-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
enable yuan autotp & add conv tp #5428
Conversation
@@ -123,3 +123,54 @@ def _transpose_fused_qkvw(src, mp_size, fused_qkv_type=None, module=None): | |||
warning_once(f"Unrecognized fusedkqv weight type, default to using bloom type," | |||
f"please check in prepare_tp_fused_qkvw() to avoid potential calculation errors") | |||
return _bloom_type_transpose(src, mp_size) | |||
|
|||
|
|||
def shard_value_with_share_qk( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comments here needed (with an example) to help understand functionality of shard_value_with_share_qk()
self.shard_by_oc = shard_by_oc | ||
self.shard_weights(conv) | ||
|
||
def shard_weights(self, conv): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should have some comments to explain the sharding scheme here. Better with a simple example to help understanding.
@@ -350,6 +372,9 @@ def set_lm_head(module): | |||
pbar.update(1) | |||
gc.collect() | |||
replaced_module = set_lm_head(replaced_module) | |||
# conv2d tp module replace | |||
if 'Yuan' in str(replaced_module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it mean we apply conv sharding only for models we know there is conv layer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. I added the comment for helping to understand this situation~
Hi @tjruwase, we get request to support Yuan model AutoTP (https://huggingface.co/IEITYuan/Yuan2-102B-hf). This model has special QKV format and also has convolution layers which need special treatment in tensor parallelism. This PR address both model features and support them inside DeepSpeed AutoTP. Can this PR be reviewed? Thanks! |
Hi @delock - FYI could you resolve the merge conflicts on this PR so it can be reviewed/tests run? |
This PR aims to enable yuan model autotp and add conv tp.
Yuan model used shared qk.
For example:
q_linear_out = [q1, q2, q3, q4, q5, ... , q16]
k_linear_out = [k1, k2, k3, k4, k5, ... , k16]
after share qk:
TP=1:
q' = [q1,q2,q3,q4, q9,q10,q11,q12, k1,k2 k3,k4, k9,k10,k11,k12]
k' = [q5,q6,q7,q8, q13,q14,q15,q16, k5,k6,k7,k8, k13,k14,k15,k16]
v' = [v1,v2,v3,v4, v5,v6,v7,v8, v9,v10,v11,v12, v13,v14,v15,v16]
TP=2:
rank0:
q'_0 = [q1,q2,q3,q4, k1,k2 k3,k4]
k'_0 = [q5,q6,q7,q8, k5,k6,k7,k8]
v'_0 = [v1,v2,v3,v4, v5,v6,v7,v8] -> v'_0 is error! Expect value is: [v1,v2,v3,v4, v9,v10,v11,v12]
rank1:
q'_1 = [q9,q10,q11,q12, k9,k10,k11,k12]
k'_1 = [q13,q14,q15,q16, k13,k14,k15,k16]
v'_1 = [v9,v10,v11,v12, v13,v14,v15,v16] -> v'_1 is error! Expect value is: [v5,v6,v7,v8, v13,v14,v15,v16]
To avoid modifying the modeling code. We adjust the value and oproj weight to fit this qk type.
We also added the conv tp to support some models that including the heavy conv calculation. It is similar to the linear tp policy.
if not last_conv_layer:
else: