-
Notifications
You must be signed in to change notification settings - Fork 19
Use mm in subclass #128
Use mm in subclass #128
Changes from all commits
9d7de42
7f8610f
e1809e0
25d0371
8cbb83f
4c51207
9b1cffe
952f6c1
007e00d
9770785
10b6218
88d3bb5
49b4985
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,7 +18,7 @@ | |
) | ||
|
||
from float8_experimental.float8_python_api import mm_float8 | ||
from float8_experimental.float8_tensor import Float8Tensor, to_float8 | ||
from float8_experimental.float8_tensor import Float8Tensor | ||
|
||
from float8_experimental.float8_utils import ( | ||
amax_history_to_scale, | ||
|
@@ -44,10 +44,12 @@ def forward( | |
fp8_scale_dL_dY, | ||
scale_fn_name, | ||
is_amax_initialized, | ||
emulate: bool, | ||
): | ||
ctx.save_for_backward(fp8_amax_dL_dY, fp8_amax_history_dL_dY, fp8_scale_dL_dY) | ||
ctx.scale_fn_name = scale_fn_name | ||
ctx.is_amax_initialized = is_amax_initialized | ||
ctx.emulate = emulate | ||
return tensor | ||
|
||
@staticmethod | ||
|
@@ -69,99 +71,11 @@ def backward(ctx, go): | |
fp8_amax_dL_dY.fill_(tensor_to_amax(go)) | ||
go_scaled = go * fp8_scale_dL_dY | ||
bits_fp8 = to_fp8_saturated(go_scaled, torch.float8_e5m2) | ||
empty_grads = None, None, None, None, None | ||
res = Float8Tensor(bits_fp8, fp8_scale_dL_dY, go.dtype) | ||
empty_grads = None, None, None, None, None, None | ||
res = Float8Tensor(bits_fp8, fp8_scale_dL_dY, go.dtype, emulate=ctx.emulate) | ||
return res, *empty_grads | ||
|
||
|
||
class float8_linear(torch.autograd.Function): | ||
""" | ||
Like F.linear, but with X and W in float8 | ||
""" | ||
|
||
@staticmethod | ||
def forward( | ||
ctx, | ||
x_fp8, | ||
w_fp8, | ||
is_amax_initialized, | ||
scale_fn_name, | ||
emulate: bool, | ||
): | ||
ctx.save_for_backward(x_fp8, w_fp8) | ||
ctx.scale_fn_name = scale_fn_name | ||
ctx.emulate = emulate | ||
orig_shape = x_fp8._data.shape | ||
x_fp8_reshaped = Float8Tensor( | ||
x_fp8._data.reshape(-1, orig_shape[-1]), x_fp8._scale, x_fp8._orig_dtype | ||
) | ||
ctx.is_amax_initialized = is_amax_initialized | ||
|
||
w_fp8_t = Float8Tensor(w_fp8._data.t(), w_fp8._scale, w_fp8._orig_dtype) | ||
|
||
res_bits, _output_amax = mm_float8( | ||
x_fp8_reshaped, w_fp8_t, output_dtype=x_fp8._orig_dtype, emulate=emulate | ||
) | ||
res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1]) | ||
return res_bits | ||
|
||
@staticmethod | ||
def backward(ctx, go_fp8): | ||
x_fp8, w_fp8 = ctx.saved_tensors | ||
scale_fn_name = ctx.scale_fn_name | ||
emulate = ctx.emulate | ||
is_amax_initialized = ctx.is_amax_initialized | ||
|
||
go_fp8_orig_shape = go_fp8._data.shape | ||
go_fp8_reshaped = Float8Tensor( | ||
go_fp8._data.reshape(-1, go_fp8_orig_shape[-1]), | ||
go_fp8._scale, | ||
go_fp8._orig_dtype, | ||
) | ||
|
||
w_fp8_t_c_t = Float8Tensor( | ||
w_fp8._data.t().contiguous().t(), w_fp8._scale, w_fp8._orig_dtype | ||
) | ||
|
||
# | ||
# calculate dL/dX | ||
# | ||
dL_dX, _dL_dX_amax = mm_float8( | ||
go_fp8_reshaped, | ||
w_fp8_t_c_t, | ||
output_dtype=x_fp8._orig_dtype, | ||
emulate=emulate, | ||
) | ||
dL_dX = dL_dX.reshape(*go_fp8_orig_shape[:-1], dL_dX.shape[-1]) | ||
|
||
x_fp8_orig_shape = x_fp8._data.shape | ||
x_fp8_reshaped_t_c = Float8Tensor( | ||
x_fp8._data.reshape(-1, x_fp8_orig_shape[-1]).t().contiguous(), | ||
x_fp8._scale, | ||
x_fp8._orig_dtype, | ||
) | ||
|
||
go_fp8_reshaped_t_c_t = Float8Tensor( | ||
go_fp8_reshaped._data.t().contiguous().t(), | ||
go_fp8_reshaped._scale, | ||
go_fp8_reshaped._orig_dtype, | ||
) | ||
|
||
# | ||
# calculate dL/dW | ||
# | ||
dL_dW, _dL_dW_amax = mm_float8( | ||
x_fp8_reshaped_t_c, | ||
go_fp8_reshaped_t_c_t, | ||
output_dtype=x_fp8._orig_dtype, | ||
emulate=emulate, | ||
) | ||
dL_dW = dL_dW.t() | ||
|
||
empty_grads = None, None, None, None, None, None, None, None, None | ||
return dL_dX, dL_dW, *empty_grads | ||
|
||
|
||
@dataclasses.dataclass | ||
class DelayedScalingRecipe: | ||
# Controls the history length of amax buffers | ||
|
@@ -221,13 +135,17 @@ def __init__(self, *args, **kwargs): | |
# will access the scale when it has ensured that it is on GPU. | ||
self._float8_tensor_ctor = lambda *args, **kwargs: Float8Tensor(*args, **kwargs) | ||
|
||
def cast_x_to_float8(self, x, is_amax_initialized): | ||
def cast_x_to_float8( | ||
self, x: torch.Tensor, is_amax_initialized: bool | ||
) -> torch.Tensor: | ||
# Duplicate the autocast logic for F.linear, so that the output | ||
# of our module has the right original precision | ||
if torch.is_autocast_enabled(): | ||
# For now, hardcode to GPU's autocast dtype | ||
# if we need CPU support in the future, we can add it | ||
x = x.to(torch.get_autocast_gpu_dtype()) | ||
autocast_dtype = torch.get_autocast_gpu_dtype() | ||
x = x.to(autocast_dtype) | ||
self.bias_dtype = autocast_dtype | ||
|
||
scale_fn_name = self.recipe.scale_fn_name | ||
_maybe_initialize_amaxes_scales_for_float8_cast( | ||
|
@@ -239,10 +157,14 @@ def cast_x_to_float8(self, x, is_amax_initialized): | |
torch.float8_e4m3fn, | ||
is_amax_initialized, | ||
) | ||
x_fp8 = to_float8(x, self.fp8_scale_x, torch.float8_e4m3fn, self.fp8_amax_x) | ||
x_fp8 = Float8Tensor.to_float8( | ||
x, self.fp8_scale_x, torch.float8_e4m3fn, self.fp8_amax_x, self.emulate | ||
) | ||
return x_fp8 | ||
|
||
def cast_w_to_float8(self, w, is_amax_initialized): | ||
def cast_w_to_float8( | ||
self, w: torch.Tensor, is_amax_initialized: bool | ||
) -> torch.Tensor: | ||
scale_fn_name = self.recipe.scale_fn_name | ||
_maybe_initialize_amaxes_scales_for_float8_cast( | ||
w, | ||
|
@@ -253,10 +175,14 @@ def cast_w_to_float8(self, w, is_amax_initialized): | |
torch.float8_e4m3fn, | ||
is_amax_initialized, | ||
) | ||
w_fp8 = to_float8(w, self.fp8_scale_w, torch.float8_e4m3fn, self.fp8_amax_w) | ||
w_fp8 = Float8Tensor.to_float8( | ||
w, self.fp8_scale_w, torch.float8_e4m3fn, self.fp8_amax_w, self.emulate | ||
) | ||
return w_fp8 | ||
|
||
def cast_y_to_float8_in_bw(self, y): | ||
def cast_y_to_float8_in_bw( | ||
self, y: torch.Tensor, emulate: bool = False | ||
) -> torch.Tensor: | ||
scale_fn_name = self.recipe.scale_fn_name | ||
y = NoopFwToFloat8E5M2Bw.apply( | ||
y, | ||
|
@@ -265,13 +191,7 @@ def cast_y_to_float8_in_bw(self, y): | |
self.fp8_scale_dL_dY, | ||
scale_fn_name, | ||
self.is_amax_initialized, | ||
) | ||
return y | ||
|
||
def float8_mm(self, x_fp8, w_fp8, is_amax_initialized): | ||
scale_fn_name = self.recipe.scale_fn_name | ||
y = float8_linear.apply( | ||
x_fp8, w_fp8, is_amax_initialized, scale_fn_name, self.emulate | ||
emulate, | ||
) | ||
return y | ||
|
||
|
@@ -292,6 +212,11 @@ def float8_post_forward(self): | |
self.is_amax_initialized = True | ||
self.amax_and_scale_synced = False | ||
|
||
def add_weight_tag(self): | ||
# We add a tag to the weight nn.Parameter in order to signal | ||
# To FSDP that this param is a weight | ||
self.weight._is_fp8_weight = True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm reviewing the subclass changes but probably not the right person to review this one There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was added in a previous PR and just moved it to the Mixin so that it can be added to the TP stuff |
||
|
||
|
||
class Float8Linear(Float8LinearMixin, torch.nn.Linear): | ||
""" | ||
|
@@ -311,11 +236,14 @@ def forward(self, x): | |
w_fp8 = self._w_fp8 | ||
else: | ||
w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized) | ||
y = self.float8_mm(x_fp8, w_fp8, self.is_amax_initialized) | ||
y = self.cast_y_to_float8_in_bw(y) | ||
|
||
y = torch.matmul(x_fp8, w_fp8.t()) | ||
|
||
# Cast gradY to float8_e5m2 during backward | ||
y = self.cast_y_to_float8_in_bw(y, self.emulate) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Mentioned offline but food for thought that I'll mention here: it would be interesting to think about what it would take to have My understanding was: (1) This is a pain mostly because the extra buffers for float8 live directly on the (2) Doing this would provide benefit if we want to start increasing the number of ops that directly handle float8, but all we care about is linear then this generality is probably not very useful. |
||
|
||
if self.bias is not None: | ||
y = y + self.bias.to(x_fp8._orig_dtype) | ||
y = y + self.bias.to(self.bias_dtype) | ||
|
||
self.float8_post_forward() | ||
return y | ||
|
@@ -336,16 +264,13 @@ def from_float(cls, mod, emulate: bool = False): | |
new_mod.weight = mod.weight | ||
new_mod.bias = mod.bias | ||
new_mod.emulate = emulate | ||
if mod.bias is not None: | ||
new_mod.bias_dtype = mod.bias.dtype | ||
# I think its okay to send all params and buffers to device | ||
new_mod.to(mod.weight.device) | ||
new_mod.add_weight_tag() | ||
return new_mod | ||
|
||
def add_weight_tag(self): | ||
# We add a tag to the weight nn.Parameter in order to signal | ||
# To FSDP that this param is a weight | ||
self.weight._is_fp8_weight = True | ||
|
||
|
||
def swap_linear_with_float8_linear( | ||
model, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
from typing import Any, Dict | ||
|
||
import torch | ||
from float8_experimental.float8_python_api import mm_float8_unwrapped | ||
from float8_experimental.float8_tensor import Float8Tensor | ||
from float8_experimental.float8_utils import is_row_major | ||
|
||
aten = torch.ops.aten | ||
FLOAT8_OPS_TABLE: Dict[Any, Any] = {} | ||
|
||
|
||
def implements(aten_ops): | ||
"""Register aten ops to the float8 op table""" | ||
|
||
def decorator(func): | ||
for op in aten_ops: | ||
FLOAT8_OPS_TABLE[op] = func | ||
return func | ||
|
||
return decorator | ||
|
||
|
||
@implements( | ||
[ | ||
aten.view.default, | ||
aten._unsafe_view.default, | ||
aten.t.default, | ||
aten.as_strided.default, | ||
aten.clone.default, | ||
aten.detach.default, | ||
] | ||
) | ||
def float8_desugar_op(aten_op, args, kwargs=None): | ||
new_data = aten_op(args[0]._data, *args[1:], **kwargs) | ||
return Float8Tensor(new_data, args[0]._scale, args[0]._orig_dtype, args[0]._emulate) | ||
|
||
|
||
@implements([aten.mm.default]) | ||
def float8_mm(aten_op, args, kwargs=None): | ||
assert isinstance(args[0], Float8Tensor) and isinstance(args[1], Float8Tensor) | ||
a = args[0] | ||
b = args[1] | ||
a_data = a._data | ||
a_scale = a._scale | ||
b_data = b._data | ||
|
||
if not is_row_major(a_data.stride()): | ||
a_data = a_data.contiguous() | ||
if is_row_major(b_data.stride()): | ||
b_data = b_data.t().contiguous().t() | ||
b_scale = b._scale | ||
output_dtype = a._orig_dtype | ||
if a._emulate: | ||
assert a._emulate == b._emulate | ||
return torch.ops.aten.mm_float8_emulated( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh also just thinking - should emulate just be a global config somewhere, instead of a flag that you have to plumb around? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Talked about this with Brian offline. This is probably right, but I am going to do this in a followup. I also want to see if when get plain torch.nn.fucntional.linear in the LinearFloat8 and will do some matmul changes |
||
a._data, a._scale, b._data, b._scale, output_dtype | ||
)[0] | ||
tensor_out, amax = mm_float8_unwrapped( | ||
a_data, a_scale, b_data, b_scale, output_dtype, output_scale=None | ||
) | ||
return tensor_out | ||
|
||
|
||
@implements([aten.is_same_size.default]) | ||
def float8_is_same_size(aten_op, args, kwargs=None): | ||
return args[0].shape == args[1].shape | ||
|
||
|
||
@implements([aten._to_copy.default]) | ||
def autocast_to_copy(aten_op, args, kwargs=None): | ||
"""This gets called when running matmul under autocast | ||
when the input is a Float8Tensor, presenting as a fp32 | ||
tensor. | ||
""" | ||
assert isinstance(args[0], Float8Tensor) | ||
assert ( | ||
len(kwargs) == 1 and "dtype" in kwargs | ||
), "Only support dtype kwarg for autocast" | ||
assert ( | ||
kwargs["dtype"] == torch.float16 | ||
), "Only support floating point conversion for autocast w/ Float8Tensor" | ||
return Float8Tensor( | ||
args[0]._data, args[0]._scale, kwargs["dtype"], args[0]._emulate | ||
) |
Uh oh!
There was an error while loading. Please reload this page.