Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 436102b

Browse files
Andrew Gufacebook-github-bot
Andrew Gu
authored andcommitted
Added changes for FSDP fp8 all-gather (#130)
Summary: This adds two changes: 1. `Float8LinearMixin` saves a constructor for FSDP to use to construct the `Float8Tensor` for `w_fp8`. This is needed for FSDP to manage the unsharded gradient and since FSDP prefers to own the underlying data/storage for all-gather. 2. `Float8Linear.forward()` checks if `self._w_fp8` has been set (by FSDP) and skips the cast itself if so. I have tested this with P865757339 (not cleaned up), but I do not think we need to land yet. (This does mean there is a chance for changes to this repo to break FSDP fp8 all-gather, but I think it is fine for now.) Pull Request resolved: #130 Reviewed By: awgu Differential Revision: D50754666 Pulled By: drisspg fbshipit-source-id: 9f7a9bfc9f2b3cb7455cd8b8642f0c4e4a55ee64
1 parent fcb9011 commit 436102b

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

float8_experimental/float8_linear.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,13 @@ def __init__(self, *args, **kwargs):
214214
# Note: this is not used in non-TP code.
215215
self.use_sequence_parallel = False
216216

217+
# Save the Float8Tensor constructor for FSDP.
218+
# N.B. Do not partially apply the scale into the constructor because
219+
# buffer Python IDs are not preserved by `nn.Module.to()` and the
220+
# module could be moved to GPU after this constructor. Instead, FSDP
221+
# will access the scale when it has ensured that it is on GPU.
222+
self._float8_tensor_ctor = lambda *args, **kwargs: Float8Tensor(*args, **kwargs)
223+
217224
def cast_x_to_float8(self, x, is_amax_initialized):
218225
# Duplicate the autocast logic for F.linear, so that the output
219226
# of our module has the right original precision
@@ -305,7 +312,10 @@ def forward(self, x):
305312
self.float8_pre_forward(x)
306313

307314
x_fp8 = self.cast_x_to_float8(x, self.is_amax_initialized)
308-
w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized)
315+
if getattr(self, "_w_fp8", None) is not None: # FSDP handled the cast
316+
w_fp8 = self._w_fp8
317+
else:
318+
w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized)
309319
y = self.float8_mm(x_fp8, w_fp8, self.is_amax_initialized)
310320
y = self.cast_y_to_float8_in_bw(y)
311321

0 commit comments

Comments
 (0)