Skip to content

Commit

Permalink
Fix regression w. dist_init_required (#2225)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffra authored Aug 17, 2022
1 parent 9b418c1 commit 7d8ad45
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 14 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ repos:
name: check-torchdist
entry: ./scripts/check-torchdist.py
language: script
exclude: ^(deepspeed/comm/|docs/|benchmarks/|scripts/check-torchdist.py|deepspeed/moe/sharded_moe.py|deepspeed/runtime/comm/coalesced_collectives.py|deepspeed/elasticity/elastic_agent.py|deepspeed/launcher/launch.py)
exclude: ^(deepspeed/comm/|docs/|benchmarks/|scripts/check-torchdist.py|deepspeed/moe/sharded_moe.py|deepspeed/runtime/comm/coalesced_collectives.py|deepspeed/elasticity/elastic_agent.py|deepspeed/launcher/launch.py|tests/unit/comm/test_dist.py)
# Specific deepspeed/ files are excluded for now until we wrap ProcessGroup in deepspeed.comm

- repo: https://github.com/codespell-project/codespell
Expand Down
5 changes: 5 additions & 0 deletions deepspeed/comm/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,11 @@ def init_distributed(dist_backend="nccl",
if dist_init_required is None:
dist_init_required = cdb is None or not cdb.is_initialized()

if cdb is None and torch.distributed.is_initialized():
# The user initialized torch.dist themselves, create cdb and short-circuit
cdb = TorchBackend(dist_backend, timeout, init_method)
return

if dist_init_required is False:
assert (
cdb is not None and cdb.is_initialized() is True
Expand Down
84 changes: 83 additions & 1 deletion tests/unit/comm/test_dist.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import torch
import deepspeed.comm as dist
import deepspeed

from tests.unit.common import DistributedTest
from tests.unit.common import DistributedTest, get_master_port
from tests.unit.simple_model import SimpleModel

import pytest

Expand Down Expand Up @@ -71,3 +73,83 @@ def test(self):
result = torch.ones(1, 3).cuda() * sum_of_ranks
dist.all_reduce(x)
assert torch.all(x == result)


@pytest.mark.parametrize("dist_init_required", [True, False, None])
class TestDistInit(DistributedTest):
init_distributed = False

def test_already_init(self, dist_init_required):
torch.distributed.init_process_group('nccl')
deepspeed.init_distributed('nccl', dist_init_required=dist_init_required)

def test_no_init(self, dist_init_required):
if dist_init_required or dist_init_required is None:
deepspeed.init_distributed('nccl', dist_init_required=dist_init_required)
else:
# torch.dist is not done and for some reason the user says they don't want it done
with pytest.raises(Exception):
deepspeed.init_distributed('nccl', dist_init_required=dist_init_required)


class TestDistInitNoEnv(DistributedTest):
world_size = 1
init_distributed = False
set_dist_env = False

def test(self):
torch.distributed.init_process_group(
backend='nccl',
init_method=f"tcp://127.0.0.1:{get_master_port()}",
world_size=1,
rank=0)
assert torch.distributed.is_initialized()
deepspeed.init_distributed('nccl', auto_mpi_discovery=True)


@pytest.mark.parametrize("dist_init_required", [True, False])
class TestDistInitWithModel(DistributedTest):
init_distributed = False

def test_already_init(self, dist_init_required):
torch.distributed.init_process_group('nccl')
model = SimpleModel(4)
config_dict = {
"train_micro_batch_size_per_gpu": 1,
"optimizer": {
"type": "Adam",
"params": {}
}
}
engine, *_ = deepspeed.initialize(
model=model,
config=config_dict,
model_parameters=model.parameters(),
dist_init_required=dist_init_required
)

def test_no_init(self, dist_init_required):
model = SimpleModel(4)
config_dict = {
"train_micro_batch_size_per_gpu": 1,
"optimizer": {
"type": "Adam",
"params": {}
}
}
if dist_init_required:
engine, *_ = deepspeed.initialize(
model=model,
config=config_dict,
model_parameters=model.parameters(),
dist_init_required=dist_init_required
)
else:
# torch.dist is not done and for some reason the user says they don't want it done
with pytest.raises(Exception):
engine, *_ = deepspeed.initialize(
model=model,
config=config_dict,
model_parameters=model.parameters(),
dist_init_required=dist_init_required
)
29 changes: 17 additions & 12 deletions tests/unit/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ class DistributedTest(ABC):
is_dist_test = True
world_size = 2
backend = "nccl"
init_distributed = True
set_dist_env = True

# Temporary directory that is shared among test methods in a class
@pytest.fixture(autouse=True, scope="class")
Expand Down Expand Up @@ -151,20 +153,22 @@ def _launch_procs(self, num_procs):

def _dist_init(self, local_rank, num_procs, skip_msg):
"""Initialize deepspeed.comm and execute the user function. """
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = get_master_port()
os.environ['LOCAL_RANK'] = str(local_rank)
# NOTE: unit tests don't support multi-node so local_rank == global rank
os.environ['RANK'] = str(local_rank)
os.environ['WORLD_SIZE'] = str(num_procs)
if self.set_dist_env:
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = get_master_port()
os.environ['LOCAL_RANK'] = str(local_rank)
# NOTE: unit tests don't support multi-node so local_rank == global rank
os.environ['RANK'] = str(local_rank)
os.environ['WORLD_SIZE'] = str(num_procs)

# turn off NCCL logging if set
os.environ.pop('NCCL_DEBUG', None)

set_cuda_visibile()

deepspeed.init_distributed(dist_backend=self.backend)
dist.barrier()
if self.init_distributed:
deepspeed.init_distributed(dist_backend=self.backend)
dist.barrier()

if torch.cuda.is_available():
torch.cuda.set_device(local_rank)
Expand All @@ -177,10 +181,11 @@ def _dist_init(self, local_rank, num_procs, skip_msg):
else:
raise e

# make sure all ranks finish at the same time
dist.barrier()
# tear down after test completes
dist.destroy_process_group()
if self.init_distributed or dist.is_initialized():
# make sure all ranks finish at the same time
dist.barrier()
# tear down after test completes
dist.destroy_process_group()


def distributed_test(world_size=2, backend='nccl'):
Expand Down

0 comments on commit 7d8ad45

Please sign in to comment.