Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit de93990

Browse files
vkuzofacebook-github-bot
authored andcommitted
support delayed scaling of weight in float8 all-gather (#312)
Summary: Pull Request resolved: #312 Adds support for delayed scaling in FSDP2 float8 all-gather. In detail: 1. add `WeightWithDelayedFloat8CastTensor`, note that we don't reuse code with the dynamic version because I'd rather not deal with plumbing optional tensors through dynamo. We can try that in a separate PR later. 2. wire `Float8Linear` to use (1) 3. add weight amax syncing back, since we need it for float8 all-gather 4. add test coverage for eager mode numerics Next up (in separate PRs) will be training run validation for numerics, and taking a look at performance. Reviewed By: awgu Differential Revision: D59685258 fbshipit-source-id: 9ff18d7649cc6e0e3c9e2a64a30a5ff8bc4108be
1 parent 3fe7c4a commit de93990

File tree

5 files changed

+320
-52
lines changed

5 files changed

+320
-52
lines changed

float8_experimental/float8_linear.py

Lines changed: 54 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,10 @@
3434
tensor_to_amax,
3535
)
3636

37-
from float8_experimental.fsdp_utils import WeightWithDynamicFloat8CastTensor
37+
from float8_experimental.fsdp_utils import (
38+
WeightWithDelayedFloat8CastTensor,
39+
WeightWithDynamicFloat8CastTensor,
40+
)
3841

3942

4043
def _maybe_initialize_amaxes_scales_for_float8_cast(
@@ -316,28 +319,30 @@ def cast_w_to_float8(
316319
self, w: torch.Tensor, is_amax_initialized: bool
317320
) -> torch.Tensor:
318321
if self.scaling_type_w is TensorScalingType.DELAYED:
319-
scale_fn_name = self.recipe.scale_fn_name
320-
_maybe_initialize_amaxes_scales_for_float8_cast(
321-
w,
322-
self.fp8_amax_w,
323-
self.fp8_amax_history_w,
324-
self.fp8_scale_w,
325-
scale_fn_name,
326-
e4m3_dtype,
327-
is_amax_initialized,
328-
reduce_amax=False,
329-
)
330-
331-
w_fp8 = Float8Tensor.to_float8(
332-
w,
333-
self.fp8_scale_w,
334-
e4m3_dtype,
335-
self.fp8_amax_w,
336-
self.forward_config,
337-
)
322+
if isinstance(self.weight, Float8Tensor): # cast by FSDP
323+
w_fp8 = self.weight
324+
else:
325+
scale_fn_name = self.recipe.scale_fn_name
326+
_maybe_initialize_amaxes_scales_for_float8_cast(
327+
w,
328+
self.fp8_amax_w,
329+
self.fp8_amax_history_w,
330+
self.fp8_scale_w,
331+
scale_fn_name,
332+
e4m3_dtype,
333+
is_amax_initialized,
334+
reduce_amax=False,
335+
)
336+
337+
w_fp8 = Float8Tensor.to_float8(
338+
w,
339+
self.fp8_scale_w,
340+
e4m3_dtype,
341+
self.fp8_amax_w,
342+
self.forward_config,
343+
)
338344
else:
339345
assert self.scaling_type_w is TensorScalingType.DYNAMIC
340-
# TODO(future): also support FSDP integration in delayed scaling path
341346
if isinstance(self.weight, Float8Tensor): # cast by FSDP
342347
w_fp8 = self.weight
343348
else:
@@ -436,18 +441,36 @@ def from_float(
436441
scaling_type_dL_dY=scaling_type_dL_dY,
437442
emulate=emulate,
438443
)
439-
if (
440-
scaling_type_w == TensorScalingType.DYNAMIC
441-
and config.enable_fsdp_fp8_all_gather
442-
):
443-
new_mod.weight = torch.nn.Parameter(
444-
WeightWithDynamicFloat8CastTensor(mod.weight, new_mod.forward_config)
445-
)
446-
else:
447-
assert not config.enable_fsdp_fp8_all_gather, "unsupported"
448-
new_mod.weight = mod.weight
444+
new_mod.weight = mod.weight
449445
new_mod.bias = mod.bias
450446
# need to create buffers again when moving from meta device to
451447
# real device
452448
new_mod.create_buffers()
449+
450+
# If FSDP float8 all-gather is on, wrap the weight in a float8-aware
451+
# tensor subclass. This must happen last because:
452+
# 1. weight needs to be on the correct device to create the buffers
453+
# 2. buffers need to be already created for the delayed scaling version
454+
# of the weight wrapper to be initialized
455+
if config.enable_fsdp_fp8_all_gather:
456+
if scaling_type_w is TensorScalingType.DYNAMIC:
457+
new_mod.weight = torch.nn.Parameter(
458+
WeightWithDynamicFloat8CastTensor(
459+
new_mod.weight,
460+
new_mod.forward_config,
461+
)
462+
)
463+
else:
464+
assert scaling_type_w is TensorScalingType.DELAYED
465+
new_mod.weight = torch.nn.Parameter(
466+
WeightWithDelayedFloat8CastTensor(
467+
new_mod.weight,
468+
new_mod.fp8_amax_w,
469+
new_mod.fp8_amax_history_w,
470+
new_mod.fp8_scale_w,
471+
new_mod.forward_config,
472+
new_mod.is_amax_initialized,
473+
)
474+
)
475+
453476
return new_mod

float8_experimental/float8_linear_utils.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -289,11 +289,10 @@ def inner_func():
289289
), "Mismatched lengths of amax tensors."
290290

291291
if dist.is_initialized():
292-
# Combine all the amax tensors into one tensor and reduce it
293-
# Note: do not reduce the weight values, because FSDP already ensures
294-
# the weight values on all ranks are the same after all-gather.
295292
all_amax_tensors = torch.cat(
296-
fp8_amax_x_tensor_list + fp8_amax_dL_dY_tensor_list
293+
fp8_amax_x_tensor_list
294+
+ fp8_amax_w_tensor_list
295+
+ fp8_amax_dL_dY_tensor_list
297296
)
298297
all_reduced_amax_tensor = all_reduce(
299298
all_amax_tensors, "MAX", list(range(dist.get_world_size()))
@@ -302,12 +301,14 @@ def inner_func():
302301
all_reduced_amax_tensor = all_reduced_amax_tensor.wait()
303302

304303
(
305-
reduced_fp8_amax_tensor,
304+
reduced_fp8_amax_x_tensor,
305+
reduced_fp8_amax_w_tensor,
306306
reduced_fp8_amax_dL_dY_tensor,
307307
) = torch.split(all_reduced_amax_tensor, len(fp8_amax_x_tensor_list))
308308

309309
for idx, child in enumerate(fp8_layers):
310-
child.fp8_amax_x.copy_(reduced_fp8_amax_tensor[idx])
310+
child.fp8_amax_x.copy_(reduced_fp8_amax_x_tensor[idx])
311+
child.fp8_amax_w.copy_(reduced_fp8_amax_w_tensor[idx])
311312
child.fp8_amax_dL_dY.copy_(reduced_fp8_amax_dL_dY_tensor[idx])
312313

313314
# We create two stacked tensor groups, one for the amax history and one for the current scales

float8_experimental/fsdp_utils.py

Lines changed: 180 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
ScaledMMConfig,
1919
)
2020

21-
from float8_experimental.float8_utils import EPS
21+
from float8_experimental.float8_utils import e4m3_dtype, EPS
2222
from torch._prims_common import suggest_memory_format
2323

2424

@@ -189,3 +189,182 @@ def fsdp_post_all_gather(
189189
out._scale = scale
190190
return
191191
return Float8Tensor(data, scale, param_dtype, self._mm_config), (data,)
192+
193+
194+
class WeightWithDelayedFloat8CastTensor(torch.Tensor):
195+
@staticmethod
196+
def __new__(
197+
cls,
198+
tensor: torch.Tensor,
199+
amax_buffer: torch.Tensor,
200+
amax_history_buffer: torch.Tensor,
201+
scale_buffer: torch.Tensor,
202+
mm_config: ScaledMMConfig,
203+
is_amax_initialized: bool,
204+
):
205+
return torch.Tensor._make_wrapper_subclass(
206+
cls,
207+
tensor.size(),
208+
strides=tensor.stride(),
209+
storage_offset=tensor.storage_offset(),
210+
memory_format=suggest_memory_format(tensor),
211+
dtype=tensor.dtype,
212+
layout=tensor.layout,
213+
device=tensor.device,
214+
pin_memory=tensor.is_pinned(),
215+
requires_grad=tensor.requires_grad,
216+
)
217+
218+
def __init__(
219+
self,
220+
tensor: torch.Tensor,
221+
amax_buffer: torch.Tensor,
222+
amax_history_buffer: torch.Tensor,
223+
scale_buffer: torch.Tensor,
224+
mm_config: ScaledMMConfig,
225+
is_amax_initialized: bool,
226+
):
227+
self._tensor = tensor
228+
self._amax_buffer = amax_buffer
229+
self._amax_history_buffer = amax_history_buffer
230+
self._scale_buffer = scale_buffer
231+
self._mm_config = mm_config
232+
233+
# Note: is_amax_initialized is not a buffer to avoid data dependent
234+
# control flow visible to dynamo
235+
# TODO(future PR): add serialization for this flag
236+
self.is_amax_initialized = is_amax_initialized
237+
238+
@classmethod
239+
def __torch_dispatch__(cls, func, types, args, kwargs=None):
240+
if func == torch.ops.aten.detach.default:
241+
return WeightWithDelayedFloat8CastTensor(
242+
args[0]._tensor,
243+
args[0]._amax_buffer,
244+
args[0]._amax_history_buffer,
245+
args[0]._scale_buffer,
246+
args[0]._mm_config,
247+
args[0].is_amax_initialized,
248+
)
249+
mm_config: Optional[ScaledMMConfig] = None
250+
amax_buffer: Optional[torch.Tensor] = None
251+
amax_history_buffer: Optional[torch.Tensor] = None
252+
scale_buffer: Optional[torch.Tensor] = None
253+
is_amax_initialized: Optional[bool] = None
254+
255+
def unwrap(t):
256+
nonlocal mm_config
257+
if mm_config is None:
258+
mm_config = t._mm_config
259+
else:
260+
mm_config = merge_mm_configs(mm_config, t._mm_config)
261+
nonlocal amax_buffer
262+
if amax_buffer is None:
263+
amax_buffer = t._amax_buffer
264+
nonlocal amax_history_buffer
265+
if amax_history_buffer is None:
266+
amax_history_buffer = t._amax_history_buffer
267+
nonlocal scale_buffer
268+
if scale_buffer is None:
269+
scale_buffer = t._scale_buffer
270+
nonlocal is_amax_initialized
271+
if is_amax_initialized is None:
272+
is_amax_initialized = t.is_amax_initialized
273+
return t._tensor
274+
275+
args, kwargs = pytree.tree_map_only(
276+
WeightWithDelayedFloat8CastTensor, unwrap, (args, kwargs or {})
277+
)
278+
out = func(*args, **kwargs)
279+
if func not in _ops_to_preserve_subclass:
280+
return out
281+
return pytree.tree_map_only(
282+
torch.Tensor,
283+
lambda x: WeightWithDelayedFloat8CastTensor(
284+
x,
285+
amax_buffer,
286+
amax_history_buffer,
287+
scale_buffer,
288+
mm_config,
289+
is_amax_initialized,
290+
),
291+
out,
292+
)
293+
294+
def __tensor_flatten__(self):
295+
return (
296+
[
297+
"_tensor",
298+
"_amax_buffer",
299+
"_amax_history_buffer",
300+
"_scale_buffer",
301+
],
302+
{
303+
"mm_config": self._mm_config,
304+
"is_amax_initialized": is_amax_initialized,
305+
},
306+
)
307+
308+
@staticmethod
309+
def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride):
310+
return WeightWithDelayedFloat8CastTensor(
311+
inner_tensors["_tensor"],
312+
inner_tensors["_amax_buffer"],
313+
inner_tensors["_amax_history_buffer"],
314+
inner_tensors["_scale_buffer"],
315+
metadata["mm_config"],
316+
metadata["is_amax_initialized"],
317+
)
318+
319+
def __repr__(self):
320+
return f"WeightWithDelayedFloat8CastTensor(tensor={self._tensor}, amax_buffer={self._amax_buffer}, scale_buffer={self._scale_buffer}, mm_config={self._mm_config})"
321+
322+
def fsdp_pre_all_gather(self, mesh):
323+
# initialize if needed
324+
# TODO(before land): ensure settings are consistent between Float8Linear and here
325+
if not self.is_amax_initialized:
326+
from float8_experimental.float8_linear import (
327+
_maybe_initialize_amaxes_scales_for_float8_cast,
328+
)
329+
330+
_maybe_initialize_amaxes_scales_for_float8_cast(
331+
self._tensor,
332+
self._amax_buffer,
333+
self._amax_history_buffer,
334+
self._scale_buffer,
335+
"max", # TODO(before land): read this from parent
336+
e4m3_dtype,
337+
self.is_amax_initialized,
338+
reduce_amax=True,
339+
)
340+
self.is_amax_initialized = True
341+
342+
# this will:
343+
# 1. cast the tensor to float8 using `_scale_buffer`
344+
# 2. populate `_amax_buffer` inplace
345+
# TODO(future PR): clean up all the casting functions and clearly
346+
# separate dynamic vs delayed, tech debt has accumulated
347+
float8_tensor = Float8Tensor.to_float8(
348+
self._tensor,
349+
self._scale_buffer,
350+
e4m3_dtype,
351+
self._amax_buffer,
352+
self._mm_config,
353+
)
354+
return (float8_tensor._data,), (float8_tensor._scale,)
355+
356+
def fsdp_post_all_gather(
357+
self,
358+
all_gather_outputs: Tuple[torch.Tensor, ...],
359+
metadata: Any,
360+
param_dtype: torch.dtype,
361+
*,
362+
out: Optional[torch.Tensor] = None,
363+
):
364+
(data,) = all_gather_outputs
365+
(scale,) = metadata
366+
if out is not None:
367+
assert isinstance(out, Float8Tensor), f"{type(out)}"
368+
out._scale = scale
369+
return
370+
return Float8Tensor(data, scale, param_dtype, self._mm_config), (data,)

test/test_fsdp2/test_fsdp2_common.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@
66
import torch
77
import torch.distributed as dist
88
import torch.nn as nn
9+
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
10+
from float8_experimental.float8_linear_utils import (
11+
linear_requires_sync,
12+
sync_float8_amax_and_scale_history,
13+
)
914
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
1015

1116

@@ -17,6 +22,7 @@ def check_parity_no_mp(
1722
fsdp_optim: torch.optim.Optimizer,
1823
local_inp: torch.Tensor,
1924
precompute: bool = False,
25+
scaling_type_w: TensorScalingType = TensorScalingType.DYNAMIC,
2026
):
2127
for iter_idx in range(10):
2228
losses: List[torch.Tensor] = []
@@ -28,10 +34,18 @@ def check_parity_no_mp(
2834
for param in model.parameters():
2935
dist.all_reduce(param.grad)
3036
param.grad.div_(dist.get_world_size())
31-
# TODO(future): add amax syncing once delayed scaling is supported
37+
38+
if linear_requires_sync(scaling_type_w=scaling_type_w):
39+
sync_float8_amax_and_scale_history(model)
40+
3241
optim.step()
33-
if model is fsdp_model and precompute:
42+
if (
43+
model is fsdp_model
44+
and precompute
45+
and scaling_type_w is TensorScalingType.DYNAMIC
46+
):
3447
precompute_float8_dynamic_scale_for_fsdp(model)
48+
3549
test_cls.assertEqual(losses[0], losses[1])
3650

3751

0 commit comments

Comments
 (0)