-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
tensor parallel MOE implementation #2293
Conversation
running into some weird torch.sort issues during cuda-graph capture... |
Hi @scv119, thanks for addressing my comments! I haven't actually completed the review yet. Will add more tonight or tmr morning. |
@WoosukKwon just let you know the triton grouped matmul returns different result from torch reference implementation for large matrix multiplication, which is likely caused by triton-lang/triton#1190 (comment) but that's purely my speculation. we might need to use https://github.com/imoneoi/cutlass_grouped_gemm if it matters. |
1089dd8 |
grouped_w1_out = grouped_matmul(expanded_hidden_states, | ||
cum_experts_range, w1s, "silu") | ||
grouped_w3_out = grouped_matmul(expanded_hidden_states, | ||
cum_experts_range, w3s) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we merge w1s
and w3
just like what we do for LlamaMLP
? Merging the two weights will be highly efficient given the cost of grouped GEMM.
self, | ||
expanded_hidden_states: torch. | ||
Tensor, # [batch_size * top_k_experts, hidden_size] | ||
reverse_indices, # [batch_size * top_k_experts] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reverse_indices, # [batch_size * top_k_experts] | |
reverse_indices: torch.Tensor, # [batch_size * top_k_experts] |
set_weight_attrs(self.w1s, { | ||
"weight_loader": self.weight_loader, | ||
"tp_type": "column" | ||
}) | ||
set_weight_attrs(self.w2s, { | ||
"weight_loader": self.weight_loader, | ||
"tp_type": "row" | ||
}) | ||
set_weight_attrs(self.w3s, { | ||
"weight_loader": self.weight_loader, | ||
"tp_type": "column" | ||
}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Can we make this compatible with other parallel linear layers by tagging input_dim
and output_dim
instead of tp_type
?
set_weight_attrs(self.w1s, { | ||
"weight_loader": self.weight_loader, | ||
"tp_type": "column" | ||
}) | ||
set_weight_attrs(self.w2s, { | ||
"weight_loader": self.weight_loader, | ||
"tp_type": "row" | ||
}) | ||
set_weight_attrs(self.w3s, { | ||
"weight_loader": self.weight_loader, | ||
"tp_type": "column" | ||
}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Can we make this more similar to other parallel linear layers by tagging input_dim
and output_dim
instead of tp_type
?
expert_params_mapping = [ | ||
# (param_name, weight_name, expert_id) | ||
(f"{weight_name}s", f"experts.{expert_id}.{weight_name}.weight", | ||
expert_id) for expert_id in range(self.config.num_local_experts) | ||
for weight_name in ["w1", "w2", "w3"] | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here, do we assume that the expert linear layers don't have bias terms?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @scv119 for the updates! The PR looks good to me overall. For the grouped GEMM, stuff I think we can investigate the Cutlass implementation later. I actually spent some time understanding it last weekend, but found it a bit difficult to understand. For now, I think the Triton kernel is acceptable, and it is actually needed for AMD GPUs anyway.
Any insights into how the quantized model will be managed please? There's a challenge regarding the weights: it may not be possible to concatenate them due to differences in experts. For instance, GPTQ might employ distinct activation order and AWQ might use varying scales. Thank you. |
linear_method=None) | ||
|
||
self.w1s = nn.Parameter( | ||
torch.empty(self.num_total_experts, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If there are many experts like deepseekmoe, it is easy to oom in this function. Any ideas to improve memory utilization?
thanks @WoosukKwon. will do another pass; also we noticed some poor performance on h100, probably need tune the kernel parameters a bit. |
On H100s, changing the number of SMs to 256 brought the best improvement in terms of throughtput for me (but still not quite matching current master). It was measured with Current master: 28600 tok/s All numbers have an error of about +/- 200 tok/s. It is quite possible that by tuning more / autotuning we can get even better results here -- I'd love to learn about it if anybody has better parameters :)
|
i think one overhead of this PR is too many small elementwise operations that are not fused according to my profile. |
This PR implements tensor parallel MOE by sharding each expert across all ranks.
concretely, it does following:
benchmark result:
A100 80G * 8, input_len=32, output_len=128
baseline:
batch_size 1: 2.1385579633448892 seconds
batch_size 8: 2.428515106982862 seconds
batch_size 32: 2.9776507209753618 seconds
batch_size 64: 3.7744668100300864 seconds
this PR
batch_size 1: 1.6442222506545174 seconds (77%)
batch_size 8: 2.3404843776564426 seconds (96%)
batch_size 32: 3.0149446266586892 seconds (101%)
batch_size 64: 3.878694705994955 seconds (103%)
A100 80G * 4, input_len=32, output_len=128
baseline:
batch_size 1: 2.9904473346929685 seconds
batch_size 8: 3.2857296433260976 seconds
batch_size 32: 3.917926660312029 seconds
batch_size 64: 4.401127053638144 seconds
this PR
batch_size 1: 1.6416094843492222 seconds (55%)
batch_size 8: 2.9794040496732728 seconds (91%)
batch_size 32: 3.631852053649103 seconds (93%)
batch_size 64: 4.388253151012274 seconds (100%)