Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 429a313

Browse files
drisspgfacebook-github-bot
authored andcommitted
Fix graph breaks in tensor subclass (#131)
Summary: For more detailed understanding of status see: #106 But this removes all graph breaks on the main work branch Pull Request resolved: #131 Reviewed By: albanD Differential Revision: D50758815 Pulled By: drisspg fbshipit-source-id: 1502601099988b1eba666306e327eb724eb14989
1 parent 436102b commit 429a313

File tree

4 files changed

+29
-29
lines changed

4 files changed

+29
-29
lines changed

float8_experimental/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
# Lets wait to define the top level interface
2-
# from float8_experimental.float8_tensor import Float8Tensor
1+
# Lets define a few top level things here
2+
from float8_experimental.float8_tensor import Float8Tensor
3+
from float8_experimental.float8_linear import Float8Linear
34

4-
# __all__ = ["Float8Tensor"]
5+
__all__ = ["Float8Tensor", "Float8Linear"]

float8_experimental/float8_linear.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
)
1919

2020
from float8_experimental.float8_python_api import mm_float8
21-
from float8_experimental.float8_tensor import Float8Tensor
21+
from float8_experimental.float8_tensor import Float8Tensor, to_float8
2222

2323
from float8_experimental.float8_utils import (
2424
amax_history_to_scale,
@@ -174,9 +174,7 @@ class DelayedScalingRecipe:
174174

175175
class Float8LinearMixin(object):
176176
def __init__(self, *args, **kwargs):
177-
delayed_scaling_recipe = kwargs.pop(
178-
"delayed_scaling_recipe", DelayedScalingRecipe()
179-
)
177+
delayed_scaling_recipe = kwargs.pop("delayed_scaling_recipe", DelayedScalingRecipe())
180178
super().__init__(*args, **kwargs)
181179

182180
# TODO(future): have a unique recipe per buffer instead of one per
@@ -239,10 +237,7 @@ def cast_x_to_float8(self, x, is_amax_initialized):
239237
torch.float8_e4m3fn,
240238
is_amax_initialized,
241239
)
242-
x_fp8 = Float8Tensor.to_float8(
243-
x, self.fp8_scale_x, torch.float8_e4m3fn, self.fp8_amax_x
244-
)
245-
240+
x_fp8 = to_float8(x, self.fp8_scale_x, torch.float8_e4m3fn, self.fp8_amax_x)
246241
return x_fp8
247242

248243
def cast_w_to_float8(self, w, is_amax_initialized):
@@ -256,9 +251,7 @@ def cast_w_to_float8(self, w, is_amax_initialized):
256251
torch.float8_e4m3fn,
257252
is_amax_initialized,
258253
)
259-
w_fp8 = Float8Tensor.to_float8(
260-
w, self.fp8_scale_w, torch.float8_e4m3fn, self.fp8_amax_w
261-
)
254+
w_fp8 = to_float8(w, self.fp8_scale_w, torch.float8_e4m3fn, self.fp8_amax_w)
262255
return w_fp8
263256

264257
def cast_y_to_float8_in_bw(self, y):
@@ -275,9 +268,7 @@ def cast_y_to_float8_in_bw(self, y):
275268

276269
def float8_mm(self, x_fp8, w_fp8, is_amax_initialized):
277270
scale_fn_name = self.recipe.scale_fn_name
278-
y = float8_linear.apply(
279-
x_fp8, w_fp8, is_amax_initialized, scale_fn_name, self.emulate
280-
)
271+
y = float8_linear.apply(x_fp8, w_fp8, is_amax_initialized, scale_fn_name, self.emulate)
281272
return y
282273

283274
def float8_pre_forward(self, x):
@@ -416,9 +407,7 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module) -> None:
416407
#
417408
_update_history_with_new_amax(child.fp8_amax_x, child.fp8_amax_history_x)
418409
_update_history_with_new_amax(child.fp8_amax_w, child.fp8_amax_history_w)
419-
_update_history_with_new_amax(
420-
child.fp8_amax_dL_dY, child.fp8_amax_history_dL_dY
421-
)
410+
_update_history_with_new_amax(child.fp8_amax_dL_dY, child.fp8_amax_history_dL_dY)
422411

423412
#
424413
# 3. calculate the scales

float8_experimental/float8_tensor.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,20 @@ def backward(ctx, g):
7575
return g, None, None, None
7676

7777

78+
def to_float8(tensor: torch.Tensor, scale: torch.Tensor, float8_dtype: torch.dtype, amax_buffer:torch.Tensor =None) -> "Float8Tensor":
79+
""" Converts a higher precision tensor to float8 in a differentiable way.
80+
81+
Args:
82+
tensor: the tensor to convert
83+
scale: the scale to use to convert the tensor
84+
float8_dtype: the float8 dtype to use
85+
amax_buffer: a buffer to store the amax value in prior to conversion
86+
87+
Returns:
88+
Float8Tensor: a float8 tensor
89+
"""
90+
return ToFloat8ConstrFunc.apply(tensor, scale, float8_dtype, amax_buffer)
91+
7892
class FromFloat8ConstrFunc(torch.autograd.Function):
7993
"""
8094
A differentiable conversion from fp8
@@ -86,7 +100,7 @@ def forward(ctx, tensor):
86100

87101
@staticmethod
88102
def backward(ctx, g):
89-
return Float8Tensor.to_float8(g), None, None
103+
return to_float8(g), None, None
90104

91105

92106
class Float8Tensor(torch.Tensor):
@@ -154,10 +168,6 @@ def __tensor_unflatten__(inner_tensors: Dict, metadata):
154168
def to_original_precision(self):
155169
return FromFloat8ConstrFunc.apply(self)
156170

157-
@classmethod
158-
def to_float8(cls, tensor, scale, float8_dtype, amax_buffer=None):
159-
return ToFloat8ConstrFunc.apply(tensor, scale, float8_dtype, amax_buffer)
160-
161171
@classmethod
162172
def __torch_dispatch__(cls, func, types, args, kwargs=None):
163173
# 1. tracing through __torch_function__ logic is not supported yet in

test/test_base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
)
1616
from float8_experimental.float8_linear_nots import Float8LinearNoTensorSubclass
1717
from float8_experimental.float8_python_api import mm_float8
18-
from float8_experimental.float8_tensor import Float8Tensor
18+
from float8_experimental.float8_tensor import Float8Tensor, to_float8
1919

2020
from float8_experimental.float8_utils import (
2121
amax_to_scale,
@@ -39,7 +39,7 @@ def test_preserves_dtype(self):
3939
for hp_dtype, lp_dtype in itertools.product(hp_dtypes, lp_dtypes):
4040
x1_hp = torch.randn(4, 4, dtype=hp_dtype)
4141
x1_s = tensor_to_scale(x1_hp, lp_dtype)
42-
x2_lp = Float8Tensor.to_float8(x1_hp, x1_s, lp_dtype)
42+
x2_lp = to_float8(x1_hp, x1_s, lp_dtype)
4343
x3_hp = x2_lp.to_original_precision()
4444
self.assertTrue(x3_hp.dtype == hp_dtype)
4545

@@ -248,8 +248,8 @@ def test_scaled_mm_vs_emulated(self, base_dtype):
248248
a_scale = tensor_to_scale(a, input_dtype).float()
249249
b_scale = tensor_to_scale(b, input_dtype).float()
250250

251-
a_fp8 = Float8Tensor.to_float8(a, a_scale, input_dtype)
252-
b_fp8 = Float8Tensor.to_float8(b, b_scale, input_dtype)
251+
a_fp8 = to_float8(a, a_scale, input_dtype)
252+
b_fp8 = to_float8(b, b_scale, input_dtype)
253253

254254
out_scaled_mm, output_amax_scaled = mm_float8(
255255
a_fp8, b_fp8, output_dtype=output_dtype, emulate=False

0 commit comments

Comments
 (0)