-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
enable yuan autotp & add conv tp #5428
Changes from all commits
3551b98
8d32555
cf7ca07
e3354bf
2c5aded
2e353ed
0516583
5073096
b58de58
84d0360
9a9a0c1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,68 @@ | |
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list | ||
|
||
|
||
class TensorParallelConv2d(nn.Module): | ||
|
||
def __init__(self, conv, rank, world_size, shard_by_oc): | ||
super().__init__() | ||
self.rank = rank | ||
self.world_size = world_size | ||
self.shard_by_oc = shard_by_oc | ||
self.shard_weights(conv) | ||
|
||
# Split along the input/output channel depending on whether it is the last conv layer. | ||
def shard_weights(self, conv): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should have some comments to explain the sharding scheme here. Better with a simple example to help understanding. |
||
if self.shard_by_oc: | ||
total_size = conv.weight.shape[0] | ||
else: | ||
total_size = conv.weight.shape[1] | ||
bias_data = None | ||
cols_per_rank = [0] | ||
for i in range(self.world_size - 1, -1, -1): | ||
cols = total_size // self.world_size | ||
if i < total_size % self.world_size: | ||
cols += 1 | ||
cols_per_rank.append(cols_per_rank[-1] + cols) | ||
weight_data = conv.weight.data | ||
if self.shard_by_oc: | ||
# not last conv layer, split output channel | ||
weight_data = weight_data[cols_per_rank[self.rank]:cols_per_rank[self.rank + 1]] | ||
if conv.bias is not None: | ||
bias_data = conv.bias.data[cols_per_rank[self.rank]:cols_per_rank[self.rank + 1]] | ||
else: | ||
# last conv layer, split input channel | ||
weight_data = weight_data[:, cols_per_rank[self.rank]:cols_per_rank[self.rank + 1]] | ||
if conv.bias is not None: | ||
bias_data = conv.bias.data / float(self.world_size) | ||
self.conv = nn.Conv2d(weight_data.shape[1], weight_data.shape[0], conv.kernel_size, conv.stride, conv.padding, | ||
conv.dilation, conv.groups, conv.bias is not None, conv.padding_mode) | ||
self.conv.weight = torch.nn.Parameter(weight_data) | ||
if conv.bias is not None: | ||
self.conv.bias = torch.nn.Parameter(bias_data) | ||
del conv | ||
|
||
def forward(self, input: torch.Tensor) -> torch.Tensor: | ||
return self.conv(input) | ||
|
||
|
||
class TensorParallelOcShardConv2d(TensorParallelConv2d): | ||
|
||
def __init__(self, conv, rank, world_size): | ||
super().__init__(conv, rank, world_size, True) | ||
|
||
|
||
class TensorParallelIcShardConv2d(TensorParallelConv2d): | ||
|
||
def __init__(self, conv, rank, world_size): | ||
super().__init__(conv, rank, world_size, False) | ||
|
||
def forward(self, input: torch.Tensor) -> torch.Tensor: | ||
out = self.conv(input) | ||
if self.world_size > 1: | ||
dist.inference_all_reduce(out) | ||
return out | ||
|
||
|
||
class LinearAllreduce(nn.Module): | ||
|
||
def __init__(self, weight, bias=None, mp_group=None): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,7 @@ | |
from deepspeed.accelerator import get_accelerator | ||
from .replace_policy import replace_policies, generic_policies | ||
from .auto_tp import AutoTP, ReplaceWithTensorSlicing, Loading | ||
from .layers import TensorParallelOcShardConv2d, TensorParallelIcShardConv2d | ||
|
||
from deepspeed import comm as dist | ||
from deepspeed.module_inject.tp_shard import set_num_kv_heads, set_n_embd, set_num_attention_heads | ||
|
@@ -340,6 +341,28 @@ def set_lm_head(module): | |
module = replace_wo_policy(module, ("embed_out", ), 0, "embed_out") | ||
return module | ||
|
||
def conv2d_parallel_shard_weights(model, rank, world_size): | ||
# add conv policy | ||
shard_oc_name = ["conv1"] | ||
shard_ic_name = ["conv2"] | ||
for name, sub_m in model.named_children(): | ||
for l_name, l_sub_m in sub_m.named_children(): | ||
if l_name in shard_oc_name: | ||
TPConv2d = TensorParallelOcShardConv2d( | ||
l_sub_m, | ||
rank, | ||
world_size, | ||
) | ||
setattr(sub_m, l_name, TPConv2d) | ||
if l_name in shard_ic_name: | ||
TPConv2d = TensorParallelIcShardConv2d( | ||
l_sub_m, | ||
rank, | ||
world_size, | ||
) | ||
setattr(sub_m, l_name, TPConv2d) | ||
conv2d_parallel_shard_weights(sub_m, rank, world_size) | ||
|
||
if checkpoint_dict is not None and not config.replace_with_kernel_inject: | ||
# AutoTP shard loading | ||
checkpoint = checkpoint_dict["checkpoints"] | ||
|
@@ -354,6 +377,10 @@ def set_lm_head(module): | |
pbar.update(1) | ||
gc.collect() | ||
replaced_module = set_lm_head(replaced_module) | ||
# conv2d tp module replace | ||
# Now is for yuan model. Add model list and conv policy to decide whether to replace conv. | ||
if 'Yuan' in str(replaced_module): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it mean we apply conv sharding only for models we know there is conv layer? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. I added the comment for helping to understand this situation~ |
||
conv2d_parallel_shard_weights(replaced_module, dist.get_rank(), dist.get_world_size()) | ||
else: | ||
replaced_module = replace_module(model=model, | ||
orig_class=orig_layer_impl, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comments here needed (with an example) to help understand functionality of
shard_value_with_share_qk()