Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Fix graph breaks in tensor subclass #131

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions float8_experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Lets wait to define the top level interface
# from float8_experimental.float8_tensor import Float8Tensor
# Lets define a few top level things here
from float8_experimental.float8_tensor import Float8Tensor
from float8_experimental.float8_linear import Float8Linear

# __all__ = ["Float8Tensor"]
__all__ = ["Float8Tensor", "Float8Linear"]
23 changes: 6 additions & 17 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
)

from float8_experimental.float8_python_api import mm_float8
from float8_experimental.float8_tensor import Float8Tensor
from float8_experimental.float8_tensor import Float8Tensor, to_float8

from float8_experimental.float8_utils import (
amax_history_to_scale,
Expand Down Expand Up @@ -174,9 +174,7 @@ class DelayedScalingRecipe:

class Float8LinearMixin(object):
def __init__(self, *args, **kwargs):
delayed_scaling_recipe = kwargs.pop(
"delayed_scaling_recipe", DelayedScalingRecipe()
)
delayed_scaling_recipe = kwargs.pop("delayed_scaling_recipe", DelayedScalingRecipe())
super().__init__(*args, **kwargs)

# TODO(future): have a unique recipe per buffer instead of one per
Expand Down Expand Up @@ -239,10 +237,7 @@ def cast_x_to_float8(self, x, is_amax_initialized):
torch.float8_e4m3fn,
is_amax_initialized,
)
x_fp8 = Float8Tensor.to_float8(
x, self.fp8_scale_x, torch.float8_e4m3fn, self.fp8_amax_x
)

x_fp8 = to_float8(x, self.fp8_scale_x, torch.float8_e4m3fn, self.fp8_amax_x)
return x_fp8

def cast_w_to_float8(self, w, is_amax_initialized):
Expand All @@ -256,9 +251,7 @@ def cast_w_to_float8(self, w, is_amax_initialized):
torch.float8_e4m3fn,
is_amax_initialized,
)
w_fp8 = Float8Tensor.to_float8(
w, self.fp8_scale_w, torch.float8_e4m3fn, self.fp8_amax_w
)
w_fp8 = to_float8(w, self.fp8_scale_w, torch.float8_e4m3fn, self.fp8_amax_w)
return w_fp8

def cast_y_to_float8_in_bw(self, y):
Expand All @@ -275,9 +268,7 @@ def cast_y_to_float8_in_bw(self, y):

def float8_mm(self, x_fp8, w_fp8, is_amax_initialized):
scale_fn_name = self.recipe.scale_fn_name
y = float8_linear.apply(
x_fp8, w_fp8, is_amax_initialized, scale_fn_name, self.emulate
)
y = float8_linear.apply(x_fp8, w_fp8, is_amax_initialized, scale_fn_name, self.emulate)
return y

def float8_pre_forward(self, x):
Expand Down Expand Up @@ -416,9 +407,7 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module) -> None:
#
_update_history_with_new_amax(child.fp8_amax_x, child.fp8_amax_history_x)
_update_history_with_new_amax(child.fp8_amax_w, child.fp8_amax_history_w)
_update_history_with_new_amax(
child.fp8_amax_dL_dY, child.fp8_amax_history_dL_dY
)
_update_history_with_new_amax(child.fp8_amax_dL_dY, child.fp8_amax_history_dL_dY)

#
# 3. calculate the scales
Expand Down
20 changes: 15 additions & 5 deletions float8_experimental/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,20 @@ def backward(ctx, g):
return g, None, None, None


def to_float8(tensor: torch.Tensor, scale: torch.Tensor, float8_dtype: torch.dtype, amax_buffer:torch.Tensor =None) -> "Float8Tensor":
""" Converts a higher precision tensor to float8 in a differentiable way.

Args:
tensor: the tensor to convert
scale: the scale to use to convert the tensor
float8_dtype: the float8 dtype to use
amax_buffer: a buffer to store the amax value in prior to conversion

Returns:
Float8Tensor: a float8 tensor
"""
return ToFloat8ConstrFunc.apply(tensor, scale, float8_dtype, amax_buffer)

class FromFloat8ConstrFunc(torch.autograd.Function):
"""
A differentiable conversion from fp8
Expand All @@ -86,7 +100,7 @@ def forward(ctx, tensor):

@staticmethod
def backward(ctx, g):
return Float8Tensor.to_float8(g), None, None
return to_float8(g), None, None


class Float8Tensor(torch.Tensor):
Expand Down Expand Up @@ -154,10 +168,6 @@ def __tensor_unflatten__(inner_tensors: Dict, metadata):
def to_original_precision(self):
return FromFloat8ConstrFunc.apply(self)

@classmethod
def to_float8(cls, tensor, scale, float8_dtype, amax_buffer=None):
return ToFloat8ConstrFunc.apply(tensor, scale, float8_dtype, amax_buffer)

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
# 1. tracing through __torch_function__ logic is not supported yet in
Expand Down
8 changes: 4 additions & 4 deletions test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from float8_experimental.float8_linear_nots import Float8LinearNoTensorSubclass
from float8_experimental.float8_python_api import mm_float8
from float8_experimental.float8_tensor import Float8Tensor
from float8_experimental.float8_tensor import Float8Tensor, to_float8

from float8_experimental.float8_utils import (
amax_to_scale,
Expand All @@ -39,7 +39,7 @@ def test_preserves_dtype(self):
for hp_dtype, lp_dtype in itertools.product(hp_dtypes, lp_dtypes):
x1_hp = torch.randn(4, 4, dtype=hp_dtype)
x1_s = tensor_to_scale(x1_hp, lp_dtype)
x2_lp = Float8Tensor.to_float8(x1_hp, x1_s, lp_dtype)
x2_lp = to_float8(x1_hp, x1_s, lp_dtype)
x3_hp = x2_lp.to_original_precision()
self.assertTrue(x3_hp.dtype == hp_dtype)

Expand Down Expand Up @@ -248,8 +248,8 @@ def test_scaled_mm_vs_emulated(self, base_dtype):
a_scale = tensor_to_scale(a, input_dtype).float()
b_scale = tensor_to_scale(b, input_dtype).float()

a_fp8 = Float8Tensor.to_float8(a, a_scale, input_dtype)
b_fp8 = Float8Tensor.to_float8(b, b_scale, input_dtype)
a_fp8 = to_float8(a, a_scale, input_dtype)
b_fp8 = to_float8(b, b_scale, input_dtype)

out_scaled_mm, output_amax_scaled = mm_float8(
a_fp8, b_fp8, output_dtype=output_dtype, emulate=False
Expand Down
Empty file modified test/test_everything.sh
100644 → 100755
Empty file.
Empty file modified test/test_fsdp.sh
100644 → 100755
Empty file.
Empty file modified test/test_tp.sh
100644 → 100755
Empty file.