4
4
import functools
5
5
import json
6
6
import os
7
- from typing import Any , Optional , Union
7
+ from typing import Any , Callable , Optional , Union
8
8
9
9
import torch
10
10
@@ -27,6 +27,76 @@ def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool:
27
27
return x == torch .float8_e4m3fn or x == torch .float8_e4m3fnuz
28
28
29
29
30
+ def cutlass_scaled_mm (
31
+ A : torch .Tensor ,
32
+ B : torch .Tensor ,
33
+ As : torch .Tensor ,
34
+ Bs : torch .Tensor ,
35
+ block_size : list [int ],
36
+ output_dtype : torch .dtype = torch .float16 ,
37
+ ) -> torch .Tensor :
38
+ return ops .cutlass_scaled_mm (A ,
39
+ B .T ,
40
+ out_dtype = output_dtype ,
41
+ scale_a = As ,
42
+ scale_b = Bs .T )
43
+
44
+
45
+ def rocm_aiter_gemm_w8a8_blockscale_impl (
46
+ A : torch .Tensor ,
47
+ B : torch .Tensor ,
48
+ As : torch .Tensor ,
49
+ Bs : torch .Tensor ,
50
+ block_size : list [int ],
51
+ output_dtype : torch .dtype = torch .float16 ,
52
+ ) -> torch .Tensor :
53
+ import aiter as rocm_aiter
54
+
55
+ return rocm_aiter .gemm_a8w8_blockscale_CK (A , B , As , Bs , dtype = output_dtype )
56
+
57
+
58
+ def rocm_aiter_gemm_w8a8_blockscale_fake (
59
+ A : torch .Tensor ,
60
+ B : torch .Tensor ,
61
+ As : torch .Tensor ,
62
+ Bs : torch .Tensor ,
63
+ block_size : list [int ],
64
+ output_dtype : torch .dtype = torch .float16 ,
65
+ ) -> torch .Tensor :
66
+
67
+ m = A .shape [0 ]
68
+ n = B .shape [0 ]
69
+ Y = torch .empty (m , n , dtype = output_dtype , device = A .device )
70
+ return Y
71
+
72
+
73
+ if current_platform .is_rocm ():
74
+ direct_register_custom_op (
75
+ op_name = "rocm_aiter_gemm_w8a8_blockscale" ,
76
+ op_func = rocm_aiter_gemm_w8a8_blockscale_impl ,
77
+ mutates_args = [],
78
+ fake_impl = rocm_aiter_gemm_w8a8_blockscale_fake ,
79
+ dispatch_key = current_platform .dispatch_key ,
80
+ )
81
+
82
+
83
+ def dispatch_w8a8_blockscale_func (
84
+ use_cutlass : bool , use_aiter_and_is_supported : bool
85
+ ) -> Callable [[
86
+ torch .Tensor ,
87
+ torch .Tensor ,
88
+ torch .Tensor ,
89
+ torch .Tensor ,
90
+ list [int ],
91
+ torch .dtype ,
92
+ ], torch .Tensor ]:
93
+ if use_cutlass :
94
+ return cutlass_scaled_mm
95
+ if (use_aiter_and_is_supported ):
96
+ return torch .ops .vllm .rocm_aiter_gemm_w8a8_blockscale
97
+ return w8a8_block_fp8_matmul
98
+
99
+
30
100
# TODO fix ROCm->Triton custom path:
31
101
# https://github.com/vllm-project/vllm/issues/14397
32
102
def apply_w8a8_block_fp8_linear (
@@ -37,26 +107,23 @@ def apply_w8a8_block_fp8_linear(
37
107
input_scale : Optional [torch .Tensor ] = None ,
38
108
bias : Optional [torch .Tensor ] = None ,
39
109
cutlass_block_fp8_supported : bool = CUTLASS_BLOCK_FP8_SUPPORTED ,
110
+ use_aiter_and_is_supported : bool = False ,
40
111
) -> torch .Tensor :
41
112
assert input_scale is None
42
113
# View input as 2D matrix for fp8 methods
43
114
input_2d = input .view (- 1 , input .shape [- 1 ])
44
115
output_shape = [* input .shape [:- 1 ], weight .shape [0 ]]
45
116
46
- shape_supported_by_cutlass = (weight .shape [0 ] % 128 == 0
47
- and weight .shape [1 ] % 128 == 0 )
48
- if current_platform .is_rocm ():
49
- # TODO this is never used, as cutlass_block_fp8_supported is False
50
- scale_a_shape = ((input_2d .shape [- 1 ] // block_size [1 ], ) +
51
- input_2d .shape [:- 1 ])[::- 1 ]
52
- scale_b_shape = (weight_scale .view (- 1 , 1 )
53
- if weight_scale .dim () <= 1 else weight_scale .T ).shape
54
- ar , ac = scale_a_shape
55
- br , bc = scale_b_shape
56
- if (ac > 1 or bc > 1 or ar not in (1 , input_2d .shape [0 ])
57
- or br not in (1 , weight .shape [0 ])):
58
- shape_supported_by_cutlass = False
59
- if cutlass_block_fp8_supported and shape_supported_by_cutlass :
117
+ if current_platform .is_cuda ():
118
+ use_cutlass = cutlass_block_fp8_supported and (
119
+ weight .shape [0 ] % 128 == 0 and weight .shape [1 ] % 128 == 0 )
120
+ else :
121
+ use_cutlass = False
122
+
123
+ w8a8_blockscale_func = dispatch_w8a8_blockscale_func (
124
+ use_cutlass , use_aiter_and_is_supported )
125
+
126
+ if use_cutlass :
60
127
rows , cols = input_2d .shape
61
128
# Blackwell GPUs (SM100) require row dimensions to be multiple of 4 for
62
129
# optimal tensor core usage. Can be removed when targeting platforms
@@ -67,26 +134,22 @@ def apply_w8a8_block_fp8_linear(
67
134
input_2d = torch .nn .functional .pad (input_2d ,
68
135
(0 , 0 , 0 , 4 - (rows % 4 )),
69
136
value = 0 ).contiguous ()
70
- q_input , x_scale = per_token_group_quant_fp8 (input_2d ,
71
- block_size [1 ],
72
- column_major_scales = True )
73
- output = ops .cutlass_scaled_mm (q_input ,
74
- weight .T ,
75
- out_dtype = input .dtype ,
76
- scale_a = x_scale ,
77
- scale_b = weight_scale .T )
137
+
138
+ q_input , x_scale = per_token_group_quant_fp8 (
139
+ input_2d , block_size [1 ], column_major_scales = use_cutlass )
140
+
141
+ output = w8a8_blockscale_func (q_input , weight , x_scale , weight_scale ,
142
+ block_size , input .dtype )
78
143
if should_pad :
79
144
output = output [:rows , :]
145
+
80
146
else :
81
- q_input , x_scale = per_token_group_quant_fp8 (input_2d ,
82
- block_size [1 ],
83
- column_major_scales = False )
84
- output = w8a8_block_fp8_matmul (q_input ,
85
- weight ,
86
- x_scale ,
87
- weight_scale ,
88
- block_size ,
89
- output_dtype = input .dtype )
147
+ q_input , x_scale = per_token_group_quant_fp8 (
148
+ input_2d , block_size [1 ], column_major_scales = use_cutlass )
149
+
150
+ output = w8a8_blockscale_func (q_input , weight , x_scale , weight_scale ,
151
+ block_size , input .dtype )
152
+
90
153
if bias is not None :
91
154
output = output + bias
92
155
return output .to (dtype = input .dtype ).view (* output_shape )
@@ -98,6 +161,9 @@ def apply_w8a8_block_fp8_linear_fake(
98
161
block_size : list [int ],
99
162
weight_scale : torch .Tensor ,
100
163
input_scale : Optional [torch .Tensor ] = None ,
164
+ bias : Optional [torch .Tensor ] = None ,
165
+ cutlass_block_fp8_supported : bool = CUTLASS_BLOCK_FP8_SUPPORTED ,
166
+ use_aiter_and_is_supported : bool = False ,
101
167
) -> torch .Tensor :
102
168
output_shape = [* input .shape [:- 1 ], weight .shape [0 ]]
103
169
return torch .empty (output_shape , dtype = input .dtype , device = input .device )
0 commit comments