@@ -142,7 +142,6 @@ def from_hp(
142142
143143 # Store fp8 data in both dense and compressed formats
144144 fp8_data_fp16 = fp8_data .to (torch .float16 )
145- from torch .sparse import to_sparse_semi_structured
146145
147146 fp8_compressed = to_sparse_semi_structured (fp8_data_fp16 )
148147
@@ -180,47 +179,28 @@ def _(func, types, args, kwargs):
180179 )
181180
182181 assert isinstance (weight_tensor , Float8SemiSparseTensor )
183- assert activation_tensor .shape [- 1 ] == weight_tensor .original_shape [1 ], (
184- f"Shape mismatch: { activation_tensor .shape } @ { weight_tensor .original_shape } "
185- )
186-
187- # Flatten batch dimensions for scale computation
188- orig_shape = activation_tensor .shape
189- if activation_tensor .dim () > 2 :
190- activation_flat = activation_tensor .view (- 1 , orig_shape [- 1 ])
191- else :
192- activation_flat = activation_tensor
182+ assert activation_tensor .dim () == 2 , "Only 2D input supported"
183+ assert activation_tensor .shape [- 1 ] == weight_tensor .original_shape [1 ]
193184
194- # Compute dynamic scale for activation quantization
195- x_scales = _choose_qparams_affine_floatx (activation_flat , ebits = 4 , mbits = 3 )
196- x_scales = x_scales .unsqueeze (1 ) # [batch, 1]
185+ x_scales = _choose_qparams_affine_floatx (activation_tensor , ebits = 4 , mbits = 3 )
186+ w_scales = weight_tensor .scale
197187
198- # Quantize activation
199- scaled_x = activation_flat / x_scales
200- scaled_x = scaled_x .clamp (- 448.0 , 448.0 )
188+ # Global normalizer to prevent overflow
189+ global_scale = (x_scales .max () * w_scales .max ()).sqrt ().clamp (min = 0.01 )
190+ x_scales_adj = (x_scales .unsqueeze (1 ) / global_scale ).to (torch .float32 )
191+ scaled_x = (activation_tensor .to (torch .float32 ) / x_scales_adj ).clamp (- 448.0 , 448.0 )
201192 x_vals_fp8 = scaled_x .to (torch .float8_e4m3fn )
202193
203- # Dequantize both activation and weight before MatMul to avoid FP16 overflow
204- x_dequant = ( x_vals_fp8 . to ( torch . float32 ) * x_scales . to ( torch . float32 )). to (
205- torch .float16
194+ # MatMul
195+ x_padded = SparseSemiStructuredTensorCUSPARSELT . _pad_dense_input (
196+ x_vals_fp8 . to ( torch .float16 )
206197 )
207- w_dequant = (
208- weight_tensor .qdata .to (torch .float32 )
209- * weight_tensor .scale .unsqueeze (1 ).to (torch .float32 )
210- ).to (torch .float16 )
211-
212- # Sparse MatMul with dequntized tensor
213- w_sparse = to_sparse_semi_structured (w_dequant )
214- row = x_dequant .shape [0 ]
215- x_padded = SparseSemiStructuredTensorCUSPARSELT ._pad_dense_input (x_dequant )
216-
217- y = torch .matmul (x_padded , w_sparse .t ())
218- y = y [:row , :]
219-
220- # Reshape to original activation shape
221- if activation_tensor .dim () > 2 :
222- y = y .view (* orig_shape [:- 1 ], - 1 )
198+ y_fp16 = torch .matmul (x_padded , weight_tensor .qdata_compressed .t ())
199+ y = y_fp16 [: activation_tensor .shape [0 ], :].to (torch .float32 )
223200
201+ # Restore scale
202+ w_scales_fp32 = w_scales .to (torch .float32 )
203+ y = y * (x_scales_adj * w_scales_fp32 .unsqueeze (0 ) * global_scale )
224204 y = y .to (activation_tensor .dtype ).contiguous ()
225205
226206 if bias is not None :
0 commit comments