-
-
Notifications
You must be signed in to change notification settings - Fork 8.4k
[ Kernel ] AWQ Fused MoE #6422
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
[ Kernel ] AWQ Fused MoE #6422
Changes from all commits
Commits
Show all changes
31 commits
Select commit
Hold shift + click to select a range
d40fd4d
added files
f1d5836
format
16baf11
stash
03d9d8e
torch library
54d6a87
fixed another torch library
524a94c
first end to end run with tp=1
febb027
loaded but not running at fp16
8bca009
correctness end-to-end!
8527d6e
formatted
36d1d82
updared the weight loading logic
6943e80
stash
71e5129
fixed fp8
703e792
Merge branch 'main' into fused-moe-awq
5b73064
merged
2ef2c92
formatting
db33c3f
better comments
f6f60cd
added
d9def7e
formatted
16eacd0
stash
0674d2f
Merge branch 'main' into fused-moe-awq
dsikka d6a032e
clean-up, fix tests
dsikka 8d52ae5
normalize weights to prevent illegal memory
dsikka c08a5da
all MoE tests working
dsikka 7325e78
revert to reproduce error
dsikka 0538dcc
update to comply with main
dsikka 0ba00ab
PR comments
dsikka 419eb7d
fix tpu forward pass; use kwargs
dsikka 5666fcb
fix triton import
dsikka 8013ad4
further fix imports
dsikka be34dc0
fix
dsikka 6e7bbf9
fix fp8
dsikka File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
"""Fused MoE utilities for AWQ.""" | ||
import torch | ||
|
||
from vllm import _custom_ops as ops | ||
from vllm.logger import init_logger | ||
from vllm.model_executor.layers.fused_moe.fused_moe import ( | ||
fused_experts, moe_align_block_size) | ||
|
||
logger = init_logger(__name__) | ||
|
||
NAIVE_THRESHOLD = 1024 | ||
|
||
|
||
def fused_experts_awq( | ||
hidden_states: torch.Tensor, | ||
w1: torch.Tensor, | ||
w2: torch.Tensor, | ||
w1_scales: torch.Tensor, | ||
w2_scales: torch.Tensor, | ||
w1_qzeros: torch.Tensor, | ||
w2_qzeros: torch.Tensor, | ||
topk_weights: torch.Tensor, | ||
topk_ids: torch.Tensor, | ||
pack_factor: int, | ||
) -> torch.Tensor: | ||
""" | ||
This function computes an AWQ fused_expert. | ||
|
||
Parameters: | ||
- hidden_states (torch.Tensor): The input tensor to the MoE layer. | ||
- w1 (torch.Tensor): The first set of expert weights. | ||
- w2 (torch.Tensor): The second set of expert weights. | ||
- w1_scales (torch.Tensor): scale to be used for w1. | ||
- w2_scales (torch.Tensor): scale to be used for w2. | ||
- w1_qzeros (torch.Tensor): zero point to be used for w1. | ||
- w2_qzeros (torch.Tensor): zero point to be used for w2. | ||
- pack_factor (int): Weight packing factor (int4 in int32 == 8) | ||
|
||
Returns: | ||
- torch.Tensor: The output tensor after applying the MoE layer. | ||
""" | ||
|
||
# If large seq_len prefill, dequantize and use the fp16 MoE kernel. | ||
do_naive_dequant = hidden_states.shape[:-1].numel() >= NAIVE_THRESHOLD | ||
if do_naive_dequant: | ||
# NOTE: not contiguous because of the permutation operation | ||
dequant_w1 = ops.awq_dequantize(w1, w1_scales, w1_qzeros, 0, 0, | ||
0).permute(0, 2, 1).contiguous() | ||
dequant_w2 = ops.awq_dequantize(w2, w2_scales, w2_qzeros, 0, 0, | ||
0).permute(0, 2, 1).contiguous() | ||
|
||
return fused_experts(hidden_states, dequant_w1, dequant_w2, | ||
topk_weights, topk_ids) | ||
|
||
(sorted_token_ids, expert_ids, | ||
num_tokens_post_padded) = moe_align_block_size(topk_ids, 16, w1.shape[0]) | ||
|
||
x = hidden_states.view(hidden_states.shape[0], 1, *hidden_states.shape[1:]) | ||
|
||
gate_up = ops.awq_fused_moe(input=x, | ||
qweight=w1, | ||
scales=w1_scales, | ||
qzeros=w1_qzeros, | ||
topk_weights=topk_weights, | ||
sorted_token_ids=sorted_token_ids, | ||
expert_ids=expert_ids, | ||
num_tokens_post_padded=num_tokens_post_padded, | ||
mul_weights=False, | ||
pack_factor=pack_factor) | ||
|
||
out = torch.empty((gate_up.shape[:-1] + (gate_up.shape[-1] // 2, )), | ||
dtype=hidden_states.dtype, | ||
device=hidden_states.device) | ||
ops.silu_and_mul(out, gate_up) | ||
|
||
out = ops.awq_fused_moe(input=out, | ||
qweight=w2, | ||
scales=w2_scales, | ||
qzeros=w2_qzeros, | ||
topk_weights=topk_weights, | ||
sorted_token_ids=sorted_token_ids, | ||
expert_ids=expert_ids, | ||
num_tokens_post_padded=num_tokens_post_padded, | ||
mul_weights=True, | ||
pack_factor=pack_factor) | ||
|
||
return torch.sum(out, dim=1) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems a bit high and it is worth commenting how it was calibrated (what model, benchmark, GPU used)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@robertgshaw2-neuralmagic do we know why this is 1024 specifically?