Skip to content

[Kernel] Integrate DeepGEMM dense block fp8 #13996

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

mgoin
Copy link
Member

@mgoin mgoin commented Feb 27, 2025

WIP since the performance is not better on my systems (8xH100 CUDA 12.5 and 8xH200 CUDA 12.4) compared to our CUTLASS kernel, and I'm unsure how we can distribute DeepGEMM. See #13917 for microbenchmarks

Setup

I installed DeepGEMM in a parallel directory to vLLM, like so

git clone --recursive https://github.com/deepseek-ai/DeepGEMM
cd DeepGEMM
python setup.py install
uv pip install -e .

Usage

E2E evaluations/benchmarks on GSM8k with 8xH200 and CUDA 12.4:

Default (using our existing CUTLASS Block FP8):

lm_eval --model vllm --model_args pretrained=/home/vllm-dev/DeepSeek-R1,tensor_parallel_size=8 --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto
...
Processed prompts: 100%|████████████████████████████████████| 1319/1319 [02:53<00:00,  7.59it/s, est. speed input: 6621.27 toks/s, output: 775.88 toks/s]
Running generate_until requests: 100%|██████████████████████| 1319/1319 [02:54<00:00,  7.57it/s]
vllm (pretrained=/home/vllm-dev/DeepSeek-R1,tensor_parallel_size=8,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9492|±  | 0.006|
|     |       |strict-match    |     5|exact_match|↑  |0.9492|±  | 0.006|

DeepGEMM:

VLLM_USE_DEEPGEMM=1 lm_eval --model vllm --model_args pretrained=/home/vllm-dev/DeepSeek-R1,tensor_parallel_size=8 --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto
...
Processed prompts: 100%|████████████████████████████████████| 1319/1319 [02:52<00:00,  7.63it/s, est. speed input: 6654.56 toks/s, output: 778.43 toks/s]
Running generate_until requests: 100%|██████████████████████| 1319/1319 [02:53<00:00,  7.60it/s]
vllm (pretrained=/home/vllm-dev/DeepSeek-R1,tensor_parallel_size=8,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  | 0.95|±  | 0.006|
|     |       |strict-match    |     5|exact_match|↑  | 0.95|±  | 0.006|

Signed-off-by: mgoin <mgoin64@gmail.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Signed-off-by: mgoin <mgoin64@gmail.com>
@houseroad
Copy link
Collaborator

Curious, do we see perf wins here?

@mgoin
Copy link
Member Author

mgoin commented Mar 3, 2025

@houseroad it seems worse at small M but better at large M compared to our CUTLASS kernels, however this is only true for specific shapes. I need to do more careful benchmarking but in the above lm-eval gsm8k throughput test it seems to be ~same

Copy link

mergify bot commented Mar 11, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @mgoin.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 11, 2025
@li2haipeng
Copy link

Hi, do we plan to merge this PR? We did a benchmark on 8xH200 and we saw ~5% otps gain on Deepseek-r1. Input len=1600, outputlen=600 and bs=1/4/8, tp=8.

@houseroad
Copy link
Collaborator

fi we do see perf wins, we can pick this PR up. cc: @chenyang78

@chenyang78
Copy link
Contributor

fi we do see perf wins, we can pick this PR up. cc: @chenyang78

Thanks for the heads-up. I will look into this.

@mgoin mgoin closed this Jul 3, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants