Skip to content

Commit 3088f14

Browse files
committed
fix: fix CPU offloading in FSDP grad clipping and weight updates
Updates the gradient clipping implementation to correctly handle parameters offloaded to CPU, bypassing CUDA-specific optimizations when necessary to prevent runtime errors. Refactors the FSDP engine's weight broadcasting logic to properly materialize and batch DTensors in offloaded scenarios. Additionally, introduces a new test suite to verify gradient normalization and clipping behavior across different device configurations.
1 parent 29f9184 commit 3088f14

File tree

3 files changed

+598
-40
lines changed

3 files changed

+598
-40
lines changed

areal/engine/fsdp_engine.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,21 @@ def _init_weight_update_from_distributed(self, meta: WeightUpdateMeta):
449449

450450
fut.result()
451451

452+
def _dtensor_to_full_tensor(self, dtensor: DTensor) -> torch.Tensor:
453+
"""Convert a DTensor to a full tensor, handling CPU offloaded tensors."""
454+
local_tensor = dtensor.to_local()
455+
if local_tensor.device.type != "cpu":
456+
return dtensor.full_tensor()
457+
458+
device_mesh = dtensor.device_mesh
459+
placements = dtensor.placements
460+
temp_dtensor = DTensor.from_local(
461+
local_tensor,
462+
device_mesh=device_mesh,
463+
placements=placements,
464+
)
465+
return temp_dtensor.full_tensor()
466+
452467
@trace_perf("fsdp_engine.update_weights_from_distributed", category="comm")
453468
def _update_weights_from_distributed(self, meta: WeightUpdateMeta):
454469
"""Broadcast parameters (chunked) from rank 0 (FSDP2 compatible)."""
@@ -459,30 +474,32 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta):
459474
dist.barrier(group=self.cpu_group)
460475

461476
weight_chunked_mem_size = meta.weight_chunked_mem_mb * 1024 * 1024
477+
main_rank = dist.get_rank() == 0
462478

463479
buffer_size = 0
464-
named_tensors = []
480+
named_tensors: list[tuple[str, torch.Tensor]] = []
465481

466482
for name, param in self.get_model_name_parameters():
467483
if isinstance(param.data, DTensor):
468-
tensor = param.data.full_tensor()
484+
tensor = self._dtensor_to_full_tensor(param.data)
469485
else:
470486
tensor = param.data
471-
472-
# Ranks other than 0 only help to get the full tensor
473-
if dist.get_rank() != 0:
474-
continue
487+
if tensor.device.type == "cpu":
488+
tensor = tensor.to(current_platform.device_type)
475489

476490
tensor_size = tensor.numel() * tensor.element_size()
477491

478-
if tensor_size + buffer_size > weight_chunked_mem_size:
492+
if tensor_size + buffer_size > weight_chunked_mem_size and named_tensors:
479493
self._update_bucket_weights_from_distributed(meta, named_tensors)
494+
named_tensors = []
480495
buffer_size = 0
481496

482-
named_tensors.append((name, tensor))
497+
# Only rank 0 collects tensors for broadcasting
498+
if main_rank:
499+
named_tensors.append((name, tensor))
483500
buffer_size += tensor_size
484501

485-
# Only rank-0 CAN contain named tensors here
502+
# Process remaining parameters
486503
if named_tensors:
487504
self._update_bucket_weights_from_distributed(meta, named_tensors)
488505

@@ -808,6 +825,7 @@ def train_batch(
808825
list(self.model.parameters()),
809826
self.world_mesh,
810827
max_norm=self.optimizer_config.gradient_clipping,
828+
offload_params=self.config.fsdp.offload_params,
811829
)
812830

813831
if not math.isfinite(grad_norm):

0 commit comments

Comments
 (0)