Skip to content

[FEAT] [ROCm] Enabling AITER Kernel #14007

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 39 commits into from

Conversation

tjtanaa
Copy link
Contributor

@tjtanaa tjtanaa commented Feb 28, 2025

Kernels Integrated from AITER (AI Tensor Engine for ROCm)

Linear Layer

The tgemm kernel from AITER has been integrated for the unquantized linear method in /vllm/model_executor/layers/linear.py and for per-tensor-weight and per-tensor-activation quantization FP8 Scaled GEMM in /vllm/model_executor/layers/quantization/utils/fp8_utils.py. This feature is enabled by default when the environment variable VLLM_ROCM_USE_AITER=1 is set. It can be specifically enabled or disabled using its dedicated environment variable VLLM_ROCM_USE_AITER_LINEAR.

RMS Norm Layer

The rmsnorm2d_fwd_with_add kernel from AITER has been integrated for the ROCm RMS norm forward pass in /vllm/model_executor/layers/layernorm.py. This feature is enabled by default when VLLM_ROCM_USE_AITER=1 is set. It can be specifically enabled or disabled using its dedicated environment variable VLLM_ROCM_USE_AITER_RMSNORM.

W8A8 Block GEMM

The gemm_a8w8_blockscale kernel from AITER has been integrated in /vllm/model_executor/layers/quantization/utils/fp8_utils.py. Unlike the above features, this is disabled by default even when the parent switch (VLLM_ROCM_USE_AITER=1) is enabled. To use this kernel, both the parent switch and its dedicated environment variable VLLM_ROCM_USE_AITER_W8A8_BLOCK_GEMM must be enabled. This kernel is suitable for DeepSeek models.

Fused MoE Kernels

Several fused MoE kernels have been integrated for different scenarios:

  1. The ck_moe kernel from AITER is integrated for unquantized model weights. It is enabled by default when VLLM_ROCM_USE_AITER=1 is set. It can be specifically enabled or disabled using the dedicated environment variable VLLM_ROCM_USE_AITER_MOE. This is suitable for MoE models such as Mixtral.

  2. The asm_moe kernel from AITER is integrated for dynamic per-tensor quantization model weights. It is enabled by default when VLLM_ROCM_USE_AITER=1 is set. It can be specifically enabled or disabled using the dedicated environment variable VLLM_ROCM_USE_AITER_MOE. This is suitable for MoE models such as Mixtral for fp8 quantization.

  3. The fmoe_fp8_block_scaled kernel from AITER is integrated for block fp8 quantization method. Unlike the above features, this is disabled by default even when the parent switch (VLLM_ROCM_USE_AITER=1) is enabled. To use this kernel, both the parent switch and its dedicated environment variable VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE must be enabled. This kernel is suitable for DeepSeek models.

These MoE kernels are integrated in /vllm/model_executor/layers/fused_moe/fused_moe.py. The necessary processing steps required for these kernels are included in their respective MoE Methods for both Unquantized (UnquantizedMoEMethod) in /vllm/model_executor/layers/fused_moe/layer.py and FP8 quantized (FP8MoEMethod) in /vllm/model_executor/layers/quantization/fp8.py.

Paged Attention

The pa_fwd_asm kernel from AITER is integrated as a new paged attention op in /vllm/attention/ops/rocm_aiter_paged_attn.py and implemented into the ROCM attention backend in /vllm/attention/backends/rocm_flash_attn.py.

This feature is disabled by default, even when the parent switch (VLLM_ROCM_USE_AITER=1) is enabled. To use this kernel, both the parent switch and its dedicated environment variable VLLM_ROCM_USE_AITER_PAGED_ATTN must be enabled.

Note:

  • The AITER paged attention module supports the following kv_cache_dtypes:
    • int8
    • "fp8"
    • "fp8_e4m3"
    • bfloat16
    • float16
  • However, for float16 and bfloat16 kv_cache_dtype, the module currently does not support decoding of models with more than 1 kv_head.

Performance Improvement Tables

Experiment setup

The following experiment results are obtained from throughput_benchmark.py
There are 5 cases are tested, and the performance difference is summarized into a range of percentage gain.

5 Cases:

  • 128/128
  • 128/2048
  • 2048/128
  • 2048/2048
  • SharedGPT Dataset

DeepSeekV3 Throughput

Summary Performance Improvement Over No AITER
fmoe_fp8_block_scaled (VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE=1) 8~26.7%

Note: The block scaled gemm from AITER (which relates to VLLM_ROCM_USE_AITER_BLOCK_GEMM=1) is still under development (tuning). There are internal tuning that boost the performances of block scaled gemm, so we should expect it to become faster in coming AITER updates.

DeepSeekV3 Latency

Summary SpeedUp in TPOT SpeedUp in TTFT
fmoe_fp8_block_scaled (VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE=1) -2% 41%

Mixtral-8x7B (with FP8 per-tensor dynamic quantization)

Note: The Fused MoE performs the best within SharedGPT dataset (75%) and is not performing that well in long input length (-14%).

Summary Performance Improvement Over No AITER
With All but no No BLOCK_GEMM and No PAGED_ATTN -15~73%
With Fused MoE On Only -14~75%
With Paged Attn On Only -3.4~1.9%
With All but No BLOCK_GEMM -15~79%

Mixtral-8x7B-FP16

Note: The Fused MoE performs the best within SharedGPT dataset (2%) and is not performing that well in long input length (-11%).

Summary Performance Improvement Over No AITER
With Fused MoE On Only -11~2%

Mixtral-8x22B (with FP8 per-tensor dynamic quantization)

Note: The Fused MoE performs the best within SharedGPT dataset (41%) and is not performing that well in long input length (4%).

Summary Performance Improvement Over No AITER
With All but no No BLOCK_GEMM and No PAGED_ATTN 4.8~39%
With Fused MoE On Only 4~40%
With Paged Attn On Only -0.2~2%
With All but No BLOCK_GEMM 4.8~41%

Mixtral-8x22B-FP16

Note: The Fused MoE performs the best within SharedGPT dataset (19%) and is not performing that well in long input length (2%).

Summary Performance Improvement Over No AITER
With Fused MoE On Only 2~19%
With All (no BSCALED MOE no GEMM no PA ASM) 3~19%

Llama-3.1-8B-Instruct (with FP8 per-tensor dynamic quantization)

Summary Performance Improvement Over No AITER
With Linear + RMS norm -0.7~1.5%
With Linear only -2~2%
With RMS norm only -1.1~0.8%
With Paged Attention only 1.5~7.9%
With Linear + RMS norm + Paged Attention 0.2~4.8%

Llama-3.1-8B-Instruct-BF16

Summary Performance Improvement Over No AITER
With Linear + RMS norm -5~5%
With Linear only -5~4%
With RMS norm only 0.5~3.9%

Llama-3.1-70B-Instruct (with FP8 per-tensor dynamic quantization)

Summary Performance Improvement Over No AITER
With Linear + RMS norm 1~5%
With Linear only -1~0.5%
With RMS norm only -0.02~2%
With Paged Attention only -2~3%
With Linear + RMS norm + Paged Attention 0.6~6%

Llama-3.1-70B-Instruct-BF16

Summary Performance Improvement Over No AITER
With Linear + RMS norm 1.6~4%
With Linear only -0.3~0.2%
With RMS norm only -0.12~1.2%

AITER Operations Testing Overview

1. High-Level Integration Tests

The integration of AITER ops is tested at a higher module level in the following files under /tests/models/decoder_only/language:

  • test_models.py
  • test_phimoe.py
  • test_mistral.py
  • test_granite.py

These tests involve running various models to ensure overall functionality.

2. AITER MoE Specific Test

  • The AITER Mixture of Experts (MoE) is specifically tested for the Mixtral model in:
    /tests/kernels/test_moe.py

3. Quantization Testing

  • Quantization methods for AITER-enabled modules are tested in:
    /tests/quantization/test_fp8.py

4. Kernel Function Dispatch Testing

  • The correct dispatching of kernel functions (AITER-enabled or not) is verified in:
    /tests/model_executor/test_enabled_custom_ops.py

Environment Settings

Several packages have been upgraded in Dockerfile.rocm_base based on this https://github.com/ROCm/vllm/blob/aiter_upstream/Dockerfile.rocm_base :

Updated Packages:

Added AITER Package:

  • AITER_BRANCH: e1ec015
    Note:

  • When setting up AITER, it is crucial to use the command git clone --recursive. This is because the package depends on a third-party package (Composable Kernel).

  • For building and installing the AITER Python package, you must use the PREBUILD_KERNELS=1 flag along with the command python3 setup.py develop. This ensures that all kernels in the AITER package are built successfully.

The following branches were used as references for this integration:

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the ci/build label Feb 28, 2025
…nfusion withpython builtin aiter function.

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
…ion in Fp8MoEMethod

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
@hongxiayang
Copy link
Collaborator

Maybe consolidate with #13975

@hongxiayang hongxiayang added the rocm Related to AMD ROCm label Mar 1, 2025
vllmellm and others added 8 commits March 3, 2025 04:07
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Copy link

mergify bot commented Mar 5, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @tjtanaa.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 5, 2025
vllmellm added 8 commits March 5, 2025 10:37
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
…atform

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
@hongxiayang
Copy link
Collaborator

Nice work @tjtanaa !
cc @houseroad

from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.fused_moe.fused_moe import (
dispatch_fused_experts_func, dispatch_topk_func, rocm_aiter_fused_experts,
rocm_aiter_topk_softmax, torch_vllm_inplace_fused_experts,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this cause import error for non-rocm platform?

Copy link
Contributor

@vllmellm vllmellm Mar 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, rocm_aiter_* functions are wrappers that will only be called when rocm platform is detected and aiter-specific env vars are set to True.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vllmellm @tjtanaa Please rebase to resolve the conflict and then we will put "ready" label to finalized the review. Thank you!

vllmellm added 3 commits March 7, 2025 15:53
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
@mergify mergify bot removed the needs-rebase label Mar 8, 2025
vllmellm added 4 commits March 8, 2025 05:08
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few initial comments

and kv_cache.dtype.itemsize == 1
and not self.aiter_kv_scales_initialized
and kv_cache.shape != torch.Size([0])):
num_blocks = kv_cache.shape[1]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A brief comment about why AITER needs special handling here would be nice

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ProExpertProg We have added a brief comment

+ # Reshaping kv tensors is required for AITER paged attention kernel
+ # because it works on a different tensor shape,
+ # when the size of one element is one byte (int8/fp8 dtypes ) .
+ # This reshaping is only required on the first forward call
+ # and the kv cache must not be empty.

v_scale: float,
) -> torch.Tensor:
output = torch.empty_like(query)
context_attention_fwd(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if this is not imported when HAS_TRITON is false? Or should we just always import it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ProExpertProg We can always import it since the HAS_TRITON will be always true. However, note that this module is a replica of PagedAttention in vllm/attention/ops/paged_attn.py, thus it follows the same logic except two functions are different foward_decode and write_to_paged_cache. for the forward_prefix function it follows the same logic. So, based on HAS_TRITON logic, it will be always true and as the ROCM platform supports triton and triton package is a required package.

@@ -1,4 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# isort: skip_file
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this have to be skipped?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ProExpertProg This was due to import path of a module from AITER package. it is fixed now and isort is not ignored anymore.
The import path below
from aiter.ops.shuffle import (shuffle_weight as rocm_aiter_shuffle_weight)
Initially we used this importing path to be able rename it to more specific function name that is distinguished that the function is from AITER.
Now it is changed to the following import path
from aiter.ops.shuffle import shuffle_weight
since the import is within the function scope and just above the where it is being invoked then it should be clear to developers that is from AITER.

) -> Callable[..., Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]:
if not add_residual:
return rms_norm
if current_platform.is_rocm_aiter_rmsnorm_enabled():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious: is there a fused rmsnorm+quant (fp8/int8) AITER kernel?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, there are such kernels. The kernels are the following:
rmsnorm2d_fwd_with_add_dynamicquant and rmsnorm2d_fwd_with_add_smoothquant.
there are two different quantization methods "dynamic" and "smooth".
You can find how to use those kernels here


def fused_add_rms_norm(
*, x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should use Tuple instead of tuple for the typing hint

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ProExpertProg Thank you. we have updated to Tuple.

Comment on lines 36 to 49
def shape_supported_by_cutlass(weight: torch.Tensor, block_size: List[int],
weight_scale: torch.Tensor,
input_2d: torch.Tensor) -> bool:
if current_platform.is_rocm():
scale_a_shape = ((input_2d.shape[-1] // block_size[1], ) +
input_2d.shape[:-1])[::-1]
scale_b_shape = (weight_scale.view(-1, 1)
if weight_scale.dim() <= 1 else weight_scale.T).shape
ar, ac = scale_a_shape
br, bc = scale_b_shape
return ac > 1 or bc > 1 or ar not in (1, input_2d.shape[0]) \
or br not in (1, weight.shape[0])

return weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is currently not working correctly, please read the issue in the TODO that was removed here. I am working on a fix and that might have to land before this PR. Or, at least restore the TODO here for now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ProExpertProg do you have a pull request for this fix? If not, can you give us a timeline for when will you be able to open the pr.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ProExpertProg We have restored the TODO here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes PR be there tomorrow.

@@ -159,124 +304,39 @@ def apply(
# TODO(luka) remove this parameter in favor of __init__
use_per_token_if_dynamic: Optional[bool] = None
) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure that this was properly merged after #14390; please at least restore the comments here.

Copy link
Contributor

@vllmellm vllmellm Mar 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out. Indeed, the code hadn't been merged properly, but this should be resolved after our last commit. We have also added the comments back in.

@@ -138,8 +151,7 @@ def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:

return F.linear(x, layer.weight, bias)
return dispatch_unquantized_linear_func()(x, layer.weight, bias)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this only applies to Marlin, and not any other linear method?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For clarity, dispatch_unquantized_linear_func is called within UnquantizedLinearMethod.apply. If you'd like to see the changes in more detail, expanding the file might be helpful as the UI here makes it seem like its being called in adjust_marlin_shard.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah that's my bad, thanks!

@@ -555,6 +556,17 @@ def process_weights_after_loading(self, layer: Module) -> None:
layer.w2_weight = Parameter(w2_weight, requires_grad=False)
layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv,
requires_grad=False)
if current_platform.is_rocm_aiter_fp8_block_scaled_moe_enabled():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels like it's adding too much additional complexity here. I think better abstractions (a subclass maybe?) would be good, but I'll let @robertgshaw2-redhat and @dsikka chime in here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ProExpertProg yes, totally agree with this refactoring. The Fp8MoEMethod could be a base class with abstracted functions create_ weights and process_weights_after_load so that these functions would be implemented in subclasses based on the quantization methods. For instance, block quantization, in-place (for float16 and bfloat16) weights and for weights that are fp8. basically, breaking down the if/else statements in process_weights_after_loading into different subclasses.

… comments from merge conflict, code edocumentation

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Copy link

mergify bot commented Mar 12, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @tjtanaa.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 12, 2025
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
@mergify mergify bot removed the needs-rebase label Mar 12, 2025
Copy link

mergify bot commented Mar 12, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @tjtanaa.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 12, 2025
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Copy link

mergify bot commented Mar 26, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @tjtanaa.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 26, 2025
@hongxiayang
Copy link
Collaborator

@tjtanaa can we close this PR since it is split to several small PRs?

@tjtanaa tjtanaa closed this Mar 27, 2025
@tjtanaa tjtanaa deleted the aiter-integration branch May 16, 2025 16:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants