diff --git a/deepspeed/comm/torch.py b/deepspeed/comm/torch.py index 17802c2e03ef..f7731536b696 100644 --- a/deepspeed/comm/torch.py +++ b/deepspeed/comm/torch.py @@ -35,9 +35,10 @@ def __init__(self, backend, timeout, init_method, name='torch'): self.init_process_group(backend, timeout, init_method) def init_process_group(self, backend, timeout, init_method): - return torch.distributed.init_process_group(backend, - timeout=timeout, - init_method=init_method) + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group(backend, + timeout=timeout, + init_method=init_method) def all_reduce(self, tensor,