Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit b8a68fb

Browse files
authored
Merge pull request #2 from pytorch-labs/add_bias
add bias support to float8linear
2 parents 30574b7 + 9ca16b8 commit b8a68fb

File tree

3 files changed

+66
-18
lines changed

3 files changed

+66
-18
lines changed

float8_playground/float8_aten_api.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,22 @@ def add_float8_e5m2(m1, s1, m2, s2, s3):
3939
m3_float32 = m1_float32 + m2_float32
4040
return float32_to_float8(m3_float32 * s3, E5M2)
4141

42+
# TODO naming of these vars is weird
43+
def addmm_float8(
44+
inp1, inp_s1, inp_flavor1, m1, s1, flavor1, m2, s2, flavor2,
45+
s3, flavor3):
46+
# naive implementation: dq -> op -> q
47+
# TODO(future): hook up to real kernel
48+
inp1_fp32 = float8_to_float32(inp1, inp_flavor1) / inp_s1
49+
m1_fp32 = float8_to_float32(m1, flavor1) / s1
50+
m2_fp32 = float8_to_float32(m2, flavor2) / s2
51+
m3_fp32 = torch.addmm(inp1_fp32, m1_fp32, m2_fp32)
52+
# TODO(future): switch to delayed scaling
53+
s3.fill_(tensor_to_scale(m3_fp32, flavor3))
54+
m3_fp32_scaled = m3_fp32 * s3
55+
return float32_to_float8(m3_fp32_scaled, flavor3)
56+
57+
4258
#
4359
# ATen op placeholders
4460
#
@@ -60,3 +76,6 @@ def add_float8_e5m2(m1, s1, m2, s2, s3):
6076

6177
lib.define("add_float8_e5m2(Tensor m1, Tensor s1, Tensor m2, Tensor s2, Tensor s3) -> Tensor")
6278
lib.impl("add_float8_e5m2", add_float8_e5m2, "CPU")
79+
80+
lib.define("addmm_float8(Tensor inp1, Tensor inp_s1, int inp_flavor1, Tensor m1, Tensor s1, int flavor1, Tensor m2, Tensor s2, int flavor2, Tensor s3, int flavor3) -> Tensor")
81+
lib.impl("addmm_float8", addmm_float8, "CPU")

float8_playground/float8_linear.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,25 +25,34 @@ def forward(
2525
ctx,
2626
x_fp8,
2727
w_fp8,
28+
b_fp8,
2829
fp8_s_out,
2930
fp8_s_dL_dX,
3031
fp8_s_dL_dW,
3132
fp8_s_dL_dY,
3233
):
33-
ctx.save_for_backward(x_fp8, w_fp8, fp8_s_dL_dX, fp8_s_dL_dW, fp8_s_dL_dY)
34-
35-
res_bits = torch.ops.aten.mm_float8(
36-
x_fp8._data, x_fp8._scale, x_fp8._flavor,
37-
w_fp8._data.t(), w_fp8._scale, w_fp8._flavor,
38-
fp8_s_out, E4M3)
34+
ctx.save_for_backward(
35+
x_fp8, w_fp8, b_fp8, fp8_s_dL_dX, fp8_s_dL_dW, fp8_s_dL_dY)
36+
if b_fp8 is not None:
37+
# TODO add this
38+
res_bits = torch.ops.aten.addmm_float8(
39+
b_fp8._data, b_fp8._scale, b_fp8._flavor,
40+
x_fp8._data, x_fp8._scale, x_fp8._flavor,
41+
w_fp8._data.t(), w_fp8._scale, w_fp8._flavor,
42+
fp8_s_out, E4M3)
43+
else:
44+
res_bits = torch.ops.aten.mm_float8(
45+
x_fp8._data, x_fp8._scale, x_fp8._flavor,
46+
w_fp8._data.t(), w_fp8._scale, w_fp8._flavor,
47+
fp8_s_out, E4M3)
3948

4049
res = Float8Tensor(res_bits, fp8_s_out, E4M3)
4150
# scale update would also happen here, for now no-op
4251
return res
4352

4453
@staticmethod
4554
def backward(ctx, go):
46-
x_fp8, w_fp8, fp8_s_dL_dX, fp8_s_dL_dW, fp8_s_dL_dY = \
55+
x_fp8, w_fp8, b_fp8, fp8_s_dL_dX, fp8_s_dL_dW, fp8_s_dL_dY = \
4756
ctx.saved_tensors
4857

4958
if not isinstance(go, Float8Tensor):
@@ -69,7 +78,10 @@ def backward(ctx, go):
6978
dL_dW_fp8 = Float8Tensor(dL_dW_bits, fp8_s_dL_dW, E5M2)
7079

7180
# scale update would also happen here, for now no-op
72-
return dL_dX_fp8, dL_dW_fp8, None, None, None, None
81+
if b_fp8 is not None:
82+
return dL_dX_fp8, dL_dW_fp8, go_fp8, None, None, None, None
83+
else:
84+
return dL_dX_fp8, dL_dW_fp8, None, None, None, None, None
7385

7486

7587
class Float8Linear(torch.nn.Linear):
@@ -86,6 +98,7 @@ def __init__(self, *args, **kwargs):
8698
# or PTQ calibration.
8799
self.register_buffer('fp8_s_in', torch.tensor(1.0))
88100
self.register_buffer('fp8_s_weight', torch.tensor(1.0))
101+
self.register_buffer('fp8_s_bias', torch.tensor(1.0))
89102
self.register_buffer('fp8_s_out', torch.tensor(1.0))
90103
self.register_buffer('fp8_s_dL_dX', torch.tensor(1.0))
91104
self.register_buffer('fp8_s_dL_dW', torch.tensor(1.0))
@@ -102,9 +115,13 @@ def forward(self, x):
102115
# TODO(future): switch to delayed scaling
103116
self.fp8_s_weight.fill_(tensor_to_scale(self.weight, E4M3))
104117
w_fp8 = Float8Tensor.from_float32(self.weight, self.fp8_s_weight, E4M3)
118+
maybe_b_fp8 = None
119+
if self.bias is not None:
120+
self.fp8_s_bias.fill_(tensor_to_scale(self.bias, E4M3))
121+
maybe_b_fp8 = Float8Tensor.from_float32(self.bias, self.fp8_s_bias, E4M3)
105122

106123
y_fp8 = float8_linear_no_bias.apply(
107-
x_fp8, w_fp8, self.fp8_s_out, self.fp8_s_dL_dX,
124+
x_fp8, w_fp8, maybe_b_fp8, self.fp8_s_out, self.fp8_s_dL_dX,
108125
self.fp8_s_dL_dW, self.fp8_s_dL_dY)
109126

110127
# For now, hardcode returning Float8Tensor (propagate as much as we can).
@@ -116,7 +133,7 @@ def from_float(cls, mod):
116133
"""
117134
Create an nn.Linear with fp8 compute from a regular nn.Linear
118135
"""
119-
assert mod.bias is None, 'bias support not implemented yet'
120136
new_mod = cls(mod.in_features, mod.out_features, bias=False)
121137
new_mod.weight = mod.weight
138+
new_mod.bias = mod.bias
122139
return new_mod

float8_playground/test.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -154,13 +154,10 @@ def test_add(self):
154154
self.assertTrue(sqnr >= 10.0)
155155

156156
class Float8LinearUnitTest(unittest.TestCase):
157+
def _test_linear_impl(self, x, m_ref):
157158

158-
def test_e2e(self):
159-
m_ref = nn.Linear(4, 4, bias=False)
160159
m_fp8 = Float8Linear.from_float(copy.deepcopy(m_ref))
161160

162-
x = torch.randn(4, 4)
163-
164161
y_fp8 = m_fp8(x)
165162
y_fp8.sum().backward()
166163
y_ref = m_ref(x)
@@ -170,23 +167,38 @@ def test_e2e(self):
170167
g_sqnr = compute_error(m_ref.weight.grad, m_fp8.weight.grad)
171168

172169
# verify sqnr is reasonable
173-
self.assertTrue(y_sqnr >= 27.0)
174-
self.assertTrue(g_sqnr >= 27.0)
170+
self.assertTrue(y_sqnr >= 24.0)
171+
self.assertTrue(g_sqnr >= 24.0)
172+
if m_ref.bias is not None:
173+
torch.testing.assert_close(m_ref.bias.grad, m_fp8.bias.grad)
175174

176175
# verify all of the scales got updated
177-
for buffer_name in (
176+
buffer_names = [
178177
'fp8_s_in',
179178
'fp8_s_weight',
180179
'fp8_s_out',
181180
'fp8_s_dL_dX',
182181
'fp8_s_dL_dW',
183182
'fp8_s_dL_dY',
184-
):
183+
]
184+
if m_ref.bias is not None:
185+
buffer_names.append('fp8_s_bias')
186+
for buffer_name in buffer_names:
185187
buffer_value = getattr(m_fp8, buffer_name)
186188
self.assertTrue(
187189
torch.ne(buffer_value, torch.tensor(1.0)),
188190
f"{buffer_name} not filled")
189191

192+
def test_linear_nobias(self):
193+
x = torch.randn(2, 3)
194+
m_ref = nn.Linear(3, 4, bias=False)
195+
self._test_linear_impl(x, m_ref)
196+
197+
def test_linear_bias(self):
198+
x = torch.randn(2, 3)
199+
m_ref = nn.Linear(3, 4, bias=True)
200+
self._test_linear_impl(x, m_ref)
201+
190202

191203
if __name__ == '__main__':
192204
unittest.main()

0 commit comments

Comments
 (0)