Skip to content

[FEAT] [ROCm]: Add AITER Block-Scaled GEMM Feature #14968

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
May 14, 2025

Conversation

vllmellm
Copy link
Contributor

@vllmellm vllmellm commented Mar 17, 2025

Description

This PR integrates the Block-Scaled GEMM functionality from AITER into vLLM, and will allow any up-coming optimizations in AITER kernel to be directly used and evaluated within the vLLM framework.

Implementation

The gemm_a8w8_blockscale kernel from AITER has been added to /vllm/model_executor/layers/quantization/utils/fp8_utils.py. This kernel is:

  • Enabled only if VLLM_ROCM_USE_AITER and VLLM_ROCM_USE_AITER_LINEAR are both set to 1.
  • Suitable for DeepSeek models.

Testing

The integration has been verified through:

  • High-level integration tests with various models
  • Kernel function dispatch testing to ensure correct operation selection

[Updated] Performance

V1 Engine
Summary of Improvements
When comparing the performance with and without AITER Blockscaled GEMM FP8, the following improvements were observed:

Metric Without AITER With AITER Improvement
Benchmark duration 142.67s 119.71s 16.1% faster
Request throughput 3.50 req/s 4.18 req/s 19.4% higher
Output token throughput 889.13 tok/s 1054.40 tok/s 18.6% higher
Total token throughput 4346.57 tok/s 5175.02 tok/s 19.1% higher
Mean inter-token latency 87.52 ms 73.41 ms 16.1% lower
Median inter-token latency 55.92 ms 45.66 ms 18.3% lower

Key Observations

  1. Overall Speed: The benchmark completed 16.1% faster with AITER Blockscaled GEMM FP8.
  2. Token Generation: Output token throughput increased by 18.6%, showing significantly faster token generation.
  3. Latency Improvements: Both mean and median inter-token latency decreased, resulting in smoother token generation.
  4. P99 Latency: The P99 inter-token latency improved from 2634.65ms to 2445.57ms, showing better worst-case performance.

[Updated] LM Eval accuracy

V1 Engine

Without AITER Block Scaled GEMM FP8

vllm (pretrained=deepseek-ai/DeepSeek-V3,tensor_parallel_size=8,max_model_len=32768,block_size=1,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.9477 ± 0.0061
strict-match 5 exact_match 0.9484 ± 0.0061

With AITER Block Scaled GEMM FP8

vllm (pretrained=deepseek-ai/DeepSeek-V3,tensor_parallel_size=8,max_model_len=32768,block_size=1,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.9469 ± 0.0062
strict-match 5 exact_match 0.9477 ± 0.0061
Old info

Throughput

SharedGPT Dataset

  • With AITER: Throughput: 5.42 requests/s, 2274.81 total tokens/s, 1087.69 output tokens/s
  • Without AITER: Throughput: 5.23 requests/s, 2196.01 total tokens/s, 1050.01 output tokens/s

Gain 3.5% in SharedGPT Dataset.

Accuracy Test GSM8K

vllm (pretrained=deepseek-ai/DeepSeek-V3,tensor_parallel_size=8,max_model_len=30000,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.9484 ± 0.0061
strict-match 5 exact_match 0.9492 ± 0.0060

Environment Settings

Updates in Dockerfile.rocm_base

Added AITER Package:

Additional Notes

  • When setting up AITER, it is crucial to use the command git clone --recursive. This is because the package depends on a third-party package (Composable Kernel).
  • For building and installing the AITER Python package, you must use the PREBUILD_KERNELS=1 flag along with the command python3 setup.py develop. This ensures that all kernels in the AITER package are built successfully.

The following branches were used as references for this integration:
https://github.com/ROCm/vllm/tree/aiter_upstream
https://github.com/ROCm/vllm/tree/aiter_integration_final
https://github.com/ROCm/vllm/tree/deepseek_v3_dev

This PR is part of a larger effort to integrate AITER kernels into vLLM for improved performance on the ROCm platform.

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
@vllmellm vllmellm marked this pull request as ready for review March 18, 2025 05:33
… add AITER package

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Copy link

mergify bot commented Mar 24, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @vllmellm.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 24, 2025
…gration

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
@mergify mergify bot removed the needs-rebase label Mar 24, 2025
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Copy link
Contributor

@SageMoore SageMoore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This generally looks fine. What models are you all using this kernel with? If there are any models that we would like to claim that this kernel supports, please just include lm_eval results in a comment on this PR.

Copy link

mergify bot commented Mar 31, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @vllmellm.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 31, 2025
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
@mergify mergify bot removed the needs-rebase label Apr 16, 2025
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Copy link

mergify bot commented Apr 22, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @vllmellm.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Apr 22, 2025
tjtanaa added 2 commits April 22, 2025 13:13
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
@mergify mergify bot removed the needs-rebase label Apr 22, 2025
weight_scale,
block_size,
output_dtype=input.dtype)
# TODO is_shape_supported_by_cutlass is never used,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add the github id or create an issue for this for tracking purpose.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@houseroad We have cross-checked with main. It seems they have implemented and removed the comment. Thus, we have removed the TODO comment.

tjtanaa and others added 5 commits April 23, 2025 01:56
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Copy link
Contributor

@SageMoore SageMoore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general this looks fine, but let's iron out this "is cutlass supported logic" ironed out before landing.

weight_scale: torch.Tensor,
input_2d: torch.Tensor) -> bool:
if current_platform.is_rocm():
# TODO this is never used, as cutlass_block_fp8_supported is False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If that's the case can we just return False? Or move the current_platform.is_rocm() check to apply_w8a8_block_fp8_linear?

Copy link
Contributor

@tjtanaa tjtanaa May 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SageMoore I think we could.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just make sure you have a look at #14397 as it describes how this is currently a bug.

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Copy link

mergify bot commented May 13, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @vllmellm.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label May 13, 2025
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
@mergify mergify bot removed the needs-rebase label May 13, 2025
Copy link
Contributor

@SageMoore SageMoore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks reasonable. Thanks for the contribution!

@DarkLight1337 DarkLight1337 added the ready ONLY add when PR is ready to merge/full CI is needed label May 13, 2025
@DarkLight1337 DarkLight1337 enabled auto-merge (squash) May 13, 2025 14:32
@vllm-bot vllm-bot merged commit 40de1ef into vllm-project:main May 14, 2025
84 of 93 checks passed
@tjtanaa tjtanaa deleted the aiter-block-gemm-integration branch May 16, 2025 16:28
zzzyq pushed a commit to zzzyq/vllm that referenced this pull request May 24, 2025
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: Yuqi Zhang <yuqizhang@google.com>
minpeter pushed a commit to minpeter/vllm that referenced this pull request Jun 24, 2025
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: minpeter <kali2005611@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants