77import torch
88
99from tests .kernels .quantization .nvfp4_utils import (
10- FLOAT4_E2M1_MAX ,
11- FLOAT8_E4M3_MAX ,
1210 dequantize_nvfp4_to_dtype ,
11+ get_nvfp4_global_scale ,
1312)
1413from vllm .platforms import current_platform
1514from vllm .utils import round_up
@@ -171,13 +170,12 @@ def test_flashinfer_trtllm_decode_with_baseline(
171170 output = torch .empty (ref_query .shape , dtype = dtype )
172171 wrapper .run (ref_query , ref_kv_cache , out = output )
173172 o_scale = 1.0
174- o_sf_scale = None
173+ o_sf_scale_float = None
175174 if o_quant_dtype == FP8_DTYPE :
176175 _ , o_scale = to_float8 (output )
177176 elif o_quant_dtype == FP4_DTYPE :
178- o_sf_scale = (
179- (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX ) / torch .amax (output .flatten (), dim = - 1 )
180- ).to (torch .float32 )
177+ o_sf_scale = get_nvfp4_global_scale (output )
178+ o_sf_scale_float = o_sf_scale .item ()
181179
182180 # TRTLLM Decode
183181 if o_quant_dtype == FP4_DTYPE :
@@ -204,7 +202,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
204202 bmm1_scale = q_scale * k_scale * sm_scale ,
205203 bmm2_scale = v_scale / o_scale ,
206204 window_left = window_left ,
207- o_sf_scale = o_sf_scale ,
205+ o_sf_scale = o_sf_scale_float ,
208206 out = output_trtllm ,
209207 )
210208 if o_quant_dtype == FP8_DTYPE :
@@ -361,13 +359,12 @@ def test_flashinfer_trtllm_prefill_with_baseline(
361359 output = torch .empty (ref_query .shape , dtype = dtype )
362360 wrapper .run (ref_query , ref_kv_cache , out = output )
363361 o_scale = 1.0
364- o_sf_scale = None
362+ o_sf_scale_float = None
365363 if o_quant_dtype == FP8_DTYPE :
366364 _ , o_scale = to_float8 (output )
367365 elif o_quant_dtype == FP4_DTYPE :
368- o_sf_scale = (
369- (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX ) / torch .amax (output .flatten (), dim = - 1 )
370- ).to (torch .float32 )
366+ o_sf_scale = get_nvfp4_global_scale (output )
367+ o_sf_scale_float = o_sf_scale .item ()
371368
372369 # TRTLLM Prefill
373370 if o_quant_dtype == FP4_DTYPE :
@@ -398,7 +395,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
398395 cum_seq_lens_q = q_indptr ,
399396 cum_seq_lens_kv = kv_indptr ,
400397 window_left = window_left ,
401- o_sf_scale = o_sf_scale ,
398+ o_sf_scale = o_sf_scale_float ,
402399 out = output_trtllm ,
403400 )
404401 if o_quant_dtype == FP8_DTYPE :
0 commit comments