Skip to content

Commit e2cb357

Browse files
cbaliogluKiuk Chung
andauthored
(torch.distributed) Add torch.distributed.is_torchelastic_launched() util method + make init_method=tcp:// compatible with torchelastic (pytorch#63910) (pytorch#64826)
Summary: Pull Request resolved: pytorch#63910 Addresses the current issue that `init_method=tcp://` is not compatible with `torch.distributed.run` and `torch.distributed.launch`. When running with a training script that initializes the process group with `init_method=tcp://localhost:$port` as such: ``` $ python -u -m torch.distributed.run --max_restarts 0 --nproc_per_node 1 --nnodes 1 --master_addr $(hostname) --master_port 6000 ~/tmp/test.py ``` An `Address in use` error is raised since the training script tries to create a TCPStore on port 6000, which is already taken since the elastic agent is already running a TCPStore on that port. For details see: pytorch#63874. This change does a couple of things: 1. Adds `is_torchelastic_launched()` check function that users can use in the training scripts to see whether the script is launched via torchelastic. 1. Update the `torch.distributed` docs page to include the new `is_torchelastic_launched()` function. 1. Makes `init_method=tcp://` torchelastic compatible by modifying `_tcp_rendezvous_handler` in `torch.distributed.rendezvous` (this is NOT the elastic rendezvous, it is the old rendezvous module which is slotted for deprecation in future releases) to check `is_torchelastic_launched()` AND `torchelastic_use_agent_store()` and if so, only create TCPStore clients (no daemons, not even for rank 0). 1. Adds a bunch of unittests to cover the different code paths NOTE: the issue mentions that we should fail-fast with an assertion on `init_method!=env://` when `is_torchelastic_launched()` is `True`. There are three registered init_methods in pytorch: env://, tcp://, file://. Since this diff makes tcp:// compatible with torchelastic and I've validated that file is compatible with torchelastic. There is no need to add assertions. I did update the docs to point out that env:// is the RECOMMENDED init_method. We should probably deprecate the other init_methods in the future but this is out of scope for this issue. Test Plan: Unittests. Reviewed By: cbalioglu Differential Revision: D30529984 fbshipit-source-id: 267aea6d4dad73eb14a2680ac921f210ff547cc5 Co-authored-by: Kiuk Chung <kiuk@fb.com>
1 parent eccdfff commit e2cb357

File tree

8 files changed

+320
-35
lines changed

8 files changed

+320
-35
lines changed

docs/source/distributed.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,8 @@ joined.
180180

181181
.. autofunction:: is_nccl_available
182182

183+
.. autofunction:: is_torchelastic_launched
184+
183185
--------------------------------------------------------------------------------
184186

185187
Currently three initialization methods are supported:
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
#!/usr/bin/env python3
2+
3+
# Copyright (c) Facebook, Inc. and its affiliates.
4+
# All rights reserved.
5+
#
6+
# This source code is licensed under the BSD-style license found in the
7+
# LICENSE file in the root directory of this source tree.
8+
9+
import argparse
10+
import os
11+
12+
import torch
13+
import torch.distributed as dist
14+
import torch.nn.functional as F
15+
16+
17+
def parse_args():
18+
parser = argparse.ArgumentParser(description="test script")
19+
20+
parser.add_argument(
21+
"--init_method",
22+
type=str,
23+
required=True,
24+
help="init_method to pass to `dist.init_process_group()` (e.g. env://)",
25+
)
26+
parser.add_argument(
27+
"--world_size",
28+
type=int,
29+
default=os.getenv("WORLD_SIZE", -1),
30+
help="world_size to pass to `dist.init_process_group()`",
31+
)
32+
parser.add_argument(
33+
"--rank",
34+
type=int,
35+
default=os.getenv("RANK", -1),
36+
help="rank to pass to `dist.init_process_group()`",
37+
)
38+
39+
return parser.parse_args()
40+
41+
42+
def main():
43+
args = parse_args()
44+
45+
dist.init_process_group(
46+
backend="gloo",
47+
init_method=args.init_method,
48+
world_size=args.world_size,
49+
rank=args.rank,
50+
)
51+
52+
rank = dist.get_rank()
53+
world_size = dist.get_world_size()
54+
55+
# one hot (by rank) tensor of size world_size
56+
# example:
57+
# rank 0, world_size 4 => [1, 0, 0, 0]
58+
# rank 1, world_size 4 => [0, 1, 0, 0]
59+
# ...
60+
t = F.one_hot(torch.tensor(rank), num_classes=world_size)
61+
62+
# after all_reduce t = tensor.ones(size=world_size)
63+
dist.all_reduce(t)
64+
65+
# adding all elements in t should equal world_size
66+
derived_world_size = torch.sum(t).item()
67+
if derived_world_size != world_size:
68+
raise RuntimeError(
69+
f"Wrong world size derived. Expected: {world_size}, Got: {derived_world_size}"
70+
)
71+
72+
print("Done")
73+
74+
75+
if __name__ == "__main__":
76+
main()
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#!/usr/bin/env python3
2+
3+
# Copyright (c) Facebook, Inc. and its affiliates.
4+
# All rights reserved.
5+
#
6+
# This source code is licensed under the BSD-style license found in the
7+
# LICENSE file in the root directory of this source tree.
8+
9+
"""
10+
This is a test script that launches as part of the test cases in
11+
run_test.py, to validate the correctness of
12+
the method ``torch.distributed.is_torchelastic_launched()``. To do so,
13+
we run this script with and without torchelastic and validate that the
14+
boolean value written to the out_file is indeed what we expect (e.g.
15+
should be False when not launched with torchelastic, True when launched with)
16+
The script itself is not a test case hence no assertions are made in this script.
17+
18+
see: - test/distributed/launcher/run_test.py#test_is_torchelastic_launched()
19+
- test/distributed/launcher/run_test.py#test_is_not_torchelastic_launched()
20+
"""
21+
import argparse
22+
23+
import torch.distributed as dist
24+
25+
26+
def parse_args():
27+
parser = argparse.ArgumentParser(description="test script")
28+
parser.add_argument(
29+
"--out_file",
30+
help="file to write indicating whether this script was launched with torchelastic",
31+
)
32+
return parser.parse_args()
33+
34+
35+
def main():
36+
args = parse_args()
37+
with open(args.out_file, "w") as out:
38+
out.write(f"{dist.is_torchelastic_launched()}")
39+
40+
41+
if __name__ == "__main__":
42+
main()

test/distributed/launcher/run_test.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
# LICENSE file in the root directory of this source tree.
88
import multiprocessing as mp
99
import os
10+
import runpy
1011
import shutil
1112
import subprocess
13+
import sys
1214
import tempfile
1315
import unittest
1416
import uuid
@@ -21,6 +23,7 @@
2123
from torch.distributed.elastic.multiprocessing.errors import ChildFailedError
2224
from torch.distributed.elastic.rendezvous.etcd_server import EtcdServer
2325
from torch.distributed.elastic.utils import get_socket_with_port
26+
from torch.distributed.elastic.utils.distributed import get_free_port
2427
from torch.testing._internal.common_utils import (
2528
TEST_WITH_ASAN,
2629
TEST_WITH_TSAN,
@@ -476,3 +479,117 @@ def test_launch_shutdown(self, agent_mock_cls):
476479
param_mock.return_value = rdzv_handler_mock
477480
launch.main(args)
478481
rdzv_handler_mock.shutdown.assert_called_once()
482+
483+
@sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan")
484+
def test_is_torchelastic_launched(self):
485+
# launch test script with torchelastic and validate that
486+
# torch.distributed.is_torchelastic_launched() returns True
487+
488+
out_file = f"{os.path.join(self.test_dir, 'out')}"
489+
490+
launch.main(
491+
[
492+
"--run_path",
493+
"--nnodes=1",
494+
"--nproc_per_node=1",
495+
"--monitor_interval=1",
496+
path("bin/test_script_is_torchelastic_launched.py"),
497+
f"--out_file={out_file}",
498+
]
499+
)
500+
501+
with open(out_file, "r") as fp:
502+
is_torchelastic_launched = fp.readline()
503+
self.assertEqual("True", is_torchelastic_launched)
504+
505+
def test_is_not_torchelastic_launched(self):
506+
# launch test script without torchelastic and validate that
507+
# torch.distributed.is_torchelastic_launched() returns False
508+
509+
out_file = f"{os.path.join(self.test_dir, 'out')}"
510+
511+
# need to run the script with runpy in the same interpreter
512+
# as the test because otherwise (depending on the environment)
513+
# it will not find torch as a dependency
514+
with patch.object(
515+
sys,
516+
"argv",
517+
[
518+
path("bin/test_script_is_torchelastic_launched.py"),
519+
f"--out_file={out_file}",
520+
],
521+
):
522+
runpy.run_path(sys.argv[0], run_name="__main__")
523+
with open(out_file, "r") as fp:
524+
is_torchelastic_launched = fp.readline()
525+
self.assertEqual("False", is_torchelastic_launched)
526+
527+
def test_init_method_tcp(self):
528+
port = get_free_port()
529+
with patch.object(
530+
sys,
531+
"argv",
532+
[
533+
path("bin/test_script_init_method.py"),
534+
f"--init_method=tcp://localhost:{port}",
535+
"--rank=0",
536+
"--world_size=1",
537+
],
538+
):
539+
runpy.run_path(sys.argv[0], run_name="__main__")
540+
# nothing to validate, just make sure it runs
541+
542+
@sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan")
543+
def test_init_method_tcp_with_torchelastic(self):
544+
port = get_free_port()
545+
launch.main(
546+
[
547+
"--run_path",
548+
"--nnodes=1",
549+
"--nproc_per_node=4",
550+
"--master_addr=localhost",
551+
f"--master_port={port}",
552+
"--monitor_interval=1",
553+
path("bin/test_script_init_method.py"),
554+
f"--init_method=tcp://localhost:{port}",
555+
]
556+
)
557+
# nothing to validate, just make sure it runs
558+
559+
def test_init_method_env(self):
560+
port = get_free_port()
561+
with patch.dict(
562+
os.environ,
563+
{
564+
"RANK": "0",
565+
"WORLD_SIZE": "1",
566+
"MASTER_ADDR": "localhost",
567+
"MASTER_PORT": str(port),
568+
},
569+
), patch.object(
570+
sys,
571+
"argv",
572+
[
573+
path("bin/test_script_init_method.py"),
574+
"--init_method=env://",
575+
],
576+
):
577+
runpy.run_path(sys.argv[0], run_name="__main__")
578+
# nothing to validate, just make sure it runs
579+
580+
@sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan")
581+
def test_init_method_env_with_torchelastic(self):
582+
port = get_free_port()
583+
launch.main(
584+
[
585+
"--run_path",
586+
"--nnodes=1",
587+
"--nproc_per_node=4",
588+
"--master_addr=localhost",
589+
f"--master_port={port}",
590+
"--monitor_interval=1",
591+
path("bin/test_script_init_method.py"),
592+
"--init_method=env://",
593+
]
594+
)
595+
# nothing to validate, just make sure it runs

torch/distributed/distributed_c10d.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import contextlib
22
import logging
3+
import os
34
import pickle
45
import io
56
import torch
@@ -15,21 +16,22 @@
1516
from .constants import default_pg_timeout
1617
from .rendezvous import rendezvous, register_rendezvous_handler # noqa: F401
1718
from torch._C._distributed_c10d import (
18-
AllreduceOptions,
1919
AllreduceCoalescedOptions,
20+
AllreduceOptions,
2021
AllToAllOptions,
2122
BarrierOptions,
2223
BroadcastOptions,
2324
GatherOptions,
2425
PrefixStore,
2526
ProcessGroup,
26-
ReduceOptions,
2727
ReduceOp,
28+
ReduceOptions,
2829
ReduceScatterOptions,
2930
ScatterOptions,
3031
Store,
3132
)
3233

34+
3335
_MPI_AVAILABLE = True
3436
_NCCL_AVAILABLE = True
3537
_GLOO_AVAILABLE = True
@@ -350,6 +352,18 @@ def is_initialized():
350352
return GroupMember.WORLD is not None
351353

352354

355+
def is_torchelastic_launched():
356+
"""
357+
Checks whether this process was launched with ``torch.distributed.elastic``
358+
(aka torchelastic). The existence of ``TORCHELASTIC_RUN_ID`` environment
359+
variable is used as a proxy to determine whether the current process
360+
was launched with torchelastic. This is a reasonable proxy since
361+
``TORCHELASTIC_RUN_ID`` maps to the rendezvous id which is always a
362+
non-null value indicating the job id for peer discovery purposes..
363+
"""
364+
return os.getenv("TORCHELASTIC_RUN_ID") is not None
365+
366+
353367
def _get_default_group():
354368
"""
355369
Getting the default process group created by init_process_group

torch/distributed/launch.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,9 @@
9797
>>> # your code to run
9898
9999
3. In your training program, you are supposed to call the following function
100-
at the beginning to start the distributed backend. You need to make sure that
101-
the init_method uses ``env://``, which is the only supported ``init_method``
102-
by this module.
100+
at the beginning to start the distributed backend. It is strongly recommended
101+
that ``init_method=env://``. Other init methods (e.g. ``tcp://``) may work,
102+
but ``env://`` is the one that is officially supported by this module.
103103
104104
::
105105
@@ -147,6 +147,7 @@
147147

148148
from torch.distributed.run import get_args_parser, run
149149

150+
150151
logger = logging.getLogger(__name__)
151152

152153

@@ -181,7 +182,8 @@ def main(args=None):
181182
"If your script expects `--local_rank` argument to be set, please\n"
182183
"change it to read from `os.environ['LOCAL_RANK']` instead. See \n"
183184
"https://pytorch.org/docs/stable/distributed.html#launch-utility for \n"
184-
"further instructions\n", FutureWarning
185+
"further instructions\n",
186+
FutureWarning,
185187
)
186188
args = parse_args(args)
187189
launch(args)

0 commit comments

Comments
 (0)