Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable yuan autotp & add conv tp #5428

Merged
merged 11 commits into from
Jun 18, 2024
26 changes: 19 additions & 7 deletions deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from deepspeed import comm as dist
from .layers import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce
from deepspeed.accelerator import get_accelerator
from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw, shard_chunk_mlp
from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw, shard_value_with_share_qk, shard_chunk_mlp
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list


Expand Down Expand Up @@ -134,7 +134,7 @@ def is_load_module(module):
load_layer_names = [
"LPLayerNorm", "SharedEmbedding", "OPTLearnedPositionalEmbedding", "LlamaRMSNorm", "FalconLinear",
"MistralRMSNorm", "T5LayerNorm", "MixtralRMSNorm", "Phi3RotaryEmbedding", "Phi3SuScaledRotaryEmbedding",
"Phi3RMSNorm"
"Phi3RMSNorm", "YuanRMSNorm", "YuanRotaryEmbedding"
]
return module.__class__ in load_layers or module._get_name() in load_layer_names

Expand Down Expand Up @@ -331,6 +331,16 @@ def _replace(self, child, name, conv_linear_layer):
# For mixtral-7x8b, need to skip MoE gate linear replace.
if name == "block_sparse_moe.gate":
return child
# For Yuan model
if 'Yuan' in str(self.module):
if 'v_proj' in name:
weight, bias = shard_value_with_share_qk(child.weight.data, child.bias, dist.get_rank(),
dist.get_world_size(), True)
return LinearLayer(weight=weight, bias=bias)
elif 'o_proj' in name:
weight, bias = shard_value_with_share_qk(child.weight.data, child.bias, dist.get_rank(),
dist.get_world_size(), False)
return LinearAllreduce(weight, bias, self.mp_group)
# for phi3.
if 'gate_up_proj' in name:
weight, bias = shard_chunk_mlp(child.weight.data, child.bias, dist.get_rank(), dist.get_world_size())
Expand Down Expand Up @@ -412,11 +422,13 @@ def _slice_embedding(self, child, name, conv_linear_layer):
def update_mp_params(self, child):
if getattr(child, "replaced", False) == True:
return
for param in [
"n_heads", "inner_dim", "num_heads", "num_kv", "num_attention_heads", "num_attn_heads",
"all_head_size", "embed_dim", "hidden_size", "num_key_value_heads", "num_kv_heads", "kv_n_heads",
"d_model"
]:
param_list = [
"n_heads", "inner_dim", "num_heads", "num_kv", "num_attention_heads", "num_attn_heads", "all_head_size",
"embed_dim", "hidden_size", "num_key_value_heads", "num_kv_heads", "kv_n_heads", "d_model"
]
for param in param_list:
if "Yuan" in str(child) and 'embed_dim' in param_list:
param_list.remove('embed_dim')
if hasattr(child, param):
param_val = getattr(child, param)
setattr(child, param, get_shard_size(param_val, self.mp_size))
Expand Down
55 changes: 55 additions & 0 deletions deepspeed/module_inject/fusedqkv_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,61 @@ def _transpose_fused_qkvw(src, mp_size, fused_qkv_type=None, module=None):
return _bloom_type_transpose(src, mp_size)


# For share qk type:
# q = [q1,...,q_{n/4}, q_{n/2+1},...,q_{3n/4}, k1,...,k_{n/4}, k_{n/2+1},...,k_{3n/4}]
# k = [q_{n/4+1},...,q_{n/2}, q_{3n/4+1},...,qn, k_{n/4+1},...,k_{n/2}, k{3n/4+1},...,kn]
# Avoid modifying the modeling code. We adjust the value and oproj weight to fit this qk type.
def shard_value_with_share_qk(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comments here needed (with an example) to help understand functionality of shard_value_with_share_qk()

weight,
bias,
rank,
world_size,
shard_value=True # True -> shard_value; False -> shard_oproj
):
if shard_value:
total_size = weight.shape[0]
weight_cat_dim = 0
else:
total_size = weight.shape[1]
weight_cat_dim = 1
num_heads = get_num_kv_heads()
head_dim = total_size // num_heads
assert (num_heads % world_size == 0)
if world_size > num_heads // 2:
RuntimeError(f"world_size {world_size} is larger than half of num_heads {num_heads}")
head_per_rank = num_heads // world_size
q_head_start = rank * head_per_rank
# mapping q_head to v_head
v_head_ids = []
i = 0
# mapping neighbor q_head to v_head
while i < head_per_rank:
v_head_ids.append(q_head_start // 2)
q_head_start += 2
i = i + 2

# mapping neighbor k_head to v_head
v_head_ids.extend([i + num_heads // 2 for i in v_head_ids])
sharded_weight = []
sharded_bias = []
for head_id in v_head_ids:
if shard_value:
sharded_weight.append(weight[head_id * head_dim:(head_id + 1) * head_dim])
if bias is not None:
sharded_bias.append(bias.data[head_id * head_dim:(head_id + 1) * head_dim])
else:
sharded_weight.append(weight[:, head_id * head_dim:(head_id + 1) * head_dim])
sharded_weight = torch.cat(sharded_weight, dim=weight_cat_dim)
if bias is not None:
if shard_value:
sharded_bias = torch.cat(sharded_bias, dim=0)
else:
bias = bias / float(world_size)
return torch.nn.Parameter(sharded_weight), torch.nn.Parameter(sharded_bias)
else:
return torch.nn.Parameter(sharded_weight), None


# For phi3 with chunk mlp, adjust the weight order.
def shard_chunk_mlp(
weight,
Expand Down
62 changes: 62 additions & 0 deletions deepspeed/module_inject/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,68 @@
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list


class TensorParallelConv2d(nn.Module):

def __init__(self, conv, rank, world_size, shard_by_oc):
super().__init__()
self.rank = rank
self.world_size = world_size
self.shard_by_oc = shard_by_oc
self.shard_weights(conv)

# Split along the input/output channel depending on whether it is the last conv layer.
def shard_weights(self, conv):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should have some comments to explain the sharding scheme here. Better with a simple example to help understanding.

if self.shard_by_oc:
total_size = conv.weight.shape[0]
else:
total_size = conv.weight.shape[1]
bias_data = None
cols_per_rank = [0]
for i in range(self.world_size - 1, -1, -1):
cols = total_size // self.world_size
if i < total_size % self.world_size:
cols += 1
cols_per_rank.append(cols_per_rank[-1] + cols)
weight_data = conv.weight.data
if self.shard_by_oc:
# not last conv layer, split output channel
weight_data = weight_data[cols_per_rank[self.rank]:cols_per_rank[self.rank + 1]]
if conv.bias is not None:
bias_data = conv.bias.data[cols_per_rank[self.rank]:cols_per_rank[self.rank + 1]]
else:
# last conv layer, split input channel
weight_data = weight_data[:, cols_per_rank[self.rank]:cols_per_rank[self.rank + 1]]
if conv.bias is not None:
bias_data = conv.bias.data / float(self.world_size)
self.conv = nn.Conv2d(weight_data.shape[1], weight_data.shape[0], conv.kernel_size, conv.stride, conv.padding,
conv.dilation, conv.groups, conv.bias is not None, conv.padding_mode)
self.conv.weight = torch.nn.Parameter(weight_data)
if conv.bias is not None:
self.conv.bias = torch.nn.Parameter(bias_data)
del conv

def forward(self, input: torch.Tensor) -> torch.Tensor:
return self.conv(input)


class TensorParallelOcShardConv2d(TensorParallelConv2d):

def __init__(self, conv, rank, world_size):
super().__init__(conv, rank, world_size, True)


class TensorParallelIcShardConv2d(TensorParallelConv2d):

def __init__(self, conv, rank, world_size):
super().__init__(conv, rank, world_size, False)

def forward(self, input: torch.Tensor) -> torch.Tensor:
out = self.conv(input)
if self.world_size > 1:
dist.inference_all_reduce(out)
return out


class LinearAllreduce(nn.Module):

def __init__(self, weight, bias=None, mp_group=None):
Expand Down
27 changes: 27 additions & 0 deletions deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from deepspeed.accelerator import get_accelerator
from .replace_policy import replace_policies, generic_policies
from .auto_tp import AutoTP, ReplaceWithTensorSlicing, Loading
from .layers import TensorParallelOcShardConv2d, TensorParallelIcShardConv2d

from deepspeed import comm as dist
from deepspeed.module_inject.tp_shard import set_num_kv_heads, set_n_embd, set_num_attention_heads
Expand Down Expand Up @@ -340,6 +341,28 @@ def set_lm_head(module):
module = replace_wo_policy(module, ("embed_out", ), 0, "embed_out")
return module

def conv2d_parallel_shard_weights(model, rank, world_size):
# add conv policy
shard_oc_name = ["conv1"]
shard_ic_name = ["conv2"]
for name, sub_m in model.named_children():
for l_name, l_sub_m in sub_m.named_children():
if l_name in shard_oc_name:
TPConv2d = TensorParallelOcShardConv2d(
l_sub_m,
rank,
world_size,
)
setattr(sub_m, l_name, TPConv2d)
if l_name in shard_ic_name:
TPConv2d = TensorParallelIcShardConv2d(
l_sub_m,
rank,
world_size,
)
setattr(sub_m, l_name, TPConv2d)
conv2d_parallel_shard_weights(sub_m, rank, world_size)

if checkpoint_dict is not None and not config.replace_with_kernel_inject:
# AutoTP shard loading
checkpoint = checkpoint_dict["checkpoints"]
Expand All @@ -354,6 +377,10 @@ def set_lm_head(module):
pbar.update(1)
gc.collect()
replaced_module = set_lm_head(replaced_module)
# conv2d tp module replace
# Now is for yuan model. Add model list and conv policy to decide whether to replace conv.
if 'Yuan' in str(replaced_module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it mean we apply conv sharding only for models we know there is conv layer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I added the comment for helping to understand this situation~

conv2d_parallel_shard_weights(replaced_module, dist.get_rank(), dist.get_world_size())
else:
replaced_module = replace_module(model=model,
orig_class=orig_layer_impl,
Expand Down
Loading