Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit b518e2a

Browse files
drisspgfacebook-github-bot
authored andcommitted
Use mm in subclass (#128)
Summary: We use the dispatching mechanism to mm ## TODO - [x] Hook on to float8_tensor the amax_buffer to be filled under dispatch - [x] Update emulate path # Note Vasiliy has already started this here: #28 Some things have changed though since then, we are outputing by default in higher precision. However I still need to replicate the amax_buffer filling here and store on float8_tensor passed in Corresponding core changes to get as far as possible in compile for aot_eager pytorch/pytorch#111735 ``` Shell Checking against fake_mode=<torch._subclasses.fake_tensor.FakeTensorMode object at 0x7f4c13cd1bd0> attr=_data attr_fake_mode=<torch._subclasses.fake_tensor.FakeTensorMode object at 0x7f4c13c34d00> attr=_scale attr_fake_mode=<torch._subclasses.fake_tensor.FakeTensorMode object at 0x7f4c13c34d00> ``` ### Current Compile Progress - backend = "eager_only", full_graph = False: ✅ - backend = "eager_only", full_graph = False: ❌ ``` Shell E torch._dynamo.exc.Unsupported: call_function UserDefinedObjectVariable(to_float8) [TensorVariable(), TensorVariable(), ConstantVariable(dtype), TensorVariable()] {} ``` - backend = "aot_eager", full_graph = False: ❌ ``` Shell File "/home/drisspg/meta/pytorch/torch/_functorch/aot_autograd.py", line 4187, in convert assert all(getattr(x, attr).fake_mode is fake_mode for attr in attrs) torch._dynamo.exc.BackendCompilerFailed: backend='aot_eager' raised: AssertionError: ``` Pull Request resolved: #128 Reviewed By: bdhirsh, y-sq Differential Revision: D50901900 Pulled By: drisspg fbshipit-source-id: 64626bc652b70bfbabff2ab26e999324d1463e1d
1 parent 791b2bd commit b518e2a

File tree

10 files changed

+220
-216
lines changed

10 files changed

+220
-216
lines changed

float8_experimental/float8_linear.py

Lines changed: 37 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
)
1919

2020
from float8_experimental.float8_python_api import mm_float8
21-
from float8_experimental.float8_tensor import Float8Tensor, to_float8
21+
from float8_experimental.float8_tensor import Float8Tensor
2222

2323
from float8_experimental.float8_utils import (
2424
amax_history_to_scale,
@@ -44,10 +44,12 @@ def forward(
4444
fp8_scale_dL_dY,
4545
scale_fn_name,
4646
is_amax_initialized,
47+
emulate: bool,
4748
):
4849
ctx.save_for_backward(fp8_amax_dL_dY, fp8_amax_history_dL_dY, fp8_scale_dL_dY)
4950
ctx.scale_fn_name = scale_fn_name
5051
ctx.is_amax_initialized = is_amax_initialized
52+
ctx.emulate = emulate
5153
return tensor
5254

5355
@staticmethod
@@ -69,99 +71,11 @@ def backward(ctx, go):
6971
fp8_amax_dL_dY.fill_(tensor_to_amax(go))
7072
go_scaled = go * fp8_scale_dL_dY
7173
bits_fp8 = to_fp8_saturated(go_scaled, torch.float8_e5m2)
72-
empty_grads = None, None, None, None, None
73-
res = Float8Tensor(bits_fp8, fp8_scale_dL_dY, go.dtype)
74+
empty_grads = None, None, None, None, None, None
75+
res = Float8Tensor(bits_fp8, fp8_scale_dL_dY, go.dtype, emulate=ctx.emulate)
7476
return res, *empty_grads
7577

7678

77-
class float8_linear(torch.autograd.Function):
78-
"""
79-
Like F.linear, but with X and W in float8
80-
"""
81-
82-
@staticmethod
83-
def forward(
84-
ctx,
85-
x_fp8,
86-
w_fp8,
87-
is_amax_initialized,
88-
scale_fn_name,
89-
emulate: bool,
90-
):
91-
ctx.save_for_backward(x_fp8, w_fp8)
92-
ctx.scale_fn_name = scale_fn_name
93-
ctx.emulate = emulate
94-
orig_shape = x_fp8._data.shape
95-
x_fp8_reshaped = Float8Tensor(
96-
x_fp8._data.reshape(-1, orig_shape[-1]), x_fp8._scale, x_fp8._orig_dtype
97-
)
98-
ctx.is_amax_initialized = is_amax_initialized
99-
100-
w_fp8_t = Float8Tensor(w_fp8._data.t(), w_fp8._scale, w_fp8._orig_dtype)
101-
102-
res_bits, _output_amax = mm_float8(
103-
x_fp8_reshaped, w_fp8_t, output_dtype=x_fp8._orig_dtype, emulate=emulate
104-
)
105-
res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1])
106-
return res_bits
107-
108-
@staticmethod
109-
def backward(ctx, go_fp8):
110-
x_fp8, w_fp8 = ctx.saved_tensors
111-
scale_fn_name = ctx.scale_fn_name
112-
emulate = ctx.emulate
113-
is_amax_initialized = ctx.is_amax_initialized
114-
115-
go_fp8_orig_shape = go_fp8._data.shape
116-
go_fp8_reshaped = Float8Tensor(
117-
go_fp8._data.reshape(-1, go_fp8_orig_shape[-1]),
118-
go_fp8._scale,
119-
go_fp8._orig_dtype,
120-
)
121-
122-
w_fp8_t_c_t = Float8Tensor(
123-
w_fp8._data.t().contiguous().t(), w_fp8._scale, w_fp8._orig_dtype
124-
)
125-
126-
#
127-
# calculate dL/dX
128-
#
129-
dL_dX, _dL_dX_amax = mm_float8(
130-
go_fp8_reshaped,
131-
w_fp8_t_c_t,
132-
output_dtype=x_fp8._orig_dtype,
133-
emulate=emulate,
134-
)
135-
dL_dX = dL_dX.reshape(*go_fp8_orig_shape[:-1], dL_dX.shape[-1])
136-
137-
x_fp8_orig_shape = x_fp8._data.shape
138-
x_fp8_reshaped_t_c = Float8Tensor(
139-
x_fp8._data.reshape(-1, x_fp8_orig_shape[-1]).t().contiguous(),
140-
x_fp8._scale,
141-
x_fp8._orig_dtype,
142-
)
143-
144-
go_fp8_reshaped_t_c_t = Float8Tensor(
145-
go_fp8_reshaped._data.t().contiguous().t(),
146-
go_fp8_reshaped._scale,
147-
go_fp8_reshaped._orig_dtype,
148-
)
149-
150-
#
151-
# calculate dL/dW
152-
#
153-
dL_dW, _dL_dW_amax = mm_float8(
154-
x_fp8_reshaped_t_c,
155-
go_fp8_reshaped_t_c_t,
156-
output_dtype=x_fp8._orig_dtype,
157-
emulate=emulate,
158-
)
159-
dL_dW = dL_dW.t()
160-
161-
empty_grads = None, None, None, None, None, None, None, None, None
162-
return dL_dX, dL_dW, *empty_grads
163-
164-
16579
@dataclasses.dataclass
16680
class DelayedScalingRecipe:
16781
# Controls the history length of amax buffers
@@ -221,13 +135,17 @@ def __init__(self, *args, **kwargs):
221135
# will access the scale when it has ensured that it is on GPU.
222136
self._float8_tensor_ctor = lambda *args, **kwargs: Float8Tensor(*args, **kwargs)
223137

224-
def cast_x_to_float8(self, x, is_amax_initialized):
138+
def cast_x_to_float8(
139+
self, x: torch.Tensor, is_amax_initialized: bool
140+
) -> torch.Tensor:
225141
# Duplicate the autocast logic for F.linear, so that the output
226142
# of our module has the right original precision
227143
if torch.is_autocast_enabled():
228144
# For now, hardcode to GPU's autocast dtype
229145
# if we need CPU support in the future, we can add it
230-
x = x.to(torch.get_autocast_gpu_dtype())
146+
autocast_dtype = torch.get_autocast_gpu_dtype()
147+
x = x.to(autocast_dtype)
148+
self.bias_dtype = autocast_dtype
231149

232150
scale_fn_name = self.recipe.scale_fn_name
233151
_maybe_initialize_amaxes_scales_for_float8_cast(
@@ -239,10 +157,14 @@ def cast_x_to_float8(self, x, is_amax_initialized):
239157
torch.float8_e4m3fn,
240158
is_amax_initialized,
241159
)
242-
x_fp8 = to_float8(x, self.fp8_scale_x, torch.float8_e4m3fn, self.fp8_amax_x)
160+
x_fp8 = Float8Tensor.to_float8(
161+
x, self.fp8_scale_x, torch.float8_e4m3fn, self.fp8_amax_x, self.emulate
162+
)
243163
return x_fp8
244164

245-
def cast_w_to_float8(self, w, is_amax_initialized):
165+
def cast_w_to_float8(
166+
self, w: torch.Tensor, is_amax_initialized: bool
167+
) -> torch.Tensor:
246168
scale_fn_name = self.recipe.scale_fn_name
247169
_maybe_initialize_amaxes_scales_for_float8_cast(
248170
w,
@@ -253,10 +175,14 @@ def cast_w_to_float8(self, w, is_amax_initialized):
253175
torch.float8_e4m3fn,
254176
is_amax_initialized,
255177
)
256-
w_fp8 = to_float8(w, self.fp8_scale_w, torch.float8_e4m3fn, self.fp8_amax_w)
178+
w_fp8 = Float8Tensor.to_float8(
179+
w, self.fp8_scale_w, torch.float8_e4m3fn, self.fp8_amax_w, self.emulate
180+
)
257181
return w_fp8
258182

259-
def cast_y_to_float8_in_bw(self, y):
183+
def cast_y_to_float8_in_bw(
184+
self, y: torch.Tensor, emulate: bool = False
185+
) -> torch.Tensor:
260186
scale_fn_name = self.recipe.scale_fn_name
261187
y = NoopFwToFloat8E5M2Bw.apply(
262188
y,
@@ -265,13 +191,7 @@ def cast_y_to_float8_in_bw(self, y):
265191
self.fp8_scale_dL_dY,
266192
scale_fn_name,
267193
self.is_amax_initialized,
268-
)
269-
return y
270-
271-
def float8_mm(self, x_fp8, w_fp8, is_amax_initialized):
272-
scale_fn_name = self.recipe.scale_fn_name
273-
y = float8_linear.apply(
274-
x_fp8, w_fp8, is_amax_initialized, scale_fn_name, self.emulate
194+
emulate,
275195
)
276196
return y
277197

@@ -292,6 +212,11 @@ def float8_post_forward(self):
292212
self.is_amax_initialized = True
293213
self.amax_and_scale_synced = False
294214

215+
def add_weight_tag(self):
216+
# We add a tag to the weight nn.Parameter in order to signal
217+
# To FSDP that this param is a weight
218+
self.weight._is_fp8_weight = True
219+
295220

296221
class Float8Linear(Float8LinearMixin, torch.nn.Linear):
297222
"""
@@ -311,11 +236,14 @@ def forward(self, x):
311236
w_fp8 = self._w_fp8
312237
else:
313238
w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized)
314-
y = self.float8_mm(x_fp8, w_fp8, self.is_amax_initialized)
315-
y = self.cast_y_to_float8_in_bw(y)
239+
240+
y = torch.matmul(x_fp8, w_fp8.t())
241+
242+
# Cast gradY to float8_e5m2 during backward
243+
y = self.cast_y_to_float8_in_bw(y, self.emulate)
316244

317245
if self.bias is not None:
318-
y = y + self.bias.to(x_fp8._orig_dtype)
246+
y = y + self.bias.to(self.bias_dtype)
319247

320248
self.float8_post_forward()
321249
return y
@@ -336,16 +264,13 @@ def from_float(cls, mod, emulate: bool = False):
336264
new_mod.weight = mod.weight
337265
new_mod.bias = mod.bias
338266
new_mod.emulate = emulate
267+
if mod.bias is not None:
268+
new_mod.bias_dtype = mod.bias.dtype
339269
# I think its okay to send all params and buffers to device
340270
new_mod.to(mod.weight.device)
341271
new_mod.add_weight_tag()
342272
return new_mod
343273

344-
def add_weight_tag(self):
345-
# We add a tag to the weight nn.Parameter in order to signal
346-
# To FSDP that this param is a weight
347-
self.weight._is_fp8_weight = True
348-
349274

350275
def swap_linear_with_float8_linear(
351276
model,

float8_experimental/float8_ops.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
from typing import Any, Dict
2+
3+
import torch
4+
from float8_experimental.float8_python_api import mm_float8_unwrapped
5+
from float8_experimental.float8_tensor import Float8Tensor
6+
from float8_experimental.float8_utils import is_row_major
7+
8+
aten = torch.ops.aten
9+
FLOAT8_OPS_TABLE: Dict[Any, Any] = {}
10+
11+
12+
def implements(aten_ops):
13+
"""Register aten ops to the float8 op table"""
14+
15+
def decorator(func):
16+
for op in aten_ops:
17+
FLOAT8_OPS_TABLE[op] = func
18+
return func
19+
20+
return decorator
21+
22+
23+
@implements(
24+
[
25+
aten.view.default,
26+
aten._unsafe_view.default,
27+
aten.t.default,
28+
aten.as_strided.default,
29+
aten.clone.default,
30+
aten.detach.default,
31+
]
32+
)
33+
def float8_desugar_op(aten_op, args, kwargs=None):
34+
new_data = aten_op(args[0]._data, *args[1:], **kwargs)
35+
return Float8Tensor(new_data, args[0]._scale, args[0]._orig_dtype, args[0]._emulate)
36+
37+
38+
@implements([aten.mm.default])
39+
def float8_mm(aten_op, args, kwargs=None):
40+
assert isinstance(args[0], Float8Tensor) and isinstance(args[1], Float8Tensor)
41+
a = args[0]
42+
b = args[1]
43+
a_data = a._data
44+
a_scale = a._scale
45+
b_data = b._data
46+
47+
if not is_row_major(a_data.stride()):
48+
a_data = a_data.contiguous()
49+
if is_row_major(b_data.stride()):
50+
b_data = b_data.t().contiguous().t()
51+
b_scale = b._scale
52+
output_dtype = a._orig_dtype
53+
if a._emulate:
54+
assert a._emulate == b._emulate
55+
return torch.ops.aten.mm_float8_emulated(
56+
a._data, a._scale, b._data, b._scale, output_dtype
57+
)[0]
58+
tensor_out, amax = mm_float8_unwrapped(
59+
a_data, a_scale, b_data, b_scale, output_dtype, output_scale=None
60+
)
61+
return tensor_out
62+
63+
64+
@implements([aten.is_same_size.default])
65+
def float8_is_same_size(aten_op, args, kwargs=None):
66+
return args[0].shape == args[1].shape
67+
68+
69+
@implements([aten._to_copy.default])
70+
def autocast_to_copy(aten_op, args, kwargs=None):
71+
"""This gets called when running matmul under autocast
72+
when the input is a Float8Tensor, presenting as a fp32
73+
tensor.
74+
"""
75+
assert isinstance(args[0], Float8Tensor)
76+
assert (
77+
len(kwargs) == 1 and "dtype" in kwargs
78+
), "Only support dtype kwarg for autocast"
79+
assert (
80+
kwargs["dtype"] == torch.float16
81+
), "Only support floating point conversion for autocast w/ Float8Tensor"
82+
return Float8Tensor(
83+
args[0]._data, args[0]._scale, kwargs["dtype"], args[0]._emulate
84+
)

float8_experimental/float8_python_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
to simplify the product code.
55
"""
66

7-
import warnings
7+
88
from typing import Optional, Tuple
99

1010
import float8_experimental.float8_aten_api

0 commit comments

Comments
 (0)