-
-
Notifications
You must be signed in to change notification settings - Fork 8.4k
[FEAT] [ROCm] Enabling AITER Kernel #14007
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
…nfusion withpython builtin aiter function. Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
…ion in Fp8MoEMethod Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Maybe consolidate with #13975 |
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
…atform Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Nice work @tjtanaa ! |
from vllm.model_executor.layers.layernorm import RMSNorm | ||
from vllm.model_executor.layers.fused_moe.fused_moe import ( | ||
dispatch_fused_experts_func, dispatch_topk_func, rocm_aiter_fused_experts, | ||
rocm_aiter_topk_softmax, torch_vllm_inplace_fused_experts, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will this cause import error for non-rocm platform?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, rocm_aiter_*
functions are wrappers that will only be called when rocm platform is detected and aiter-specific env vars are set to True.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few initial comments
and kv_cache.dtype.itemsize == 1 | ||
and not self.aiter_kv_scales_initialized | ||
and kv_cache.shape != torch.Size([0])): | ||
num_blocks = kv_cache.shape[1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A brief comment about why AITER needs special handling here would be nice
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ProExpertProg We have added a brief comment
+ # Reshaping kv tensors is required for AITER paged attention kernel
+ # because it works on a different tensor shape,
+ # when the size of one element is one byte (int8/fp8 dtypes ) .
+ # This reshaping is only required on the first forward call
+ # and the kv cache must not be empty.
v_scale: float, | ||
) -> torch.Tensor: | ||
output = torch.empty_like(query) | ||
context_attention_fwd( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens if this is not imported when HAS_TRITON
is false? Or should we just always import it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ProExpertProg We can always import it since the HAS_TRITON
will be always true. However, note that this module is a replica of PagedAttention
in vllm/attention/ops/paged_attn.py
, thus it follows the same logic except two functions are different foward_decode
and write_to_paged_cache
. for the forward_prefix
function it follows the same logic. So, based on HAS_TRITON
logic, it will be always true and as the ROCM platform supports triton and triton package is a required package.
@@ -1,4 +1,5 @@ | |||
# SPDX-License-Identifier: Apache-2.0 | |||
# isort: skip_file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does this have to be skipped?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ProExpertProg This was due to import path of a module from AITER package. it is fixed now and isort is not ignored anymore.
The import path below
from aiter.ops.shuffle import (shuffle_weight as rocm_aiter_shuffle_weight)
Initially we used this importing path to be able rename it to more specific function name that is distinguished that the function is from AITER.
Now it is changed to the following import path
from aiter.ops.shuffle import shuffle_weight
since the import is within the function scope and just above the where it is being invoked then it should be clear to developers that is from AITER.
) -> Callable[..., Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]: | ||
if not add_residual: | ||
return rms_norm | ||
if current_platform.is_rocm_aiter_rmsnorm_enabled(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious: is there a fused rmsnorm+quant (fp8/int8) AITER kernel?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, there are such kernels. The kernels are the following:
rmsnorm2d_fwd_with_add_dynamicquant
and rmsnorm2d_fwd_with_add_smoothquant
.
there are two different quantization methods "dynamic" and "smooth".
You can find how to use those kernels here
|
||
def fused_add_rms_norm( | ||
*, x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, | ||
variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should use Tuple
instead of tuple
for the typing hint
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ProExpertProg Thank you. we have updated to Tuple
.
def shape_supported_by_cutlass(weight: torch.Tensor, block_size: List[int], | ||
weight_scale: torch.Tensor, | ||
input_2d: torch.Tensor) -> bool: | ||
if current_platform.is_rocm(): | ||
scale_a_shape = ((input_2d.shape[-1] // block_size[1], ) + | ||
input_2d.shape[:-1])[::-1] | ||
scale_b_shape = (weight_scale.view(-1, 1) | ||
if weight_scale.dim() <= 1 else weight_scale.T).shape | ||
ar, ac = scale_a_shape | ||
br, bc = scale_b_shape | ||
return ac > 1 or bc > 1 or ar not in (1, input_2d.shape[0]) \ | ||
or br not in (1, weight.shape[0]) | ||
|
||
return weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is currently not working correctly, please read the issue in the TODO that was removed here. I am working on a fix and that might have to land before this PR. Or, at least restore the TODO here for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ProExpertProg do you have a pull request for this fix? If not, can you give us a timeline for when will you be able to open the pr.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ProExpertProg We have restored the TODO here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes PR be there tomorrow.
@@ -159,124 +304,39 @@ def apply( | |||
# TODO(luka) remove this parameter in favor of __init__ | |||
use_per_token_if_dynamic: Optional[bool] = None | |||
) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure that this was properly merged after #14390; please at least restore the comments here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pointing this out. Indeed, the code hadn't been merged properly, but this should be resolved after our last commit. We have also added the comments back in.
@@ -138,8 +151,7 @@ def apply(self, | |||
layer: torch.nn.Module, | |||
x: torch.Tensor, | |||
bias: Optional[torch.Tensor] = None) -> torch.Tensor: | |||
|
|||
return F.linear(x, layer.weight, bias) | |||
return dispatch_unquantized_linear_func()(x, layer.weight, bias) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this only applies to Marlin, and not any other linear method?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For clarity, dispatch_unquantized_linear_func
is called within UnquantizedLinearMethod.apply
. If you'd like to see the changes in more detail, expanding the file might be helpful as the UI here makes it seem like its being called in adjust_marlin_shard
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yeah that's my bad, thanks!
@@ -555,6 +556,17 @@ def process_weights_after_loading(self, layer: Module) -> None: | |||
layer.w2_weight = Parameter(w2_weight, requires_grad=False) | |||
layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv, | |||
requires_grad=False) | |||
if current_platform.is_rocm_aiter_fp8_block_scaled_moe_enabled(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This feels like it's adding too much additional complexity here. I think better abstractions (a subclass maybe?) would be good, but I'll let @robertgshaw2-redhat and @dsikka chime in here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ProExpertProg yes, totally agree with this refactoring. The Fp8MoEMethod
could be a base class with abstracted functions create_ weights
and process_weights_after_load
so that these functions would be implemented in subclasses based on the quantization methods. For instance, block quantization, in-place (for float16 and bfloat16) weights and for weights that are fp8. basically, breaking down the if/else statements in process_weights_after_loading
into different subclasses.
… comments from merge conflict, code edocumentation Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This pull request has merge conflicts that must be resolved before it can be |
@tjtanaa can we close this PR since it is split to several small PRs? |
Kernels Integrated from AITER (AI Tensor Engine for ROCm)
Linear Layer
The
tgemm
kernel from AITER has been integrated for the unquantized linear method in/vllm/model_executor/layers/linear.py
and for per-tensor-weight and per-tensor-activation quantization FP8 Scaled GEMM in/vllm/model_executor/layers/quantization/utils/fp8_utils.py
. This feature is enabled by default when the environment variableVLLM_ROCM_USE_AITER=1
is set. It can be specifically enabled or disabled using its dedicated environment variableVLLM_ROCM_USE_AITER_LINEAR
.RMS Norm Layer
The
rmsnorm2d_fwd_with_add
kernel from AITER has been integrated for the ROCm RMS norm forward pass in/vllm/model_executor/layers/layernorm.py
. This feature is enabled by default whenVLLM_ROCM_USE_AITER=1
is set. It can be specifically enabled or disabled using its dedicated environment variableVLLM_ROCM_USE_AITER_RMSNORM
.W8A8 Block GEMM
The
gemm_a8w8_blockscale
kernel from AITER has been integrated in/vllm/model_executor/layers/quantization/utils/fp8_utils.py
. Unlike the above features, this is disabled by default even when the parent switch (VLLM_ROCM_USE_AITER=1
) is enabled. To use this kernel, both the parent switch and its dedicated environment variableVLLM_ROCM_USE_AITER_W8A8_BLOCK_GEMM
must be enabled. This kernel is suitable for DeepSeek models.Fused MoE Kernels
Several fused MoE kernels have been integrated for different scenarios:
The
ck_moe
kernel from AITER is integrated for unquantized model weights. It is enabled by default whenVLLM_ROCM_USE_AITER=1
is set. It can be specifically enabled or disabled using the dedicated environment variableVLLM_ROCM_USE_AITER_MOE
. This is suitable for MoE models such as Mixtral.The
asm_moe
kernel from AITER is integrated for dynamic per-tensor quantization model weights. It is enabled by default whenVLLM_ROCM_USE_AITER=1
is set. It can be specifically enabled or disabled using the dedicated environment variableVLLM_ROCM_USE_AITER_MOE
. This is suitable for MoE models such as Mixtral for fp8 quantization.The
fmoe_fp8_block_scaled
kernel from AITER is integrated for block fp8 quantization method. Unlike the above features, this is disabled by default even when the parent switch (VLLM_ROCM_USE_AITER=1
) is enabled. To use this kernel, both the parent switch and its dedicated environment variableVLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE
must be enabled. This kernel is suitable for DeepSeek models.These MoE kernels are integrated in
/vllm/model_executor/layers/fused_moe/fused_moe.py
. The necessary processing steps required for these kernels are included in their respective MoE Methods for both Unquantized (UnquantizedMoEMethod
) in/vllm/model_executor/layers/fused_moe/layer.py
and FP8 quantized (FP8MoEMethod
) in/vllm/model_executor/layers/quantization/fp8.py
.Paged Attention
The
pa_fwd_asm
kernel from AITER is integrated as a new paged attention op in/vllm/attention/ops/rocm_aiter_paged_attn.py
and implemented into the ROCM attention backend in/vllm/attention/backends/rocm_flash_attn.py
.This feature is disabled by default, even when the parent switch (
VLLM_ROCM_USE_AITER=1
) is enabled. To use this kernel, both the parent switch and its dedicated environment variableVLLM_ROCM_USE_AITER_PAGED_ATTN
must be enabled.Note:
kv_cache_dtypes
:int8
"fp8"
"fp8_e4m3"
bfloat16
float16
float16
andbfloat16
kv_cache_dtype
, the module currently does not support decoding of models with more than 1kv_head
.Performance Improvement Tables
Experiment setup
The following experiment results are obtained from throughput_benchmark.py
There are 5 cases are tested, and the performance difference is summarized into a range of percentage gain.
5 Cases:
DeepSeekV3 Throughput
Note: The block scaled gemm from AITER (which relates to VLLM_ROCM_USE_AITER_BLOCK_GEMM=1) is still under development (tuning). There are internal tuning that boost the performances of block scaled gemm, so we should expect it to become faster in coming AITER updates.
DeepSeekV3 Latency
Mixtral-8x7B (with FP8 per-tensor dynamic quantization)
Note: The Fused MoE performs the best within SharedGPT dataset (75%) and is not performing that well in long input length (-14%).
Mixtral-8x7B-FP16
Note: The Fused MoE performs the best within SharedGPT dataset (2%) and is not performing that well in long input length (-11%).
Mixtral-8x22B (with FP8 per-tensor dynamic quantization)
Note: The Fused MoE performs the best within SharedGPT dataset (41%) and is not performing that well in long input length (4%).
Mixtral-8x22B-FP16
Note: The Fused MoE performs the best within SharedGPT dataset (19%) and is not performing that well in long input length (2%).
Llama-3.1-8B-Instruct (with FP8 per-tensor dynamic quantization)
Llama-3.1-8B-Instruct-BF16
Llama-3.1-70B-Instruct (with FP8 per-tensor dynamic quantization)
Llama-3.1-70B-Instruct-BF16
AITER Operations Testing Overview
1. High-Level Integration Tests
The integration of AITER ops is tested at a higher module level in the following files under
/tests/models/decoder_only/language
:test_models.py
test_phimoe.py
test_mistral.py
test_granite.py
These tests involve running various models to ensure overall functionality.
2. AITER MoE Specific Test
/tests/kernels/test_moe.py
3. Quantization Testing
/tests/quantization/test_fp8.py
4. Kernel Function Dispatch Testing
/tests/model_executor/test_enabled_custom_ops.py
Environment Settings
Several packages have been upgraded in
Dockerfile.rocm_base
based on this https://github.com/ROCm/vllm/blob/aiter_upstream/Dockerfile.rocm_base :Updated Packages:
PYTORCH_BRANCH
:6c0e7463
PYTORCH_VISION_BRANCH
:v0.21.0
Added AITER Package:
AITER_BRANCH
:e1ec015
Note:
When setting up AITER, it is crucial to use the command
git clone --recursive
. This is because the package depends on a third-party package (Composable Kernel).For building and installing the AITER Python package, you must use the
PREBUILD_KERNELS=1
flag along with the commandpython3 setup.py develop
. This ensures that all kernels in the AITER package are built successfully.The following branches were used as references for this integration: