-
-
Notifications
You must be signed in to change notification settings - Fork 8.4k
Optimized fused MoE Kernel, take 2 #2979
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
WoosukKwon
merged 30 commits into
vllm-project:main
from
pcmoritz:tune-fused-moe-kernel
Feb 26, 2024
Merged
Changes from all commits
Commits
Show all changes
30 commits
Select commit
Hold shift + click to select a range
86035f8
Autotuned fused_moe kernel
pcmoritz 8b991f5
update
pcmoritz 09b89d0
logging
pcmoritz d29e87b
fix
pcmoritz d61bb01
update
pcmoritz a4050b5
update
pcmoritz 070d825
update
pcmoritz 145014b
log once
pcmoritz e60033d
whitespace
pcmoritz f15f4b9
update
pcmoritz 0b891f9
update
pcmoritz bdbe42a
update
pcmoritz 93c5486
lint
pcmoritz 2495c6c
Merge branch 'main' into tune-fused-moe-kernel
pcmoritz b03e039
update
pcmoritz 7a07033
fixes
pcmoritz 3fc61e5
whitespace
pcmoritz 8afd132
update
pcmoritz 5802ebe
update and simplify
pcmoritz 2b56ec0
fix
pcmoritz dbad1fb
fix bug
pcmoritz 389052d
lint
pcmoritz 35d4c9e
add benchmarking script
pcmoritz f986b70
lint
pcmoritz 1ef5078
update
pcmoritz 62af490
more batchsizes
pcmoritz 5715887
Update vllm/model_executor/layers/fused_moe/fused_moe.py
pcmoritz 183fa4d
make override clearer
pcmoritz 816a1e3
yapf
pcmoritz c5040c3
yapf 2
pcmoritz File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
import json | ||
import os | ||
import sys | ||
|
||
os.environ['CUDA_VISIBLE_DEVICES'] = '0' | ||
|
||
from vllm.model_executor.layers.fused_moe import fused_moe | ||
import torch | ||
import torch.nn.functional as F | ||
import triton | ||
|
||
|
||
def main(): | ||
method = fused_moe | ||
for bs in [ | ||
1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536, | ||
2048, 3072, 4096 | ||
]: | ||
run_grid(bs, method=method) | ||
|
||
|
||
def run_grid(bs, method): | ||
d_model = 4096 | ||
num_total_experts = 8 | ||
top_k = 2 | ||
tp_size = 2 | ||
model_intermediate_size = 14336 | ||
num_layers = 32 | ||
num_calls = 100 | ||
|
||
num_warmup_trials = 1 | ||
num_trials = 1 | ||
|
||
configs = [] | ||
if bs <= 16: | ||
BLOCK_SIZES_M = [16] | ||
elif bs <= 32: | ||
BLOCK_SIZES_M = [16, 32] | ||
elif bs <= 64: | ||
BLOCK_SIZES_M = [16, 32, 64] | ||
elif bs <= 128: | ||
BLOCK_SIZES_M = [16, 32, 64, 128] | ||
else: | ||
BLOCK_SIZES_M = [16, 32, 64, 128, 256] | ||
|
||
for block_size_n in [32, 64, 128, 256]: | ||
for block_size_m in BLOCK_SIZES_M: | ||
for block_size_k in [64, 128, 256]: | ||
for group_size_m in [1, 16, 32, 64]: | ||
for num_warps in [4, 8]: | ||
configs.append({ | ||
"BLOCK_SIZE_M": block_size_m, | ||
"BLOCK_SIZE_N": block_size_n, | ||
"BLOCK_SIZE_K": block_size_k, | ||
"GROUP_SIZE_M": group_size_m, | ||
"num_warps": num_warps, | ||
"num_stages": 4, | ||
}) | ||
|
||
best_config = None | ||
best_time_us = 1e20 | ||
|
||
for config in configs: | ||
print(f'{tp_size=} {bs=}') | ||
print(f'{config}') | ||
# warmup | ||
print(f'warming up') | ||
try: | ||
for _ in range(num_warmup_trials): | ||
run_timing( | ||
num_calls=num_calls, | ||
bs=bs, | ||
d_model=d_model, | ||
num_total_experts=num_total_experts, | ||
top_k=top_k, | ||
tp_size=tp_size, | ||
model_intermediate_size=model_intermediate_size, | ||
method=method, | ||
config=config, | ||
) | ||
except triton.runtime.autotuner.OutOfResources: | ||
continue | ||
|
||
# trial | ||
print(f'benchmarking') | ||
for _ in range(num_trials): | ||
kernel_dur_ms = run_timing( | ||
num_calls=num_calls, | ||
bs=bs, | ||
d_model=d_model, | ||
num_total_experts=num_total_experts, | ||
top_k=top_k, | ||
tp_size=tp_size, | ||
model_intermediate_size=model_intermediate_size, | ||
method=method, | ||
config=config, | ||
) | ||
|
||
kernel_dur_us = 1000 * kernel_dur_ms | ||
model_dur_ms = kernel_dur_ms * num_layers | ||
|
||
if kernel_dur_us < best_time_us: | ||
best_config = config | ||
best_time_us = kernel_dur_us | ||
|
||
print( | ||
f'{kernel_dur_us=:.1f} {model_dur_ms=:.1f} {bs=} {tp_size=} {top_k=} {num_total_experts=} {d_model=} {model_intermediate_size=} {num_layers=}' | ||
) | ||
|
||
print("best_time_us", best_time_us) | ||
print("best_config", best_config) | ||
|
||
filename = "/tmp/config.jsonl" | ||
print(f"writing config to file {filename}") | ||
with open(filename, "a") as f: | ||
f.write(json.dumps({str(bs): best_config}) + "\n") | ||
|
||
|
||
def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int, | ||
top_k: int, tp_size: int, model_intermediate_size: int, method, | ||
config) -> float: | ||
shard_intermediate_size = model_intermediate_size // tp_size | ||
|
||
hidden_states = torch.rand( | ||
(bs, d_model), | ||
device="cuda:0", | ||
dtype=torch.bfloat16, | ||
) | ||
|
||
ws = torch.rand( | ||
(num_total_experts, 2 * shard_intermediate_size, d_model), | ||
device=hidden_states.device, | ||
dtype=hidden_states.dtype, | ||
) | ||
|
||
w2s = torch.rand( | ||
(num_total_experts, d_model, shard_intermediate_size), | ||
device=hidden_states.device, | ||
dtype=hidden_states.dtype, | ||
) | ||
|
||
gating_output = F.softmax(torch.rand( | ||
(num_calls, bs, num_total_experts), | ||
device=hidden_states.device, | ||
dtype=torch.float32, | ||
), | ||
dim=-1) | ||
|
||
start_event = torch.cuda.Event(enable_timing=True) | ||
end_event = torch.cuda.Event(enable_timing=True) | ||
|
||
start_event.record() | ||
for i in range(num_calls): | ||
hidden_states = method( | ||
hidden_states=hidden_states, | ||
w1=ws, | ||
w2=w2s, | ||
gating_output=gating_output[i], | ||
topk=2, | ||
renormalize=True, | ||
inplace=True, | ||
override_config=config, | ||
) | ||
end_event.record() | ||
end_event.synchronize() | ||
|
||
dur_ms = start_event.elapsed_time(end_event) / num_calls | ||
return dur_ms | ||
|
||
|
||
if __name__ == "__main__": | ||
sys.exit(main()) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe | ||
|
||
__all__ = [ | ||
"fused_moe", | ||
] |
20 changes: 20 additions & 0 deletions
20
...model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
{ | ||
"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4}, | ||
"2": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 7}, | ||
"4": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 6}, | ||
"8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 7}, | ||
"16": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 7}, | ||
"24": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4}, | ||
"32": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4}, | ||
"64": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4}, | ||
"96": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 4}, | ||
"128": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 6}, | ||
"192": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 6}, | ||
"256": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 16, "num_warps": 8, "num_stages": 4}, | ||
"512": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 32, "num_warps": 8, "num_stages": 4}, | ||
"1024": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4}, | ||
"1536": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16, "num_warps": 8, "num_stages": 4}, | ||
"2048": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 4}, | ||
"3072": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 4}, | ||
"4096": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16, "num_warps": 8, "num_stages": 4} | ||
} |
24 changes: 24 additions & 0 deletions
24
...model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
{ | ||
"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4}, | ||
"2": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4}, | ||
"4": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4}, | ||
"8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 8, "num_stages": 4}, | ||
"16": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4}, | ||
"24": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4}, | ||
"32": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4}, | ||
"80": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4}, | ||
"96": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4}, | ||
"128": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4}, | ||
"192": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4}, | ||
"200": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 4}, | ||
"208": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 4}, | ||
"216": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 4}, | ||
"224": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 4}, | ||
"256": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 4}, | ||
"512": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4}, | ||
"1024": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4}, | ||
"1536": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4}, | ||
"2048": {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4}, | ||
"3072": {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4}, | ||
"4096": {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4} | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
This directory contains tuned configurations for different settings of the fused_moe kernel. | ||
For different settings of | ||
- E (number of experts) | ||
- N (intermediate size) | ||
- device_name (torch.cuda.get_device_name()) | ||
the JSON file contains a mapping from M (batch size) to the chosen configuration. | ||
|
||
The example configurations provided are for the Mixtral model for TP2 on H100 | ||
and TP4 on A100. Mixtral has intermediate size N = 14336, i.e. for TP2 we have | ||
N = 7168 and for TP4 we have N = 3584. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you provide a script to tune these parameters? No worries otherwise. I can implement it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, happy to add the script. The process of tuning is not fully automatic at the moment and requires some manual modifications, but I will contribute what I have 😊
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added the script
benchmark_mixtral_moe.py
now to do the search -- in practice I wasn't using exactly this script but was doing some modifications of the script as I searched through the different batch sizes. But the script should be a good way to get started :)Maybe by doing a more exhaustive search, we could even improve on these parameters, but I think the gap would be pretty small if we find something even better :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!