Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit d03d16b

Browse files
committed
do less in the backwards
1 parent 20da1c0 commit d03d16b

File tree

2 files changed

+26
-12
lines changed

2 files changed

+26
-12
lines changed

float8_experimental/float8_ops.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torch
99

1010
from float8_experimental.float8_python_api import addmm_float8_unwrapped
11-
from float8_experimental.float8_tensor import Float8Tensor
11+
from float8_experimental.float8_tensor import Float8Tensor, re_construct_float8_weight
1212
from float8_experimental.float8_utils import is_row_major
1313
from torch.utils._pytree import tree_map
1414

@@ -191,9 +191,7 @@ def forward(
191191
if recompute_float8_weight:
192192
# This should be set to True when using traditional fsdp to avoid
193193
# saving the unsharded weight for backwards
194-
ctx.save_for_backward(
195-
x_fp8, original_weight, weight_scale, weight_amax_buffer
196-
)
194+
ctx.save_for_backward(x_fp8, original_weight, weight_scale)
197195
else:
198196
# Does this interact properly with activation checkpointing?
199197
ctx.save_for_backward(x_fp8, w_fp8)
@@ -211,19 +209,15 @@ def forward(
211209
@staticmethod
212210
def backward(ctx, go_fp8: torch.Tensor):
213211
if ctx.recompute_float8_weight:
214-
x_fp8, original_weight, weight_scale, weight_amax_buffer = ctx.saved_tensors
215-
w_fp8 = Float8Tensor.to_float8(
216-
original_weight,
217-
weight_scale,
218-
torch.float8_e4m3fn,
219-
weight_amax_buffer,
220-
emulate=ctx.emulate,
212+
x_fp8, original_weight, weight_scale = ctx.saved_tensors
213+
w_fp8 = re_construct_float8_weight(
214+
original_weight, weight_scale, torch.float8_e4m3fn, emulate=ctx.emulate
221215
)
222216
else:
223217
x_fp8, w_fp8 = ctx.saved_tensors
224218

225219
# calculate dL/dX
226-
go_fp8_reshaped = go_fp8.view(-1, go_fp8.size(-1))
220+
go_fp8_reshaped = go_fp8.reshape(-1, go_fp8.size(-1))
227221
w_fp8_t_c_t = w_fp8.t().contiguous().t()
228222
dL_dX = float8_mm_helper(go_fp8_reshaped, w_fp8_t_c_t)
229223
dL_dX = dL_dX.view(*go_fp8.shape[:-1], dL_dX.size(-1))

float8_experimental/float8_tensor.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,26 @@ def backward(ctx, g):
4242
return g, None, None, None, None
4343

4444

45+
@torch._dynamo.allow_in_graph
46+
def re_construct_float8_weight(
47+
tensor: torch.Tensor, scale: torch.Tensor, float8_dtype, emulate: bool = False
48+
):
49+
"""In the backwards of float8_linear we don't need to fill the amax buffer
50+
for the weight tensor since that was done during the forward and we just need to
51+
recast the orignal precision tensor using the scale from the forward
52+
53+
Args:
54+
tensor: the tensor to convert
55+
scale: the scale to use to convert the tensor, from the forward
56+
float8_dtype: the float8 dtype to use
57+
emulate: if true using fp32 emulation for the matmuls, helpful
58+
if you don't have access to h100 hardware.
59+
"""
60+
tensor_scaled = tensor * scale
61+
bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype)
62+
return Float8Tensor(bits_fp8, scale, tensor.dtype, emulate=emulate)
63+
64+
4565
@torch._dynamo.allow_in_graph
4666
class FromFloat8ConstrFunc(torch.autograd.Function):
4767
"""

0 commit comments

Comments
 (0)