从底层逆推对张量的拆分
初衷是想控制张量并行拆分,本文分析张量是怎么被拆分并分到两张卡的。以OPT模型为例。
张量的复制
张量的拆分最终发生在ReplaceWithTensorSlicing类的strided_copy方法中(文件位置deepspeed/module_inject/auto_tp.py)。这个类在replace_transformer_layer方法中(文件位置deepspeed/module_inject/replace_module.py)被实例化。代码如下。从代码可知复制时是根据dst.shape和src.shape来拆分张量的。如果dst.shape最外层的维度是src.shape的一半,那么张量就会被对半拆分。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
| def strided_copy(self, dst: Optional[torch.Tensor], src: Optional[torch.Tensor], num_splits: int, int8: bool = False, allocate_tensor: bool = False): if src is None: return src src_shape = src.shape dst_shape = dst.shape outer_dim = 0 if int8 else -1 if allocate_tensor: dst = torch.empty_like(dst) src_split = torch.split(src.data, src.shape[outer_dim] // num_splits, dim=outer_dim)
if (len(src_shape) == 2 and len(dst_shape) == 2): if src_shape[outer_dim] == dst_shape[self.out_dim]: try: dst = dst.reshape(-1).data.copy_(src.data.reshape(-1)).reshape(src.shape) except: print(dst.shape, src.shape) exit() dst = torch.nn.parameter.Parameter(dst, requires_grad=False) if hasattr(src, 'scale'): dst.scale = src.scale return dst self.merge_assert(src_shape[outer_dim], dst_shape[self.out_dim]) qkv_size = dst_shape[self.out_dim] // num_splits qkv_split = [torch.split(src_s, qkv_size, dim=outer_dim) for src_s in src_split] weight_split = [ torch.cat([qkv_s[i] for qkv_s in qkv_split], axis=outer_dim) for i in range(len(qkv_split[0])) ] dst = dst.reshape(-1).data.copy_(weight_split[self.gpu_index].contiguous().reshape(-1)).reshape( weight_split[self.gpu_index].shape) else: if src_shape[0] == dst_shape[0]: return torch.nn.parameter.Parameter(src) qkv_size = dst_shape[0] // num_splits qkv_split = [torch.split(src_s, qkv_size, dim=0) for src_s in src_split] bias_split = [torch.cat([qkv_s[i] for qkv_s in qkv_split], axis=0) for i in range(len(qkv_split[0]))] dst.data.copy_(bias_split[self.gpu_index].contiguous()) dst = torch.nn.parameter.Parameter(dst, requires_grad=False) if hasattr(src, 'scale'): dst.scale = src.scale return dst
|
张量形状的确定
调用上文strided_copy方法的函数如下(文件位置deepspeed/module_inject/containers/base.py)。可以看到dst来自self.module而src直接来自self(这里的self是Container)。
1 2 3 4 5 6 7 8 9
| def attention_qkv_mp(self, mp_replace, reversed_dim=False): self.module.attention.attn_qkvw = mp_replace.strided_copy(self.module.attention.attn_qkvw, self.qkvw, num_splits=3, int8=reversed_dim) self.module.attention.attn_qkvb = mp_replace.strided_copy(self.module.attention.attn_qkvb, self.qkvb, num_splits=3, int8=reversed_dim)
|
self.qkvw最终是来自self.policy.client_module(参考deepspeed/module_inject/containers/opt.py),也就是从huggingface拉下来的模型。
1 2 3 4 5 6 7 8 9 10 11 12 13
| def attention(self, enable_training=False): qw = self.client_module.self_attn.q_proj.weight qb = self.client_module.self_attn.q_proj.bias kw = self.client_module.self_attn.k_proj.weight kb = self.client_module.self_attn.k_proj.bias vw = self.client_module.self_attn.v_proj.weight vb = self.client_module.self_attn.v_proj.bias qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=enable_training) qkvb = Parameter(torch.cat((qb, kb, vb), dim=0), requires_grad=enable_training) return qkvw, \ qkvb, \ self.client_module.self_attn.out_proj.weight, \ self.client_module.self_attn.out_proj.bias
|
self.module是DeepSpeedOPTInference,本质是DeepSpeedTransformerInference。self.module.attention是DeepSpeedSelfAttention(文件位置deepspeed/ops/transformer/inference/ds_attention.py)。里面的atten_qkvw的大小是由qkv_size_per_partition决定的。
1 2 3 4 5 6 7
| self.attn_qkvw = nn.Parameter(torch.empty(self.config.hidden_size, qkv_size_per_partition, dtype=data_type, device=device), requires_grad=False) self.attn_qkvb = nn.Parameter(torch.empty(qkv_size_per_partition, dtype=data_type_fp, device=device), requires_grad=False)
|
qkv_size_per_partition是在DeepSpeedSelfAttention初始化时决定的(文件位置deepspeed/ops/transformer/inference/ds_attention.py)。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
| qkv_size_per_partition = (self.config.hidden_size // self.config.mp_size) * 3 if config.num_kv < 0 else \ ((self.config.heads + self.config.num_kv * 2) // self.config.mp_size) * (self.config.hidden_size // self.config.heads) self.attn_qkvw = nn.Parameter(torch.empty(self.config.hidden_size, qkv_size_per_partition, dtype=data_type, device=device), requires_grad=False) self.attn_qkvb = nn.Parameter(torch.empty(qkv_size_per_partition, dtype=data_type_fp, device=device), requires_grad=False) out_size_per_partition = self.config.hidden_size // self.config.mp_size self.attn_ow = nn.Parameter(torch.empty(out_size_per_partition, self.config.hidden_size, dtype=data_type, device=device), requires_grad=False) self.attn_ob = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device), requires_grad=False)
|