77
88from  vllm  import  _custom_ops  as  ops 
99from  vllm .model_executor .custom_op  import  CustomOp 
10- from  vllm .model_executor .layers .quantization .utils .fp8_quant_ops  import  (
11-     quantize_fp8_per_group , quantize_fp8_per_tensor , quantize_fp8_per_token )
1210from  vllm .model_executor .layers .quantization .utils .quant_utils  import  (
1311    GroupShape )
1412from  vllm .platforms  import  current_platform 
1513
1614# Using the default value (240.0) from pytorch will cause accuracy 
1715# issue on dynamic quantization models. Here use 224.0 for fnuz on ROCm. 
1816_FP8_DTYPE  =  current_platform .fp8_dtype ()
17+ _FP8_FINFO  =  torch .finfo (_FP8_DTYPE )
18+ _FP8_MAX  =  224.0  if  current_platform .is_fp8_fnuz () else  _FP8_FINFO .max 
19+ _FP8_MIN  =  - 224.0  if  current_platform .is_fp8_fnuz () else  _FP8_FINFO .min 
20+ _FP8_MIN_SCALING_FACTOR  =  1.0  /  (_FP8_MAX  *  512.0 )
1921
2022
2123@CustomOp .register ("quant_fp8" ) 
@@ -92,9 +94,25 @@ def forward_native(
9294                                    and  scale_ub .numel () ==  1 )
9395
9496        if  self .use_per_token_if_dynamic  and  scale  is  None :
95-             out , scale  =  quantize_fp8_per_token (x , scale , scale_ub )
97+             # Per-token quantization logic 
98+             x_max , _  =  x .abs ().max (dim = - 1 )
99+             x_max  =  x_max .unsqueeze (- 1 ).to (torch .float32 )
100+             if  scale_ub  is  not None :
101+                 x_max  =  x_max .clamp (max = scale_ub )
102+             scale  =  (x_max  /  _FP8_MAX ).clamp (min = _FP8_MIN_SCALING_FACTOR )
103+ 
104+             out  =  x .to (torch .float32 ) *  scale .reciprocal ()
105+             out  =  out .clamp (_FP8_MIN , _FP8_MAX ).to (_FP8_DTYPE )
96106        else :
97-             out , scale  =  quantize_fp8_per_tensor (x , scale )
107+             # Per-tensor quantization logic 
108+             if  scale  is  None :
109+                 x_max  =  x .abs ().max ().unsqueeze (- 1 ).to (torch .float32 )
110+                 scale  =  (x_max  /  _FP8_MAX ).clamp (min = _FP8_MIN_SCALING_FACTOR )
111+ 
112+             # Even for dynamic per-token scales, 
113+             # reciprocal performs slightly better than division 
114+             out  =  x .to (torch .float32 ) *  scale .reciprocal ()
115+             out  =  out .clamp (_FP8_MIN , _FP8_MAX ).to (_FP8_DTYPE )
98116
99117        # This currently generates an extra Triton kernel in compilation. 
100118        # Fortunately, we don't use padding if compiling. 
@@ -118,5 +136,31 @@ def _quantize_group_cuda(
118136
119137    def  _quantize_group_native (
120138            self , x : torch .Tensor ) ->  tuple [torch .Tensor , torch .Tensor ]:
121-         return  quantize_fp8_per_group (x , self .group_size ,
122-                                       self .column_major_scales )
139+         orig_shape  =  x .shape 
140+         hidden_dim  =  x .shape [- 1 ]
141+         num_groups  =  (hidden_dim  +  self .group_size  -  1 ) //  self .group_size 
142+         padded_dim  =  num_groups  *  self .group_size 
143+ 
144+         if  padded_dim  !=  hidden_dim :
145+             padding  =  padded_dim  -  hidden_dim 
146+             x  =  F .pad (x , (0 , padding ), mode = 'constant' , value = 0.0 )
147+ 
148+         x_grouped  =  x .view (- 1 , num_groups , self .group_size )
149+         absmax  =  x_grouped .abs ().max (dim = - 1 , keepdim = True )[0 ].float ()
150+         scales  =  (absmax  /  _FP8_MAX ).clamp (min = _FP8_MIN_SCALING_FACTOR )
151+ 
152+         x_scaled  =  x_grouped  /  scales 
153+         x_quant  =  x_scaled .clamp (_FP8_MIN , _FP8_MAX ).to (_FP8_DTYPE )
154+ 
155+         x_quant  =  x_quant .view (- 1 , padded_dim )
156+         if  padded_dim  !=  hidden_dim :
157+             x_quant  =  x_quant [..., :hidden_dim ]
158+         x_quant  =  x_quant .view (orig_shape )
159+ 
160+         scales  =  scales .squeeze (- 1 )
161+         scales  =  scales .reshape (orig_shape [:- 1 ] +  (num_groups , ))
162+ 
163+         if  self .column_major_scales :
164+             scales  =  scales .transpose (- 2 , - 1 ).contiguous ()
165+ 
166+         return  x_quant , scales 
0 commit comments