1414
1515import vllm .envs as envs
1616from vllm .platforms import current_platform
17- from vllm .utils import has_deep_gemm
17+ from vllm .utils import cdiv , has_deep_gemm
1818
1919
2020@functools .cache
@@ -37,7 +37,7 @@ def is_blackwell_deep_gemm_used() -> bool:
3737 return False
3838
3939 _lazy_init ()
40- if _per_block_cast_impl is None :
40+ if _fp8_gemm_nt_impl is None :
4141 return False
4242
4343 return (current_platform .is_cuda ()
@@ -63,18 +63,15 @@ def _resolve_symbol(module, new: str, old: str) -> Callable[..., Any] | None:
6363_fp8_gemm_nt_impl : Callable [..., Any ] | None = None
6464_grouped_impl : Callable [..., Any ] | None = None
6565_grouped_masked_impl : Callable [..., Any ] | None = None
66- _per_block_cast_impl : Callable [..., Any ] | None = None
6766
6867
6968def _lazy_init () -> None :
7069 """Import deep_gemm and resolve symbols on first use."""
71- global _fp8_gemm_nt_impl , _grouped_impl , _grouped_masked_impl , \
72- _per_block_cast_impl
70+ global _fp8_gemm_nt_impl , _grouped_impl , _grouped_masked_impl
7371
7472 # fast path
7573 if (_fp8_gemm_nt_impl is not None or _grouped_impl is not None
76- or _grouped_masked_impl is not None
77- or _per_block_cast_impl is not None ):
74+ or _grouped_masked_impl is not None ):
7875 return
7976
8077 if not has_deep_gemm ():
@@ -90,14 +87,6 @@ def _lazy_init() -> None:
9087 _grouped_masked_impl = _resolve_symbol (
9188 _dg , "fp8_m_grouped_gemm_nt_masked" ,
9289 "m_grouped_gemm_fp8_fp8_bf16_nt_masked" )
93- # Try to get per_token_cast_to_fp8 from DeepGEMM math utils.
94- try :
95- _math_mod = importlib .import_module (
96- "deep_gemm.utils.math" ) # type: ignore
97- _per_block_cast_impl = getattr (_math_mod , "per_block_cast_to_fp8" ,
98- None )
99- except ModuleNotFoundError :
100- _per_block_cast_impl = None
10190
10291
10392def fp8_gemm_nt (* args , ** kwargs ):
@@ -121,13 +110,37 @@ def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
121110 return _grouped_masked_impl (* args , ** kwargs )
122111
123112
124- def per_block_cast_to_fp8 (x , * args , ** kwargs ):
125- _lazy_init ()
126- if _per_block_cast_impl is not None and is_blackwell_deep_gemm_used ():
127- return _per_block_cast_impl (x , use_ue8m0 = True )
128- # TODO: refactor the `per_block_cast_to_fp8` from tests to vllm utils
129- from tests .kernels .quant_utils import per_block_cast_to_fp8 as _pbcf
130- return _pbcf (x , * args , ** kwargs )
113+ def _ceil_to_ue8m0 (x : torch .Tensor ):
114+ return torch .pow (2.0 , torch .ceil (torch .log2 (x .abs ())))
115+
116+
117+ def _align (x : int , y : int ) -> int :
118+ return cdiv (x , y ) * y
119+
120+
121+ DEFAULT_BLOCK_SIZE = [128 , 128 ]
122+
123+
124+ # Taken from https://github.com/deepseek-ai/DeepGEMM/blob/dd6ed14acbc7445dcef224248a77ab4d22b5f240/deep_gemm/utils/math.py#L38
125+ # TODO(wentao): optimize this function, using triton or cuda kernel
126+ def per_block_cast_to_fp8 (
127+ x : torch .Tensor ,
128+ block_size : list [int ] = DEFAULT_BLOCK_SIZE ,
129+ use_ue8m0 : bool = False ) -> tuple [torch .Tensor , torch .Tensor ]:
130+ assert x .dim () == 2
131+ m , n = x .shape
132+ block_m , block_n = block_size
133+ x_padded = torch .zeros ((_align (m , block_m ), _align (n , block_n )),
134+ dtype = x .dtype ,
135+ device = x .device )
136+ x_padded [:m , :n ] = x
137+ x_view = x_padded .view (- 1 , block_m , x_padded .size (1 ) // block_n , block_n )
138+ x_amax = x_view .abs ().float ().amax (dim = (1 , 3 ), keepdim = True ).clamp (1e-4 )
139+ sf = x_amax / 448.0
140+ sf = _ceil_to_ue8m0 (sf ) if use_ue8m0 else sf
141+ x_scaled = (x_view * (1.0 / sf )).to (torch .float8_e4m3fn )
142+ return x_scaled .view_as (x_padded )[:m , :n ].contiguous (), sf .view (
143+ x_view .size (0 ), x_view .size (2 ))
131144
132145
133146def calc_diff (x : torch .Tensor , y : torch .Tensor ):
0 commit comments