Skip to content

Commit 40de1ef

Browse files
vllmellmtjtanaa
andauthored
[FEAT] [ROCm]: Add AITER Block-Scaled GEMM Feature (#14968)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
1 parent 0189a65 commit 40de1ef

File tree

3 files changed

+137
-32
lines changed

3 files changed

+137
-32
lines changed

tests/model_executor/test_enabled_custom_ops.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import pytest
4+
import torch
45

56
from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config
67
from vllm.model_executor.custom_op import CustomOp
@@ -16,6 +17,8 @@
1617
from vllm.model_executor.layers.layernorm import (
1718
RMSNorm, dispatch_cuda_rmsnorm_func, fused_add_rms_norm, rms_norm,
1819
rocm_aiter_fused_add_rms_norm, rocm_aiter_rms_norm)
20+
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
21+
cutlass_scaled_mm, dispatch_w8a8_blockscale_func, w8a8_block_fp8_matmul)
1922
from vllm.platforms import current_platform
2023

2124

@@ -98,6 +101,34 @@ def test_enabled_ops_invalid(env: str):
98101
RMSNorm(1024).enabled()
99102

100103

104+
@pytest.mark.skipif(
105+
not current_platform.is_rocm() or not current_platform.is_fp8_fnuz(),
106+
reason="AITER is a feature exclusive for ROCm and FP8_FNUZ")
107+
@pytest.mark.parametrize("use_cutlass", [True, False])
108+
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
109+
@pytest.mark.parametrize("use_rocm_aiter_gemm_w8a8_blockscale", ["0", "1"])
110+
def test_w8a8_blockscale_dispatch(use_cutlass: bool, use_rocm_aiter: str,
111+
use_rocm_aiter_gemm_w8a8_blockscale: str,
112+
monkeypatch):
113+
114+
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
115+
monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR",
116+
use_rocm_aiter_gemm_w8a8_blockscale)
117+
118+
use_aiter_and_is_supported = (bool(int(use_rocm_aiter)) and bool(
119+
int(use_rocm_aiter_gemm_w8a8_blockscale)))
120+
block_scale_func = dispatch_w8a8_blockscale_func(
121+
use_cutlass, use_aiter_and_is_supported=use_aiter_and_is_supported)
122+
if use_cutlass:
123+
assert block_scale_func == cutlass_scaled_mm
124+
elif current_platform.is_rocm() and int(use_rocm_aiter) and int(
125+
use_rocm_aiter_gemm_w8a8_blockscale):
126+
assert block_scale_func == (
127+
torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale)
128+
else:
129+
assert block_scale_func == w8a8_block_fp8_matmul
130+
131+
101132
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
102133
def test_topk_dispatch(use_rocm_aiter: str, monkeypatch):
103134
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,13 @@ def __init__(self, quant_config: Fp8Config):
182182
if current_platform.is_rocm():
183183
self.use_marlin = False
184184

185+
# AITER is only supported on ROCm and only for FP8_FNUZ
186+
# and at the moment are MI300 series
187+
self.use_aiter_and_is_supported = (current_platform.is_rocm()
188+
and envs.VLLM_ROCM_USE_AITER
189+
and envs.VLLM_ROCM_USE_AITER_LINEAR
190+
and current_platform.is_fp8_fnuz())
191+
185192
self.block_quant = self.quant_config.weight_block_size is not None
186193
self.fp8_linear = Fp8LinearOp(
187194
# Default to using per_token quantization if cutlass is supported
@@ -402,6 +409,7 @@ def apply(self,
402409
input_scale=layer.input_scale,
403410
bias=bias,
404411
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
412+
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
405413
)
406414

407415
return self.fp8_linear.apply(input=x,

vllm/model_executor/layers/quantization/utils/fp8_utils.py

Lines changed: 98 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import functools
55
import json
66
import os
7-
from typing import Any, Optional, Union
7+
from typing import Any, Callable, Optional, Union
88

99
import torch
1010

@@ -27,6 +27,76 @@ def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool:
2727
return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz
2828

2929

30+
def cutlass_scaled_mm(
31+
A: torch.Tensor,
32+
B: torch.Tensor,
33+
As: torch.Tensor,
34+
Bs: torch.Tensor,
35+
block_size: list[int],
36+
output_dtype: torch.dtype = torch.float16,
37+
) -> torch.Tensor:
38+
return ops.cutlass_scaled_mm(A,
39+
B.T,
40+
out_dtype=output_dtype,
41+
scale_a=As,
42+
scale_b=Bs.T)
43+
44+
45+
def rocm_aiter_gemm_w8a8_blockscale_impl(
46+
A: torch.Tensor,
47+
B: torch.Tensor,
48+
As: torch.Tensor,
49+
Bs: torch.Tensor,
50+
block_size: list[int],
51+
output_dtype: torch.dtype = torch.float16,
52+
) -> torch.Tensor:
53+
import aiter as rocm_aiter
54+
55+
return rocm_aiter.gemm_a8w8_blockscale_CK(A, B, As, Bs, dtype=output_dtype)
56+
57+
58+
def rocm_aiter_gemm_w8a8_blockscale_fake(
59+
A: torch.Tensor,
60+
B: torch.Tensor,
61+
As: torch.Tensor,
62+
Bs: torch.Tensor,
63+
block_size: list[int],
64+
output_dtype: torch.dtype = torch.float16,
65+
) -> torch.Tensor:
66+
67+
m = A.shape[0]
68+
n = B.shape[0]
69+
Y = torch.empty(m, n, dtype=output_dtype, device=A.device)
70+
return Y
71+
72+
73+
if current_platform.is_rocm():
74+
direct_register_custom_op(
75+
op_name="rocm_aiter_gemm_w8a8_blockscale",
76+
op_func=rocm_aiter_gemm_w8a8_blockscale_impl,
77+
mutates_args=[],
78+
fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake,
79+
dispatch_key=current_platform.dispatch_key,
80+
)
81+
82+
83+
def dispatch_w8a8_blockscale_func(
84+
use_cutlass: bool, use_aiter_and_is_supported: bool
85+
) -> Callable[[
86+
torch.Tensor,
87+
torch.Tensor,
88+
torch.Tensor,
89+
torch.Tensor,
90+
list[int],
91+
torch.dtype,
92+
], torch.Tensor]:
93+
if use_cutlass:
94+
return cutlass_scaled_mm
95+
if (use_aiter_and_is_supported):
96+
return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale
97+
return w8a8_block_fp8_matmul
98+
99+
30100
# TODO fix ROCm->Triton custom path:
31101
# https://github.com/vllm-project/vllm/issues/14397
32102
def apply_w8a8_block_fp8_linear(
@@ -37,26 +107,23 @@ def apply_w8a8_block_fp8_linear(
37107
input_scale: Optional[torch.Tensor] = None,
38108
bias: Optional[torch.Tensor] = None,
39109
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
110+
use_aiter_and_is_supported: bool = False,
40111
) -> torch.Tensor:
41112
assert input_scale is None
42113
# View input as 2D matrix for fp8 methods
43114
input_2d = input.view(-1, input.shape[-1])
44115
output_shape = [*input.shape[:-1], weight.shape[0]]
45116

46-
shape_supported_by_cutlass = (weight.shape[0] % 128 == 0
47-
and weight.shape[1] % 128 == 0)
48-
if current_platform.is_rocm():
49-
# TODO this is never used, as cutlass_block_fp8_supported is False
50-
scale_a_shape = ((input_2d.shape[-1] // block_size[1], ) +
51-
input_2d.shape[:-1])[::-1]
52-
scale_b_shape = (weight_scale.view(-1, 1)
53-
if weight_scale.dim() <= 1 else weight_scale.T).shape
54-
ar, ac = scale_a_shape
55-
br, bc = scale_b_shape
56-
if (ac > 1 or bc > 1 or ar not in (1, input_2d.shape[0])
57-
or br not in (1, weight.shape[0])):
58-
shape_supported_by_cutlass = False
59-
if cutlass_block_fp8_supported and shape_supported_by_cutlass:
117+
if current_platform.is_cuda():
118+
use_cutlass = cutlass_block_fp8_supported and (
119+
weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0)
120+
else:
121+
use_cutlass = False
122+
123+
w8a8_blockscale_func = dispatch_w8a8_blockscale_func(
124+
use_cutlass, use_aiter_and_is_supported)
125+
126+
if use_cutlass:
60127
rows, cols = input_2d.shape
61128
# Blackwell GPUs (SM100) require row dimensions to be multiple of 4 for
62129
# optimal tensor core usage. Can be removed when targeting platforms
@@ -67,26 +134,22 @@ def apply_w8a8_block_fp8_linear(
67134
input_2d = torch.nn.functional.pad(input_2d,
68135
(0, 0, 0, 4 - (rows % 4)),
69136
value=0).contiguous()
70-
q_input, x_scale = per_token_group_quant_fp8(input_2d,
71-
block_size[1],
72-
column_major_scales=True)
73-
output = ops.cutlass_scaled_mm(q_input,
74-
weight.T,
75-
out_dtype=input.dtype,
76-
scale_a=x_scale,
77-
scale_b=weight_scale.T)
137+
138+
q_input, x_scale = per_token_group_quant_fp8(
139+
input_2d, block_size[1], column_major_scales=use_cutlass)
140+
141+
output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale,
142+
block_size, input.dtype)
78143
if should_pad:
79144
output = output[:rows, :]
145+
80146
else:
81-
q_input, x_scale = per_token_group_quant_fp8(input_2d,
82-
block_size[1],
83-
column_major_scales=False)
84-
output = w8a8_block_fp8_matmul(q_input,
85-
weight,
86-
x_scale,
87-
weight_scale,
88-
block_size,
89-
output_dtype=input.dtype)
147+
q_input, x_scale = per_token_group_quant_fp8(
148+
input_2d, block_size[1], column_major_scales=use_cutlass)
149+
150+
output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale,
151+
block_size, input.dtype)
152+
90153
if bias is not None:
91154
output = output + bias
92155
return output.to(dtype=input.dtype).view(*output_shape)
@@ -98,6 +161,9 @@ def apply_w8a8_block_fp8_linear_fake(
98161
block_size: list[int],
99162
weight_scale: torch.Tensor,
100163
input_scale: Optional[torch.Tensor] = None,
164+
bias: Optional[torch.Tensor] = None,
165+
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
166+
use_aiter_and_is_supported: bool = False,
101167
) -> torch.Tensor:
102168
output_shape = [*input.shape[:-1], weight.shape[0]]
103169
return torch.empty(output_shape, dtype=input.dtype, device=input.device)

0 commit comments

Comments
 (0)