Skip to content

Commit f5f7a17

Browse files
committed
fix FP16 bit-range overflow
1 parent 96f8374 commit f5f7a17

File tree

1 file changed

+16
-36
lines changed

1 file changed

+16
-36
lines changed

torchao/prototype/quantization/quantize_/workflows/float8/float8_semisparse_tensor.py

Lines changed: 16 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,6 @@ def from_hp(
142142

143143
# Store fp8 data in both dense and compressed formats
144144
fp8_data_fp16 = fp8_data.to(torch.float16)
145-
from torch.sparse import to_sparse_semi_structured
146145

147146
fp8_compressed = to_sparse_semi_structured(fp8_data_fp16)
148147

@@ -180,47 +179,28 @@ def _(func, types, args, kwargs):
180179
)
181180

182181
assert isinstance(weight_tensor, Float8SemiSparseTensor)
183-
assert activation_tensor.shape[-1] == weight_tensor.original_shape[1], (
184-
f"Shape mismatch: {activation_tensor.shape} @ {weight_tensor.original_shape}"
185-
)
186-
187-
# Flatten batch dimensions for scale computation
188-
orig_shape = activation_tensor.shape
189-
if activation_tensor.dim() > 2:
190-
activation_flat = activation_tensor.view(-1, orig_shape[-1])
191-
else:
192-
activation_flat = activation_tensor
182+
assert activation_tensor.dim() == 2, "Only 2D input supported"
183+
assert activation_tensor.shape[-1] == weight_tensor.original_shape[1]
193184

194-
# Compute dynamic scale for activation quantization
195-
x_scales = _choose_qparams_affine_floatx(activation_flat, ebits=4, mbits=3)
196-
x_scales = x_scales.unsqueeze(1) # [batch, 1]
185+
x_scales = _choose_qparams_affine_floatx(activation_tensor, ebits=4, mbits=3)
186+
w_scales = weight_tensor.scale
197187

198-
# Quantize activation
199-
scaled_x = activation_flat / x_scales
200-
scaled_x = scaled_x.clamp(-448.0, 448.0)
188+
# Global normalizer to prevent overflow
189+
global_scale = (x_scales.max() * w_scales.max()).sqrt().clamp(min=0.01)
190+
x_scales_adj = (x_scales.unsqueeze(1) / global_scale).to(torch.float32)
191+
scaled_x = (activation_tensor.to(torch.float32) / x_scales_adj).clamp(-448.0, 448.0)
201192
x_vals_fp8 = scaled_x.to(torch.float8_e4m3fn)
202193

203-
# Dequantize both activation and weight before MatMul to avoid FP16 overflow
204-
x_dequant = (x_vals_fp8.to(torch.float32) * x_scales.to(torch.float32)).to(
205-
torch.float16
194+
# MatMul
195+
x_padded = SparseSemiStructuredTensorCUSPARSELT._pad_dense_input(
196+
x_vals_fp8.to(torch.float16)
206197
)
207-
w_dequant = (
208-
weight_tensor.qdata.to(torch.float32)
209-
* weight_tensor.scale.unsqueeze(1).to(torch.float32)
210-
).to(torch.float16)
211-
212-
# Sparse MatMul with dequntized tensor
213-
w_sparse = to_sparse_semi_structured(w_dequant)
214-
row = x_dequant.shape[0]
215-
x_padded = SparseSemiStructuredTensorCUSPARSELT._pad_dense_input(x_dequant)
216-
217-
y = torch.matmul(x_padded, w_sparse.t())
218-
y = y[:row, :]
219-
220-
# Reshape to original activation shape
221-
if activation_tensor.dim() > 2:
222-
y = y.view(*orig_shape[:-1], -1)
198+
y_fp16 = torch.matmul(x_padded, weight_tensor.qdata_compressed.t())
199+
y = y_fp16[: activation_tensor.shape[0], :].to(torch.float32)
223200

201+
# Restore scale
202+
w_scales_fp32 = w_scales.to(torch.float32)
203+
y = y * (x_scales_adj * w_scales_fp32.unsqueeze(0) * global_scale)
224204
y = y.to(activation_tensor.dtype).contiguous()
225205

226206
if bias is not None:

0 commit comments

Comments
 (0)