Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 3d9231d

Browse files
authored
Merge pull request #4 from pytorch-labs/real_dtypes
switch from emulated to real float8 dtypes
2 parents 00a649c + dcc039d commit 3d9231d

File tree

5 files changed

+70
-359
lines changed

5 files changed

+70
-359
lines changed

float8_playground/float8_aten_api.py

Lines changed: 24 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7,52 +7,51 @@
77
from torch.library import Library
88

99
from float8_utils import (
10-
float32_to_float8,
11-
float8_to_float32,
12-
E4M3,
13-
E5M2,
1410
tensor_to_scale,
1511
)
1612

1713

18-
def mm_float8(m1, s1, flavor1, m2, s2, flavor2, s3, flavor3):
14+
def mm_float8(m1, s1, m2, s2, s3, dtype3):
1915
# naive implementation: dq -> op -> q
2016
# TODO(future): hook up to real kernel
21-
m1_fp32 = float8_to_float32(m1, flavor1) / s1
22-
m2_fp32 = float8_to_float32(m2, flavor2) / s2
17+
m1_fp32 = m1.float() / s1
18+
m2_fp32 = m2.float() / s2
2319
m3_fp32 = torch.mm(m1_fp32, m2_fp32)
2420
# TODO(future): switch to delayed scaling
25-
s3.fill_(tensor_to_scale(m3_fp32, flavor3))
21+
s3.fill_(tensor_to_scale(m3_fp32, dtype3))
2622
m3_fp32_scaled = m3_fp32 * s3
27-
return float32_to_float8(m3_fp32_scaled, flavor3)
23+
if dtype3 == torch.float8_e4m3fn:
24+
return m3_fp32_scaled.to(torch.float8_e4m3fn)
25+
else:
26+
return m3_fp32_scaled.to(torch.float8_e5m2)
2827

2928
def add_float8_e5m2(m1, s1, m2, s2, s3):
3029
# for now this is only implemented for e5m2 because we only care about
3130
# this for adding gradients
3231
# naive implementation: dq -> op -> q
3332
# TODO(future): hook up to real kernel
34-
# TODO(future): make this more accurate, accuracy is pretty low,
35-
# can probably just calculate s3 dynamically since this is an edge case
36-
# unlikely to affect e2e performance
37-
m1_float32 = float8_to_float32(m1, E5M2) / s1
38-
m2_float32 = float8_to_float32(m2, E5M2) / s2
33+
m1_float32 = m1.float() / s1
34+
m2_float32 = m2.float() / s2
3935
m3_float32 = m1_float32 + m2_float32
40-
return float32_to_float8(m3_float32 * s3, E5M2)
36+
s3_val = tensor_to_scale(m3_float32, torch.float8_e5m2)
37+
s3.fill_(s3_val)
38+
return (m3_float32 * s3).to(torch.float8_e5m2)
4139

4240
# TODO naming of these vars is weird
43-
def addmm_float8(
44-
inp1, inp_s1, inp_flavor1, m1, s1, flavor1, m2, s2, flavor2,
45-
s3, flavor3):
41+
def addmm_float8(inp1, inp_s1, m1, s1, m2, s2, s3, dtype3):
4642
# naive implementation: dq -> op -> q
4743
# TODO(future): hook up to real kernel
48-
inp1_fp32 = float8_to_float32(inp1, inp_flavor1) / inp_s1
49-
m1_fp32 = float8_to_float32(m1, flavor1) / s1
50-
m2_fp32 = float8_to_float32(m2, flavor2) / s2
44+
inp1_fp32 = inp1.float() / inp_s1
45+
m1_fp32 = m1.float() / s1
46+
m2_fp32 = m2.float() / s2
5147
m3_fp32 = torch.addmm(inp1_fp32, m1_fp32, m2_fp32)
5248
# TODO(future): switch to delayed scaling
53-
s3.fill_(tensor_to_scale(m3_fp32, flavor3))
49+
s3.fill_(tensor_to_scale(m3_fp32, dtype3))
5450
m3_fp32_scaled = m3_fp32 * s3
55-
return float32_to_float8(m3_fp32_scaled, flavor3)
51+
if dtype3 == torch.float8_e4m3fn:
52+
return m3_fp32_scaled.to(torch.float8_e4m3fn)
53+
else:
54+
return m3_fp32_scaled.to(torch.float8_e5m2)
5655

5756

5857
#
@@ -65,17 +64,11 @@ def addmm_float8(
6564

6665
# For now register on CPU,
6766
# TODO(future) add GPU and test there
68-
lib.define("float32_to_float8(Tensor t, int flavor) -> Tensor")
69-
lib.impl("float32_to_float8", float32_to_float8, "CPU")
70-
71-
lib.define("float8_to_float32(Tensor t, int flavor) -> Tensor")
72-
lib.impl("float8_to_float32", float8_to_float32, "CPU")
73-
74-
lib.define("mm_float8(Tensor m1, Tensor s1, int flavor1, Tensor m2, Tensor s2, int flavor2, Tensor s3, int flavor3) -> Tensor")
67+
lib.define("mm_float8(Tensor m1, Tensor s1, Tensor m2, Tensor s2, Tensor s3, int dtype3) -> Tensor")
7568
lib.impl("mm_float8", mm_float8, "CPU")
7669

7770
lib.define("add_float8_e5m2(Tensor m1, Tensor s1, Tensor m2, Tensor s2, Tensor s3) -> Tensor")
7871
lib.impl("add_float8_e5m2", add_float8_e5m2, "CPU")
7972

80-
lib.define("addmm_float8(Tensor inp1, Tensor inp_s1, int inp_flavor1, Tensor m1, Tensor s1, int flavor1, Tensor m2, Tensor s2, int flavor2, Tensor s3, int flavor3) -> Tensor")
73+
lib.define("addmm_float8(Tensor inp1, Tensor inp_s1, Tensor m1, Tensor s1, Tensor m2, Tensor s2, Tensor s3, int dtype3) -> Tensor")
8174
lib.impl("addmm_float8", addmm_float8, "CPU")

float8_playground/float8_linear.py

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import float8_aten_api
1313

14-
from float8_utils import E4M3, E5M2, tensor_to_scale
14+
from float8_utils import tensor_to_scale
1515
from float8_tensor import Float8Tensor
1616

1717
class float8_linear(torch.autograd.Function):
@@ -33,19 +33,18 @@ def forward(
3333
ctx.save_for_backward(
3434
x_fp8, w_fp8, b_fp8, fp8_s_dL_dX, fp8_s_dL_dW, fp8_s_dL_dY)
3535
if b_fp8 is not None:
36-
# TODO add this
3736
res_bits = torch.ops.aten.addmm_float8(
38-
b_fp8._data, b_fp8._scale, b_fp8._flavor,
39-
x_fp8._data, x_fp8._scale, x_fp8._flavor,
40-
w_fp8._data.t(), w_fp8._scale, w_fp8._flavor,
41-
fp8_s_out, E4M3)
37+
b_fp8._data, b_fp8._scale,
38+
x_fp8._data, x_fp8._scale,
39+
w_fp8._data.t(), w_fp8._scale,
40+
fp8_s_out, torch.float8_e4m3fn)
4241
else:
4342
res_bits = torch.ops.aten.mm_float8(
44-
x_fp8._data, x_fp8._scale, x_fp8._flavor,
45-
w_fp8._data.t(), w_fp8._scale, w_fp8._flavor,
46-
fp8_s_out, E4M3)
43+
x_fp8._data, x_fp8._scale,
44+
w_fp8._data.t(), w_fp8._scale,
45+
fp8_s_out, torch.float8_e4m3fn)
4746

48-
res = Float8Tensor(res_bits, fp8_s_out, E4M3)
47+
res = Float8Tensor(res_bits, fp8_s_out)
4948
# scale update would also happen here, for now no-op
5049
return res
5150

@@ -56,25 +55,24 @@ def backward(ctx, go):
5655

5756
if not isinstance(go, Float8Tensor):
5857
# TODO(future): switch to delayed scaling
59-
fp8_s_dL_dY.fill_(tensor_to_scale(go, E5M2))
58+
fp8_s_dL_dY.fill_(tensor_to_scale(go, torch.float8_e5m2))
6059
go_fp8 = Float8Tensor(
61-
torch.ops.aten.float32_to_float8(go * fp8_s_dL_dY, E5M2),
62-
fp8_s_dL_dY,
63-
E5M2)
60+
(go * fp8_s_dL_dY).to(torch.float8_e5m2),
61+
fp8_s_dL_dY)
6462
else:
6563
go_fp8 = go
6664

6765
dL_dX_bits = torch.ops.aten.mm_float8(
68-
go_fp8._data, go_fp8._scale, go_fp8._flavor,
69-
w_fp8._data, w_fp8._scale, w_fp8._flavor,
70-
fp8_s_dL_dX, E5M2)
71-
dL_dX_fp8 = Float8Tensor(dL_dX_bits, fp8_s_dL_dX, E5M2)
66+
go_fp8._data, go_fp8._scale,
67+
w_fp8._data, w_fp8._scale,
68+
fp8_s_dL_dX, torch.float8_e5m2)
69+
dL_dX_fp8 = Float8Tensor(dL_dX_bits, fp8_s_dL_dX)
7270

7371
dL_dW_bits = torch.ops.aten.mm_float8(
74-
x_fp8._data.t(), x_fp8._scale, x_fp8._flavor,
75-
go_fp8._data, go_fp8._scale, go_fp8._flavor,
76-
fp8_s_dL_dW, E5M2).t()
77-
dL_dW_fp8 = Float8Tensor(dL_dW_bits, fp8_s_dL_dW, E5M2)
72+
x_fp8._data.t(), x_fp8._scale,
73+
go_fp8._data, go_fp8._scale,
74+
fp8_s_dL_dW, torch.float8_e5m2).t()
75+
dL_dW_fp8 = Float8Tensor(dL_dW_bits, fp8_s_dL_dW)
7876

7977
# scale update would also happen here, for now no-op
8078
if b_fp8 is not None:
@@ -106,18 +104,18 @@ def __init__(self, *args, **kwargs):
106104
def forward(self, x):
107105
if not isinstance(x, Float8Tensor):
108106
# TODO(future): switch to delayed scaling
109-
self.fp8_s_in.fill_(tensor_to_scale(x, E4M3))
110-
x_fp8 = Float8Tensor.from_float32(x, self.fp8_s_in, E4M3)
107+
self.fp8_s_in.fill_(tensor_to_scale(x, torch.float8_e4m3fn))
108+
x_fp8 = Float8Tensor.from_float32(x, self.fp8_s_in, torch.float8_e4m3fn)
111109
else:
112110
x_fp8 = x
113111

114112
# TODO(future): switch to delayed scaling
115-
self.fp8_s_weight.fill_(tensor_to_scale(self.weight, E4M3))
116-
w_fp8 = Float8Tensor.from_float32(self.weight, self.fp8_s_weight, E4M3)
113+
self.fp8_s_weight.fill_(tensor_to_scale(self.weight, torch.float8_e4m3fn))
114+
w_fp8 = Float8Tensor.from_float32(self.weight, self.fp8_s_weight, torch.float8_e4m3fn)
117115
maybe_b_fp8 = None
118116
if self.bias is not None:
119-
self.fp8_s_bias.fill_(tensor_to_scale(self.bias, E4M3))
120-
maybe_b_fp8 = Float8Tensor.from_float32(self.bias, self.fp8_s_bias, E4M3)
117+
self.fp8_s_bias.fill_(tensor_to_scale(self.bias, torch.float8_e4m3fn))
118+
maybe_b_fp8 = Float8Tensor.from_float32(self.bias, self.fp8_s_bias, torch.float8_e4m3fn)
121119

122120
y_fp8 = float8_linear.apply(
123121
x_fp8, w_fp8, maybe_b_fp8, self.fp8_s_out, self.fp8_s_dL_dX,

float8_playground/float8_tensor.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
import torch
33
from torch.utils._pytree import tree_map
44

5-
from float8_utils import E4M3, E5M2
6-
75
aten = torch.ops.aten
86

97
class Float8ConstrFunc(torch.autograd.Function):
@@ -12,15 +10,15 @@ class Float8ConstrFunc(torch.autograd.Function):
1210
TODO(future): split into two for cleaner code
1311
"""
1412
@staticmethod
15-
def forward(ctx, tensor, scale: float=None, flavor=E4M3):
13+
def forward(ctx, tensor, scale: float=None, dtype=torch.float8_e4m3fn):
1614
if isinstance(tensor, Float8Tensor):
1715
ctx.inp_is_float8 = True
18-
return torch.ops.aten.float8_to_float32(tensor._data, tensor._flavor) / tensor._scale
16+
return tensor._data.to(torch.float32) / tensor._scale
1917
else:
2018
ctx.inp_is_float8 = False
2119
tensor_scaled = tensor * scale
22-
bits_fp8 = torch.ops.aten.float32_to_float8(tensor_scaled, flavor)
23-
return Float8Tensor(bits_fp8, scale, flavor)
20+
bits_fp8 = tensor_scaled.to(dtype)
21+
return Float8Tensor(bits_fp8, scale)
2422

2523
@staticmethod
2624
def backward(ctx, g):
@@ -41,7 +39,6 @@ class Float8Tensor(torch.Tensor):
4139
* `_scale`: the scale used to scale the original fp32 tensor. We multiply
4240
by scale to go from fp32 range to fp8 range, and divide by scale to go
4341
from fp8 range to fp32 range.
44-
* `_flavor`: either E4M3 or E5M2
4542
4643
The current purpose of this object is 99% to bundle raw data + fp8 metadata
4744
together for easy passing through PyTorch systems, and 1% to implement
@@ -57,11 +54,9 @@ class Float8Tensor(torch.Tensor):
5754
to fp32 for them.
5855
"""
5956

60-
def __new__(cls, data, scale, flavor):
57+
def __new__(cls, data, scale):
6158
# This is a non-differentiable constructor!
6259
assert not data.requires_grad
63-
# TODO(future): make bits8 easier to work with and switch to using it
64-
# assert data.dtype == torch.bits8
6560
assert scale.dtype == torch.float32
6661
assert scale.nelement() == 1
6762

@@ -77,19 +72,18 @@ def __new__(cls, data, scale, flavor):
7772
)
7873
self._data = data
7974
self._scale = scale
80-
self._flavor = flavor
8175

8276
return self
8377

8478
def __repr__(self):
85-
return f"Float8Tensor(flavor={self._flavor}, scale={self._scale}, as_float32={self.to_float32()}"
79+
return f"Float8Tensor(dtype={self._data.dtype}, scale={self._scale}, as_float32={self.to_float32()}"
8680

8781
def to_float32(self):
8882
return Float8ConstrFunc.apply(self)
8983

9084
@classmethod
91-
def from_float32(cls, tensor, scale, flavor):
92-
return Float8ConstrFunc.apply(tensor, scale, flavor)
85+
def from_float32(cls, tensor, scale, dtype):
86+
return Float8ConstrFunc.apply(tensor, scale, dtype)
9387

9488
@classmethod
9589
def __torch_dispatch__(cls, func, types, args, kwargs=None):
@@ -113,14 +107,14 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None):
113107
and isinstance(args[1], Float8Tensor)
114108
):
115109
x1_fp8, x2_fp8 = args[0], args[1]
116-
assert x1_fp8._flavor == E5M2 and x2_fp8._flavor == E5M2
117-
# naive scale calculation: max of incoming two scales
118-
x3_scale = torch.max(x1_fp8._scale, x2_fp8._scale)
110+
assert x1_fp8._data.dtype == torch.float8_e5m2 and x2_fp8._data.dtype == torch.float8_e5m2
111+
# scale will be filled in by the kernel, not using delayed scaling
112+
x3_scale = torch.empty(1)
119113
res_bits = torch.ops.aten.add_float8_e5m2(
120114
x1_fp8._data, x1_fp8._scale,
121115
x2_fp8._data, x2_fp8._scale,
122116
x3_scale)
123-
res = Float8Tensor(res_bits, x3_scale, x1_fp8._flavor)
117+
res = Float8Tensor(res_bits, x3_scale)
124118
return res
125119

126120
# for all other ops, fall back to fp32

0 commit comments

Comments
 (0)