Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 146c6a1

Browse files
committed
still one failing test, but I think its a good failure
1 parent b383662 commit 146c6a1

File tree

5 files changed

+95
-41
lines changed

5 files changed

+95
-41
lines changed

float8_experimental/float8_linear.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,9 +302,17 @@ def forward(self, x):
302302

303303
x_fp8 = self.cast_x_to_float8(x, self.is_amax_initialized)
304304
w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized)
305-
y = self.float8_mm(x_fp8, w_fp8, self.is_amax_initialized)
306-
y = self.cast_y_to_float8_in_bw(y)
305+
# y = self.float8_mm(x_fp8, w_fp8, self.is_amax_initialized)
306+
if self.emulate:
307+
y = self.float8_mm(x_fp8, w_fp8, self.is_amax_initialized)
308+
else:
309+
orig_shape = x_fp8.shape
310+
x_fp8 = x_fp8.reshape(-1, orig_shape[-1])
311+
y = torch.mm(x_fp8, w_fp8.t())
312+
y = y.reshape(*orig_shape[:-1], y.shape[-1])
307313

314+
y = self.cast_y_to_float8_in_bw(y)
315+
# breakpoint()
308316
if self.bias is not None:
309317
y = y + self.bias.to(x_fp8._orig_dtype)
310318

float8_experimental/float8_ops.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
from typing import Any, Dict
2+
3+
import torch
4+
5+
from float8_experimental.float8_python_api import mm_float8_unwrapped
6+
from float8_experimental.float8_tensor import Float8Tensor
7+
from float8_experimental.float8_utils import (is_row_major, tensor_to_amax,
8+
to_fp8_saturated)
9+
10+
aten = torch.ops.aten
11+
FLOAT8_OPS_TABLE: Dict[Any, Any] = {}
12+
13+
14+
def implements(aten_ops):
15+
"""Register aten ops to the float8 op table"""
16+
17+
def decorator(func):
18+
for op in aten_ops:
19+
FLOAT8_OPS_TABLE[op] = func
20+
return func
21+
22+
return decorator
23+
24+
25+
@implements(
26+
[
27+
aten.view.default,
28+
aten._unsafe_view.default,
29+
aten.t.default,
30+
aten.as_strided.default,
31+
aten.clone.default,
32+
aten.detach.default,
33+
]
34+
)
35+
def float8_desugar_op(aten_op, args, kwargs=None):
36+
# assert is_fake(args[0]), "Float8Tensor.__torch_dispatch__ for user code is not supported"
37+
new_data = aten_op(args[0]._data, *args[1:], **kwargs)
38+
return Float8Tensor(new_data, args[0]._scale, args[0]._orig_dtype)
39+
40+
41+
@implements([aten.mm.default])
42+
def float8_mm(aten_op, args, kwargs=None):
43+
assert isinstance(args[0], Float8Tensor) and isinstance(args[1], Float8Tensor)
44+
a = args[0]
45+
b = args[1]
46+
a_data = a._data
47+
a_scale = a._scale
48+
b_data = b._data
49+
50+
if not is_row_major(a_data.stride()):
51+
a_data = a_data.contiguous()
52+
if is_row_major(b_data.stride()):
53+
b_data = b_data.t().contiguous().t()
54+
55+
b_scale = b._scale
56+
output_dtype = a._orig_dtype
57+
tensor_out, amax = mm_float8_unwrapped(
58+
a_data, a_scale, b_data, b_scale, output_dtype, output_scale=None
59+
)
60+
return tensor_out
61+
62+
63+
@implements([aten.is_same_size.default])
64+
def float8_is_same_size(aten_op, args, kwargs=None):
65+
return args[0].shape == args[1].shape
66+
67+
68+
@implements([aten._to_copy.default])
69+
def autocast_to_copy(aten_op, args, kwargs=None):
70+
# This is needed for auto cast behavior
71+
# TODO Also feels kind of sketch....
72+
assert isinstance(args[0], Float8Tensor)
73+
assert len(kwargs) == 1 and "dtype" in kwargs, "Only support dtype kwarg for autocast"
74+
return Float8Tensor(args[0]._data, args[0]._scale, kwargs["dtype"])

float8_experimental/float8_python_api.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import float8_experimental.float8_aten_api
1111
import torch
12-
from float8_experimental.float8_tensor import Float8Tensor
1312

1413

1514
def mm_float8_unwrapped(
@@ -45,8 +44,8 @@ def mm_float8_unwrapped(
4544
# For a,b going from fp8 -> fp32 we multiple by the inverse of the scale
4645
# For output going from fp32 -> fp8 we multiply by the scale
4746
def mm_float8(
48-
a: Float8Tensor, # input 1
49-
b: Float8Tensor, # input 2
47+
a: "Float8Tensor", # input 1
48+
b: "Float8Tensor", # input 2
5049
output_dtype: torch.dtype, # output dtype
5150
output_scale: Optional[torch.Tensor] = None, # output scale, precomputed
5251
emulate: bool = False, # whether to emulate the operation using fp32

float8_experimental/float8_tensor.py

Lines changed: 6 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,11 @@
1-
from typing import Any, Dict
1+
from typing import Dict
22

33
import torch
4+
45
from float8_experimental.float8_utils import tensor_to_amax, to_fp8_saturated
5-
from torch._subclasses.fake_tensor import is_fake
66

77
aten = torch.ops.aten
88

9-
FLOAT8_OPS_TABLE: Dict[Any, Any] = {}
10-
11-
12-
def implements(aten_ops):
13-
"""Register aten ops to the float8 op table"""
14-
15-
def decorator(func):
16-
for op in aten_ops:
17-
FLOAT8_OPS_TABLE[op] = func
18-
return func
19-
20-
return decorator
21-
22-
23-
@implements(
24-
[
25-
aten.view.default,
26-
aten._unsafe_view.default,
27-
aten.t.default,
28-
aten.as_strided.default,
29-
aten.clone.default,
30-
aten.detach.default,
31-
]
32-
)
33-
def float8_desugar_op(aten_op, args, kwargs=None):
34-
assert is_fake(args[0]), "Float8Tensor.__torch_dispatch__ for user code is not supported"
35-
new_data = aten_op(args[0]._data, *args[1:], **kwargs)
36-
return Float8Tensor(new_data, args[0]._scale, args[0]._orig_dtype)
37-
38-
39-
@implements([aten.is_same_size.default])
40-
def float8_is_same_size(aten_op, args, kwargs=None):
41-
return args[0].shape == args[1].shape
42-
439

4410
class ToFloat8ConstrFunc(torch.autograd.Function):
4511
"""
@@ -166,6 +132,10 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None):
166132
# PT2.0, so we explicitly disallow it here for callsites from user code.
167133
# 2. We do need to handle a couple of ops in order for
168134
# TorchDynamo tracing to succeed.
135+
136+
# Lazy import to avoid circular dependency
137+
from float8_experimental.float8_ops import FLOAT8_OPS_TABLE
138+
169139
if func in FLOAT8_OPS_TABLE:
170140
return FLOAT8_OPS_TABLE[func](func, args, kwargs)
171141
raise NotImplementedError(f"attempting to run {func}, this is not supported")

float8_experimental/float8_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,6 @@ def compute_error(x, y):
8282
Ps = torch.norm(x)
8383
Pn = torch.norm(x - y)
8484
return 20 * torch.log10(Ps / Pn)
85+
86+
def is_row_major(stride):
87+
return stride[0] > stride[1] and stride[1] == 1

0 commit comments

Comments
 (0)