Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 39e0ad1

Browse files
committed
Update on "Add rowwwise scaling to Float8Inference module"
[ghstack-poisoned]
1 parent ce3baaf commit 39e0ad1

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

float8_experimental/float8_python_api.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,18 @@ def addmm_float8_unwrapped(
5151
)
5252
output += bias
5353
return output
54+
# Weight tensors are stored in N, K format. We call tensor_to_scale(dim=0)
55+
# which produces a (N, 1) Tensor. However scaled_mm syntactically expects
56+
# M X K @ K X N, and scales (M, 1) and (1, N)
57+
b_inverse_scale = (
58+
b_inverse_scale.T if b_inverse_scale.dim() == 2 else b_inverse_scale
59+
)
60+
5461
output = torch._scaled_mm(
5562
a_data,
5663
b_data,
5764
scale_a=a_inverse_scale,
58-
scale_b=b_inverse_scale.T,
65+
scale_b=b_inverse_scale,
5966
bias=bias,
6067
scale_result=output_scale,
6168
out_dtype=output_dtype,

0 commit comments

Comments
 (0)