Skip to content

Commit

Permalink
[Trainer] Add optional communication backends for torch.distributed w…
Browse files Browse the repository at this point in the history
…hen using GPU (huggingface#22247)

Update training_args.py
  • Loading branch information
heya5 authored and raghavanone committed Apr 5, 2023
1 parent 177fd50 commit 1c588fe
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1641,7 +1641,10 @@ def _setup_devices(self) -> "torch.device":
# Here, we'll use torch.distributed.
# Initializes the distributed backend which will take care of synchronizing nodes/GPUs
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl", timeout=self.ddp_timeout_delta)
if self.xpu_backend and self.xpu_backend in ("mpi", "gloo"):
torch.distributed.init_process_group(backend=self.xpu_backend, timeout=self.ddp_timeout_delta)
else:
torch.distributed.init_process_group(backend="nccl", timeout=self.ddp_timeout_delta)
device = torch.device("cuda", self.local_rank)
self._n_gpu = 1

Expand Down

0 comments on commit 1c588fe

Please sign in to comment.