Skip to content

Commit

Permalink
Make env init_method support both env and args for rank and size (pyt…
Browse files Browse the repository at this point in the history
…orch#14494)

Summary:
Fixing: pytorch#14446

This was a supported behavior in old torch.distributed. We want to support it in the new release.

Test should cover all combination of scenario when we have either env or arg set up for rank or size or both
Pull Request resolved: pytorch#14494

Differential Revision: D13253433

Pulled By: teng-li

fbshipit-source-id: c05974d84f1bdf969f74ec45763e11a841fe4848
  • Loading branch information
teng-li authored and facebook-github-bot committed Nov 30, 2018
1 parent 1a9602d commit 0d3cb91
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 15 deletions.
49 changes: 49 additions & 0 deletions test/test_c10d.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,10 @@ def test_unknown_handler(self):

class RendezvousEnvTest(TestCase):
def test_common_errors(self):
# TODO remove this hack
if not hasattr(c10d, "ProcessGroupNCCL"):
raise unittest.SkipTest("C10D is not built with NCCL process group,"
" skipping test")
vars = {
"WORLD_SIZE": "2",
"RANK": "0",
Expand All @@ -247,23 +251,68 @@ def without(d, key):
d.pop(key)
return d

def withouts(d, keys):
d = d.copy()
for key in keys:
d.pop(key)
return d

with Env(without(vars, 'WORLD_SIZE')):
with self.assertRaisesRegex(ValueError, 'WORLD_SIZE expected'):
gen = c10d.rendezvous('env://')
next(gen)
c10d.init_process_group(backend='nccl', world_size=2)
self.assertEqual(c10d.get_rank(), 0)
self.assertEqual(c10d.get_world_size(), 2)
c10d.destroy_process_group()

with Env(without(vars, 'RANK')):
with self.assertRaisesRegex(ValueError, 'RANK expected'):
gen = c10d.rendezvous('env://')
next(gen)
c10d.init_process_group(backend='nccl', rank=0)
self.assertEqual(c10d.get_rank(), 0)
self.assertEqual(c10d.get_world_size(), 2)
c10d.destroy_process_group()

with Env(withouts(vars, ['RANK', 'WORLD_SIZE'])):
c10d.init_process_group(backend='nccl', rank=0, world_size=2)
self.assertEqual(c10d.get_rank(), 0)
self.assertEqual(c10d.get_world_size(), 2)
c10d.destroy_process_group()

with Env(vars):
c10d.init_process_group(backend='nccl')
self.assertEqual(c10d.get_rank(), 0)
self.assertEqual(c10d.get_world_size(), 2)
c10d.destroy_process_group()

with Env(without(vars, 'MASTER_ADDR')):
with self.assertRaisesRegex(ValueError, 'MASTER_ADDR expected'):
gen = c10d.rendezvous('env://')
next(gen)

with Env(without(vars, 'MASTER_PORT')):
with self.assertRaisesRegex(ValueError, 'MASTER_PORT expected'):
gen = c10d.rendezvous('env://')
next(gen)

with Env(without(vars, 'WORLD_SIZE')):
gen = c10d.rendezvous('env://?world_size={}'.format(2))
_, _, size = next(gen)
self.assertEqual(size, 2)

with Env(without(vars, 'RANK')):
gen = c10d.rendezvous('env://?rank={}'.format(0))
_, rank, _ = next(gen)
self.assertEqual(rank, 0)

with Env(withouts(vars, ['RANK', 'WORLD_SIZE'])):
gen = c10d.rendezvous('env://?rank={}&world_size={}'.format(0, 2))
_, rank, size = next(gen)
self.assertEqual(rank, 0)
self.assertEqual(size, 2)

@retry_on_address_already_in_use_error
def test_nominal(self):
os.environ['WORLD_SIZE'] = '2'
Expand Down
17 changes: 9 additions & 8 deletions torch/distributed/distributed_c10d.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,14 +343,15 @@ def init_process_group(backend,
_pg_names[_default_pg] = group_name
else:
# backward compatible API
if init_method != "env://" and world_size != -1 and rank != -1:
url = "{}?rank={}&world_size={}".format(init_method,
rank,
world_size)
store, _, _ = next(rendezvous(url))
else:
store, rank, world_size = next(rendezvous(init_method))

url = init_method
if world_size != -1 and rank != -1:
url += "?rank={}&world_size={}".format(rank, world_size)
elif rank != -1:
url += "?rank={}".format(rank)
elif world_size != -1:
url += "?world_size={}".format(world_size)

store, rank, world_size = next(rendezvous(url))
if backend == Backend.GLOO:
_default_pg = ProcessGroupGloo(
store,
Expand Down
26 changes: 19 additions & 7 deletions torch/distributed/rendezvous.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,17 +106,29 @@ def _error(msg):
def _env_error(var):
return _error("environment variable %s expected, but not set" % var)

if url != "env://":
if not url.startswith("env://"):
raise _error("url must be equal to `env://`")
world_size = os.environ.get("WORLD_SIZE", None)
if world_size is None:
raise _env_error("WORLD_SIZE")
rank = os.environ.get("RANK", None)
if rank is None:
raise _env_error("RANK")
result = urlparse(url)
query = dict(pair.split("=") for pair in filter(None, result.query.split("&")))

if "rank" in query:
rank = int(query["rank"])
else:
rank = os.environ.get("RANK", None)
if rank is None:
raise _env_error("RANK")

if "world_size" in query:
world_size = int(query["world_size"])
else:
world_size = os.environ.get("WORLD_SIZE", None)
if world_size is None:
raise _env_error("WORLD_SIZE")

master_addr = os.environ.get("MASTER_ADDR", None)
if master_addr is None:
raise _env_error("MASTER_ADDR")

master_port = os.environ.get("MASTER_PORT", None)
if master_port is None:
raise _env_error("MASTER_PORT")
Expand Down

0 comments on commit 0d3cb91

Please sign in to comment.