Skip to content

Commit

Permalink
Do not initialize torch.distributed process group if one is already…
Browse files Browse the repository at this point in the history
… initailized (huggingface#16487)

* Do not initialize torch process group twice

* Apply suggestions from code review
  • Loading branch information
Yard1 authored Mar 29, 2022
1 parent 2b48323 commit 277d49a
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
8 changes: 7 additions & 1 deletion src/transformers/sagemaker/training_args_sm.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ def __post_init__(self):
@cached_property
def _setup_devices(self) -> "torch.device":
logger.info("PyTorch: setting up devices")
if torch.distributed.is_initialized() and self.local_rank == -1:
logger.warning(
"torch.distributed process group is initialized, but local_rank == -1. "
"In order to use Torch DDP, launch your script with `python -m torch.distributed.launch"
)
if self.no_cuda:
device = torch.device("cpu")
self._n_gpu = 0
Expand Down Expand Up @@ -105,7 +110,8 @@ def _setup_devices(self) -> "torch.device":
else:
# Here, we'll use torch.distributed.
# Initializes the distributed backend which will take care of synchronizing nodes/GPUs
torch.distributed.init_process_group(backend="nccl")
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl")
device = torch.device("cuda", self.local_rank)
self._n_gpu = 1

Expand Down
10 changes: 8 additions & 2 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1022,10 +1022,15 @@ def eval_batch_size(self) -> int:
@torch_required
def _setup_devices(self) -> "torch.device":
logger.info("PyTorch: setting up devices")
if torch.distributed.is_initialized() and self.local_rank == -1:
logger.warning(
"torch.distributed process group is initialized, but local_rank == -1. "
"In order to use Torch DDP, launch your script with `python -m torch.distributed.launch"
)
if self.no_cuda:
device = torch.device("cpu")
self._n_gpu = 0
if self.local_rank != -1:
if self.local_rank != -1 and not torch.distributed.is_initialized():
# Initializes distributed backend for cpu
if self.xpu_backend not in ("mpi", "ccl"):
raise ValueError(
Expand Down Expand Up @@ -1076,7 +1081,8 @@ def _setup_devices(self) -> "torch.device":
else:
# Here, we'll use torch.distributed.
# Initializes the distributed backend which will take care of synchronizing nodes/GPUs
torch.distributed.init_process_group(backend="nccl")
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl")
device = torch.device("cuda", self.local_rank)
self._n_gpu = 1

Expand Down

0 comments on commit 277d49a

Please sign in to comment.