Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] change benchmark script so that result can be directly used; tune moe kernel in A100/H100 with tp=2,4,8 #3389

Merged
merged 11 commits into from
Mar 14, 2024

Conversation

youkaichao
Copy link
Member

@youkaichao youkaichao changed the title [Kernel] change benchmark script so that result can be directly used [Kernel][WIP] change benchmark script so that result can be directly used Mar 13, 2024
@youkaichao youkaichao changed the title [Kernel][WIP] change benchmark script so that result can be directly used [Kernel] change benchmark script so that result can be directly used; tune moe kernel in A100/H100 with tp=2,4,8 Mar 14, 2024
@youkaichao youkaichao marked this pull request as ready for review March 14, 2024 00:38
@pcmoritz
Copy link
Collaborator

pcmoritz commented Mar 14, 2024

Thanks for the refactorings! While you are touching this code, one thing that would be wonderful to do is to keep track of the timings for the best configuration for each batch size. This could e.g. be done by writing them to a separate file. This would allow you to decide if a new configuration is better than the old one.

Also note that running the script as-is will likely not produce optimal results in some settings, since there is a bunch of parameter pruning going on at the moment (e.g. for the batch size). Sometimes it is important to look at the values found and then expand the search space if it runs into the boundaries :)

@youkaichao
Copy link
Member Author

@pcmoritz This manual kernel tuning is kind of temporary. Going forward, we plan to use triton.autotune to automatically tune these configs. So we don't need to invest too much time here.

"2048": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 4},
"3072": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 4},
"4096": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16, "num_warps": 8, "num_stages": 4}
"1": {
Copy link
Collaborator

@WoosukKwon WoosukKwon Mar 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The configuration here is quite different from the one we have right now. Could you compare the old and new ones by benchmarking the end-to-end performance (e.g., using benchmark_throughput.py on Mixtral)?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, it'd be nice if you can benchmark other configs as well, if not all.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iirc, LLaMA models don't use MoE. Do you mean mixtral models?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep I mean mixtral, not llama

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you have some common setting for me to run the throughput test? Otherwise I'm blindly running python benchmarks/benchmark_throughput.py --input-len 100 --output-len 100. Not sure if input 100 tokens and output 100 tokens are the cases people care the most.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will use python benchmarks/benchmark_throughput.py --model=mistralai/Mixtral-8x7B-Instruct-v0.1 --input-len 1000 --output-len 50 from #2293 (comment) .

@pcmoritz
Copy link
Collaborator

@pcmoritz This manual kernel tuning is kind of temporary. Going forward, we plan to use triton.autotune to automatically tune these configs. So we don't need to invest too much time here.

In my experience triton.autotune is far too slow to be useful (unless the configs have already been run / are cached) :)

@youkaichao
Copy link
Member Author

command: python benchmarks/benchmark_throughput.py --model=mistralai/Mixtral-8x7B-Instruct-v0.1 --input-len 1000 --output-len 50 -tp 2

H100 GPU:

benchmark no config w/o this PR w/ PR
tp=2 11.73 requests/s, 12313.00 tokens/s 14.11 requests/s, 14816.20 tokens/s 15.69 requests/s, 16477.71 tokens/s
tp=4 18.07 requests/s, 18974.30 tokens/s same as no config 22.06 requests/s, 23166.51 tokens/s
tp=8 24.80 requests/s, 26039.89 tokens/s same as no config 28.17 requests/s, 29577.75 tokens/s

A100 GPU: TODO (don't have 8*A100 GPU at hand now)

@WoosukKwon benchmarking results are quite promising!

@youkaichao
Copy link
Member Author

In my experience triton.autotune is far too slow to be useful (unless the configs have already been run / are cached) :)

Will definitely try to cache tuned configs!

@WoosukKwon
Copy link
Collaborator

@youkaichao Awesome! Could you 1) update the PR with the current main and 2) fix the lint error by running ./format.sh? You will have to run pip install -r requirements-dev.txt before it.

@youkaichao
Copy link
Member Author

@WoosukKwon lint is good now 👌

@WoosukKwon WoosukKwon enabled auto-merge (squash) March 14, 2024 07:51
Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for the PR! Excited about the performance improvement!

@WoosukKwon WoosukKwon enabled auto-merge (squash) March 14, 2024 07:52
@WoosukKwon WoosukKwon merged commit 8fe8386 into vllm-project:main Mar 14, 2024
24 checks passed
starmpcc pushed a commit to starmpcc/vllm that referenced this pull request Mar 14, 2024
@youkaichao youkaichao deleted the moe-kernel-tuning branch March 14, 2024 16:19
Temirulan pushed a commit to Temirulan/vllm-whisper that referenced this pull request Sep 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants