Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 682c2e8

Browse files
committed
add to delayed linear
1 parent 27a6a7f commit 682c2e8

File tree

3 files changed

+71
-6
lines changed

3 files changed

+71
-6
lines changed

float8_experimental/float8_dynamic_linear.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,13 @@ def forward(self, x):
5353

5454
# y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias)
5555
weight_scale = tensor_to_scale(self.weight, torch.float8_e4m3fn)
56-
y = float8_linear.apply(
57-
x_fp8, self.weight, weight_scale, None, self.emulate, False
56+
y = float8_linear(
57+
x_fp8,
58+
self.weight,
59+
weight_scale,
60+
None,
61+
self.emulate,
62+
self.recompute_weight_cast,
5863
)
5964
# Cast gradY to float8_e5m2 during backward
6065
y = self.cast_to_float8e5m2_bw(y)
@@ -72,17 +77,22 @@ def cast_to_float8e5m2_bw(self, gradY):
7277
return NoopFwToFloat8E5M2Bw.apply(gradY, self.emulate)
7378

7479
@classmethod
75-
def from_float(cls, mod, emulate: bool = False):
80+
def from_float(
81+
cls, mod, emulate: bool = False, recompute_weight_cast: bool = False
82+
):
7683
"""
7784
Create an nn.Linear with fp8 compute from a regular nn.Linear
7885
7986
Args:
8087
mod (torch.nn.Linear): nn.Linear to convert
8188
emulate (bool): whether to emulate fp8 matmul logic in float32
89+
recompute_weight_cast (bool): whether to recompute the weight cast on every
90+
backwards pass
8291
"""
8392
with torch.device("meta"):
8493
new_mod = cls(mod.in_features, mod.out_features, bias=False)
8594
new_mod.weight = mod.weight
8695
new_mod.bias = mod.bias
8796
new_mod.emulate = emulate
97+
new_mod.recompute_weight_cast = recompute_weight_cast
8898
return new_mod

float8_experimental/float8_linear.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import float8_experimental.config as config
2020

2121
import torch
22+
from float8_experimental.float8_ops import float8_linear
2223

2324
from float8_experimental.float8_tensor import Float8Tensor
2425

@@ -172,6 +173,13 @@ def __init__(self, *args, **kwargs):
172173
# and torch.compile, this option can disable them
173174
self.enable_pre_and_post_forward = config.enable_pre_and_post_forward
174175

176+
# This flag is used to modify what gets saved for backwards. Its default value
177+
# is False, this saves the casted weight for backwards. Note that this typically increases memory usage
178+
# Because both the weight parameter and the casted weight are saved on device. If set to true
179+
# this will only save the weight parameter and during the backwards pass it will re-cast this weight to fp8.
180+
# For traditional FSDP this should be set to True in order to not save the un-sharded weight for backwards.
181+
self.recompute_weight_cast = False
182+
175183
def register_always_float32_buffer(
176184
self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True
177185
) -> None:
@@ -214,6 +222,20 @@ def cast_x_to_float8(
214222
)
215223
return x_fp8
216224

225+
def _maybe_init_amaxes_scales_weight(
226+
self, w: torch.Tensor, is_amax_initialized: bool
227+
):
228+
scale_fn_name = self.recipe.scale_fn_name
229+
_maybe_initialize_amaxes_scales_for_float8_cast(
230+
w,
231+
self.fp8_amax_w,
232+
self.fp8_amax_history_w,
233+
self.fp8_scale_w,
234+
scale_fn_name,
235+
torch.float8_e4m3fn,
236+
is_amax_initialized,
237+
)
238+
217239
def cast_w_to_float8(
218240
self, w: torch.Tensor, is_amax_initialized: bool
219241
) -> torch.Tensor:
@@ -284,9 +306,18 @@ def forward(self, x):
284306
self.float8_pre_forward(x)
285307

286308
x_fp8 = self.cast_x_to_float8(x, self.is_amax_initialized)
287-
w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized)
309+
# w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized)
310+
self._maybe_init_amaxes_scales_weight(self.weight, self.is_amax_initialized)
288311

289-
y = torch.matmul(x_fp8, w_fp8.t())
312+
y = float8_linear(
313+
x_fp8,
314+
self.weight,
315+
self.fp8_scale_w,
316+
self.fp8_amax_w,
317+
self.emulate,
318+
self.recompute_weight_cast,
319+
)
320+
# y = torch.matmul(x_fp8, w_fp8.t())
290321

291322
# Cast gradY to float8_e5m2 during backward
292323
y = self.cast_y_to_float8_in_bw(y, self.emulate)

float8_experimental/float8_ops.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def autocast_to_copy(aten_op, args, kwargs=None):
164164
)
165165

166166

167-
class float8_linear(torch.autograd.Function):
167+
class _float8_linear(torch.autograd.Function):
168168
"""Custom autograd function for computing torch.nn.Linear on Float8Tensor.
169169
170170
This is needed for a couple reasons, we want to have fine grained control over the
@@ -238,3 +238,27 @@ def backward(ctx, go_fp8: torch.Tensor):
238238

239239
empty_grads = None, None, None, None, None, None, None, None, None
240240
return dL_dX, dL_dW, *empty_grads
241+
242+
243+
# Need to allow_in_graph because:
244+
# (1) the forward returns a plain tensor
245+
# (2) the backward accepts a Float8Tensor subclass
246+
# dynamo has no good way to be told what the type of
247+
# the grad_out is today, so it (incorrectly) assumes it is also a plain tensor.
248+
@torch._dynamo.allow_in_graph
249+
def float8_linear(
250+
x_fp8: torch.Tensor,
251+
original_weight: torch.Tensor,
252+
weight_scale: torch.Tensor,
253+
weight_amax_buffer: Optional[torch.Tensor],
254+
emulate: bool,
255+
recompute_float8_weight: bool,
256+
):
257+
return _float8_linear.apply(
258+
x_fp8,
259+
original_weight,
260+
weight_scale,
261+
weight_amax_buffer,
262+
emulate,
263+
recompute_float8_weight,
264+
)

0 commit comments

Comments
 (0)