Skip to content

Commit

Permalink
ddp fixes for training (#22874)
Browse files Browse the repository at this point in the history
ddp fixes for stable lm training
  • Loading branch information
winglian authored Apr 21, 2023
1 parent eddf9ee commit d00997e
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1565,12 +1565,13 @@ def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}):
kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb
if is_torch_neuroncore_available():
return model
model = nn.parallel.DistributedDataParallel(
model,
device_ids=[self.args.local_rank] if self.args._n_gpu != 0 else None,
output_device=self.args.local_rank if self.args._n_gpu != 0 else None,
**kwargs,
)
if any(p.requires_grad for p in model.parameters()):
model = nn.parallel.DistributedDataParallel(
model,
device_ids=[self.args.local_rank] if self.args._n_gpu != 0 else None,
output_device=self.args.local_rank if self.args._n_gpu != 0 else None,
**kwargs,
)

# torch.compile() needs to be called after wrapping the model with FSDP or DDP
# to ensure that it accounts for the graph breaks required by those wrappers
Expand Down Expand Up @@ -1920,6 +1921,7 @@ def _inner_training_loop(
(total_batched_samples % args.gradient_accumulation_steps != 0)
and args.parallel_mode == ParallelMode.DISTRIBUTED
and args._no_sync_in_gradient_accumulation
and hasattr(model, "no_sync")
):
# Avoid unnecessary DDP synchronization since there will be no backward pass on this example.
with model.no_sync():
Expand Down

0 comments on commit d00997e

Please sign in to comment.