Skip to content

Commit df6bd8f

Browse files
authored
fix: fix CPU offloading in FSDP grad clipping and weight updates (#680)
Updates the gradient clipping implementation to correctly handle parameters offloaded to CPU, bypassing CUDA-specific optimizations when necessary to prevent runtime errors. Refactors the FSDP engine's weight broadcasting logic to properly materialize and batch DTensors in offloaded scenarios. Additionally, introduces a new test suite to verify gradient normalization and clipping behavior across different device configurations.
1 parent 26738c2 commit df6bd8f

File tree

3 files changed

+597
-38
lines changed

3 files changed

+597
-38
lines changed

areal/engine/fsdp_engine.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,26 @@ def _init_weight_update_from_distributed(self, meta: WeightUpdateMeta):
448448

449449
fut.result()
450450

451+
def _get_full_tensor(self, param: nn.Parameter) -> torch.Tensor:
452+
"""Get full tensor from a parameter, handling DTensor and CPU offloaded tensors."""
453+
tensor = param.data
454+
if isinstance(tensor, DTensor):
455+
# For non-offloaded DTensor, directly call full_tensor()
456+
if tensor.device.type != "cpu":
457+
return tensor.full_tensor()
458+
459+
# Handle CPU offloaded DTensor: reconstruct DTensor from local tensor
460+
temp_dtensor = DTensor.from_local(
461+
tensor.to_local(),
462+
device_mesh=tensor.device_mesh,
463+
placements=tensor.placements,
464+
)
465+
return temp_dtensor.full_tensor()
466+
else:
467+
if tensor.device.type == "cpu":
468+
tensor = tensor.to(current_platform.device_type)
469+
return tensor
470+
451471
@trace_perf("fsdp_engine.update_weights_from_distributed", category="comm")
452472
def _update_weights_from_distributed(self, meta: WeightUpdateMeta):
453473
"""Broadcast parameters (chunked) from rank 0 (FSDP2 compatible)."""
@@ -458,18 +478,16 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta):
458478
dist.barrier(group=self.cpu_group)
459479

460480
weight_chunked_mem_size = meta.weight_chunked_mem_mb * 1024 * 1024
481+
main_rank = dist.get_rank() == 0
461482

462483
buffer_size = 0
463-
named_tensors = []
484+
named_tensors: list[tuple[str, torch.Tensor]] = []
464485

465486
for name, param in self.get_model_name_parameters():
466-
if isinstance(param.data, DTensor):
467-
tensor = param.data.full_tensor()
468-
else:
469-
tensor = param.data
487+
tensor = self._get_full_tensor(param)
470488

471489
# Ranks other than 0 only help to get the full tensor
472-
if dist.get_rank() != 0:
490+
if not main_rank:
473491
continue
474492

475493
tensor_size = tensor.numel() * tensor.element_size()
@@ -481,7 +499,7 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta):
481499
named_tensors.append((name, tensor))
482500
buffer_size += tensor_size
483501

484-
# Only rank-0 CAN contain named tensors here
502+
# Process remaining parameters
485503
if named_tensors:
486504
self._update_bucket_weights_from_distributed(meta, named_tensors)
487505

@@ -807,6 +825,7 @@ def train_batch(
807825
list(self.model.parameters()),
808826
self.world_mesh,
809827
max_norm=self.optimizer_config.gradient_clipping,
828+
offload_params=self.config.fsdp.offload_params,
810829
)
811830

812831
if not math.isfinite(grad_norm):

0 commit comments

Comments
 (0)