Skip to content

ddp fixes for training #22874

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 21, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1565,12 +1565,13 @@ def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}):
kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb
if is_torch_neuroncore_available():
return model
model = nn.parallel.DistributedDataParallel(
model,
device_ids=[self.args.local_rank] if self.args._n_gpu != 0 else None,
output_device=self.args.local_rank if self.args._n_gpu != 0 else None,
**kwargs,
)
if any(p.requires_grad for p in model.parameters()):
model = nn.parallel.DistributedDataParallel(
model,
device_ids=[self.args.local_rank] if self.args._n_gpu != 0 else None,
output_device=self.args.local_rank if self.args._n_gpu != 0 else None,
**kwargs,
)

# torch.compile() needs to be called after wrapping the model with FSDP or DDP
# to ensure that it accounts for the graph breaks required by those wrappers
Expand Down Expand Up @@ -1920,6 +1921,7 @@ def _inner_training_loop(
(total_batched_samples % args.gradient_accumulation_steps != 0)
and args.parallel_mode == ParallelMode.DISTRIBUTED
and args._no_sync_in_gradient_accumulation
and hasattr(model, "no_sync")
):
# Avoid unnecessary DDP synchronization since there will be no backward pass on this example.
with model.no_sync():
Expand Down