Skip to content

[Hardware][AMD] integrate aiter into vllm #17710

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
VLLM_ROCM_USE_AITER_MOE: bool = True
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
VLLM_ROCM_USE_AITER_MLA: bool = True
VLLM_ROCM_USE_AITER_MHA: bool = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering what's difference between VLLM_ROCM_USE_AITER and VLLM_ROCM_USE_AITER_MHA?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering what's difference between VLLM_ROCM_USE_AITER and VLLM_ROCM_USE_AITER_MHA?

Main switch and submodule switch.

VLLM_ROCM_USE_SKINNY_GEMM: bool = True
VLLM_ROCM_FP8_PADDING: bool = True
VLLM_ROCM_MOE_PADDING: bool = True
Expand Down Expand Up @@ -581,6 +582,13 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
"VLLM_ROCM_USE_AITER_MLA":
lambda: (os.getenv("VLLM_ROCM_USE_AITER_MLA", "True").lower() in
("true", "1")),

# Whether to use aiter mha ops.
# By default is enabled.
"VLLM_ROCM_USE_AITER_MHA":
lambda: (os.getenv("VLLM_ROCM_USE_AITER_MHA", "True").lower() in
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to override #16828 by default?

("true", "1")),

# use rocm skinny gemms
"VLLM_ROCM_USE_SKINNY_GEMM":
lambda: (os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in
Expand Down
76 changes: 51 additions & 25 deletions vllm/model_executor/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import vllm.envs as envs
from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op


def is_rocm_aiter_rmsnorm_enabled() -> bool:
Expand Down Expand Up @@ -42,46 +43,71 @@ def fused_add_rms_norm(
return x, residual


def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> torch.Tensor:
if is_rocm_aiter_rmsnorm_enabled():

import aiter as rocm_aiter
if x.dim() > 2:
x_original_shape = x.shape
x = x.reshape(-1, x_original_shape[-1])
x = rocm_aiter.rms_norm(x, weight, variance_epsilon)
return x.reshape(x_original_shape)
def rocm_aiter_rms_norm_impl(x: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> torch.Tensor:

return rocm_aiter.rms_norm(x, weight, variance_epsilon)
import aiter as rocm_aiter
if x.dim() > 2:
x_original_shape = x.shape
x = x.reshape(-1, x_original_shape[-1])
x = rocm_aiter.rms_norm(x, weight, variance_epsilon)
return x.reshape(x_original_shape)

return rocm_aiter.rms_norm(x, weight, variance_epsilon)

def rocm_aiter_fused_add_rms_norm(
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
def rocm_aiter_rms_norm_fake(input: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> torch.Tensor:
return input.clone()

import aiter as rocm_aiter
direct_register_custom_op(
op_name="rocm_aiter_rms_norm",
op_func=rocm_aiter_rms_norm_impl,
mutates_args=[],
fake_impl=rocm_aiter_rms_norm_fake,
dispatch_key=current_platform.dispatch_key,
)

residual_out = torch.empty_like(residual)
output = torch.empty_like(x)
rocm_aiter.rmsnorm2d_fwd_with_add(
output, # output
x, # input
residual, # residual input
residual_out, # residual output
weight,
variance_epsilon,
def rocm_aiter_fused_add_rms_norm_impl(
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:

import aiter as rocm_aiter
residual_out = torch.empty_like(residual)
output = torch.empty_like(x)
rocm_aiter.rmsnorm2d_fwd_with_add(
output, # output
x, # input
residual, # residual input
residual_out, # residual output
weight,
variance_epsilon,
)
return output, residual_out

def rocm_aiter_fused_add_rms_norm_fake(
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
return x.clone(), residual.clone()

direct_register_custom_op(
op_name="rocm_aiter_fused_add_rms_norm",
op_func=rocm_aiter_fused_add_rms_norm_impl,
mutates_args=[],
fake_impl=rocm_aiter_fused_add_rms_norm_fake,
dispatch_key=current_platform.dispatch_key,
)
return output, residual_out


def dispatch_cuda_rmsnorm_func(add_residual: bool):
if add_residual:
if is_rocm_aiter_rmsnorm_enabled():
return rocm_aiter_fused_add_rms_norm
return torch.ops.vllm.rocm_aiter_fused_add_rms_norm
return fused_add_rms_norm

if is_rocm_aiter_rmsnorm_enabled():
return rocm_aiter_rms_norm
return torch.ops.vllm.rocm_aiter_rms_norm
return rms_norm


Expand Down
12 changes: 9 additions & 3 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,15 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
selected_backend = (_Backend.ROCM_FLASH if selected_backend
== _Backend.FLASH_ATTN else selected_backend)
if envs.VLLM_USE_V1:
logger.info("Using Triton Attention backend on V1 engine.")
return ("vllm.v1.attention.backends."
"triton_attn.TritonAttentionBackend")
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA \
and on_mi250_mi300():
logger.info("Using Flash Attention backend on V1 engine.")
return ("vllm.v1.attention.backends."
"rocm_aiter_fa.AiterFlashAttentionBackend")
else:
logger.info("Using Triton Attention backend on V1 engine.")
return ("vllm.v1.attention.backends."
"triton_attn.TritonAttentionBackend")
if selected_backend == _Backend.ROCM_FLASH:
if not cls.has_device_capability(90):
# not Instinct series GPUs.
Expand Down
Loading