From 277d49a590b6745ec82460eea3f33a825a89051c Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Tue, 29 Mar 2022 16:07:31 -0700 Subject: [PATCH] Do not initialize `torch.distributed` process group if one is already initailized (#16487) * Do not initialize torch process group twice * Apply suggestions from code review --- src/transformers/sagemaker/training_args_sm.py | 8 +++++++- src/transformers/training_args.py | 10 ++++++++-- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/src/transformers/sagemaker/training_args_sm.py b/src/transformers/sagemaker/training_args_sm.py index 992f3d4fce3014..f6c57d8f8577d2 100644 --- a/src/transformers/sagemaker/training_args_sm.py +++ b/src/transformers/sagemaker/training_args_sm.py @@ -77,6 +77,11 @@ def __post_init__(self): @cached_property def _setup_devices(self) -> "torch.device": logger.info("PyTorch: setting up devices") + if torch.distributed.is_initialized() and self.local_rank == -1: + logger.warning( + "torch.distributed process group is initialized, but local_rank == -1. " + "In order to use Torch DDP, launch your script with `python -m torch.distributed.launch" + ) if self.no_cuda: device = torch.device("cpu") self._n_gpu = 0 @@ -105,7 +110,8 @@ def _setup_devices(self) -> "torch.device": else: # Here, we'll use torch.distributed. # Initializes the distributed backend which will take care of synchronizing nodes/GPUs - torch.distributed.init_process_group(backend="nccl") + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group(backend="nccl") device = torch.device("cuda", self.local_rank) self._n_gpu = 1 diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 2087fbb7ace0d2..b0e6fbc6e85e04 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1022,10 +1022,15 @@ def eval_batch_size(self) -> int: @torch_required def _setup_devices(self) -> "torch.device": logger.info("PyTorch: setting up devices") + if torch.distributed.is_initialized() and self.local_rank == -1: + logger.warning( + "torch.distributed process group is initialized, but local_rank == -1. " + "In order to use Torch DDP, launch your script with `python -m torch.distributed.launch" + ) if self.no_cuda: device = torch.device("cpu") self._n_gpu = 0 - if self.local_rank != -1: + if self.local_rank != -1 and not torch.distributed.is_initialized(): # Initializes distributed backend for cpu if self.xpu_backend not in ("mpi", "ccl"): raise ValueError( @@ -1076,7 +1081,8 @@ def _setup_devices(self) -> "torch.device": else: # Here, we'll use torch.distributed. # Initializes the distributed backend which will take care of synchronizing nodes/GPUs - torch.distributed.init_process_group(backend="nccl") + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group(backend="nccl") device = torch.device("cuda", self.local_rank) self._n_gpu = 1