-
Notifications
You must be signed in to change notification settings - Fork 71
/
Copy pathdistributed_utils.py
346 lines (288 loc) · 13.4 KB
/
distributed_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines classes to enable running tests in a distributed setting."""
# The following code is copied and adapted from the DeepSpeed repo:
# https://github.com/microsoft/DeepSpeed/blob/master/tests/unit/common.py
import inspect
import os
import socket
import time
import uuid
from abc import ABC, abstractmethod
from typing import List, Union
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from _pytest.fixtures import FixtureLookupError
from _pytest.outcomes import Skipped
from optimum.neuron.utils.cache_utils import get_num_neuron_cores
from optimum.neuron.utils.import_utils import (
is_neuronx_distributed_available,
is_torch_neuronx_available,
is_torch_xla_available,
)
if is_torch_neuronx_available():
import torch_neuronx
if is_torch_xla_available():
import torch_xla.distributed.xla_backend as xbn
if is_neuronx_distributed_available():
import neuronx_distributed
TEST_TIMEOUT = 600
def is_neuron_environment_available() -> bool:
return get_num_neuron_cores() > 0
def get_xdist_worker_id():
xdist_worker = os.environ.get("PYTEST_XDIST_WORKER", None)
if xdist_worker is not None:
xdist_worker_id = xdist_worker.replace("gw", "")
return int(xdist_worker_id)
return None
def get_master_port(base_port=29500, port_range_size=1000):
xdist_worker_id = get_xdist_worker_id()
if xdist_worker_id is not None:
# Make xdist workers use different port ranges to avoid race conditions
base_port += port_range_size * xdist_worker_id
# Select first open port in range
port = base_port
max_port = base_port + port_range_size
sock = socket.socket()
while port < max_port:
try:
sock.bind(("", port))
sock.close()
return str(port)
except OSError:
port += 1
raise IOError("no free ports")
class DistributedExec(ABC):
"""
Base class for distributed execution of functions/methods. Contains common
methods needed for DistributedTest and DistributedFixture (not included in this file).
"""
world_size: Union[int, List[int]] = 2
tp_size: int = 1
pp_size: int = 1
backend: str = "xla"
init_distributed: bool = True
set_dist_env: bool = True
requires_neuron_environment: bool = True
reuse_dist_env: bool = False
_pool_cache = {}
exec_timeout: int = TEST_TIMEOUT
@abstractmethod
def run(self): ...
def __call__(self, request=None):
self._fixture_kwargs = self._get_fixture_kwargs(request, self.run)
world_size = self.world_size
if self.requires_neuron_environment and not is_neuron_environment_available():
pytest.skip("Only supported in a Neuron environment.")
if isinstance(world_size, int):
world_size = [world_size]
for procs in world_size:
self._launch_procs(procs, self.tp_size, self.pp_size)
def _get_fixture_kwargs(self, request, func):
if not request:
return {}
# Grab fixture / parametrize kwargs from pytest request object
fixture_kwargs = {}
params = inspect.getfullargspec(func).args
params.remove("self")
for p in params:
try:
fixture_kwargs[p] = request.getfixturevalue(p)
except FixtureLookupError:
pass # test methods can have kwargs that are not fixtures
return fixture_kwargs
def _launch_procs(self, num_procs, tp_size, pp_size):
if not is_torch_neuronx_available() or not is_torch_xla_available() or not is_neuronx_distributed_available():
raise RuntimeError(
"The `torch_neuronx`, `torch_xla` and `neuronx_distributed` packages are required to run a distributed "
"test."
)
# Verify we have enough accelerator devices to run this test
num_cores = get_num_neuron_cores()
if 0 < num_cores < num_procs:
pytest.skip(
f"Skipping test because not enough Neuron cores are available: {num_procs} required, {num_cores} "
"available."
)
# Set start method to `forkserver` (or `fork`)
mp.set_start_method("forkserver", force=True)
# We cannot set environment variable `TORCHELASTIC_RUN_ID` here because `torch_neuronx` will
# configure PJRT if it is set. Instead we store the value and set it once the other environment
# variables to simulate a `torchrun` execution (e.g. `LOCAL_RANK`, `RANK`, `WORLD_SIZE`, ...) can be set.
self.torchelastic_run_id = str(uuid.uuid4())
# Create process pool or use cached one
master_port = None
if self.reuse_dist_env:
if num_procs not in self._pool_cache:
self._pool_cache[num_procs] = mp.Pool(processes=num_procs)
master_port = get_master_port()
pool = self._pool_cache[num_procs]
else:
pool = mp.Pool(processes=num_procs)
master_port = get_master_port()
# Run the test
args = [(local_rank, num_procs, master_port, tp_size, pp_size) for local_rank in range(num_procs)]
skip_msgs_async = pool.starmap_async(self._dist_run, args)
skip_msgs = "" # Otherwise the linter complains.
try:
skip_msgs = skip_msgs_async.get(self.exec_timeout)
except mp.TimeoutError:
# Shortcut to exit pytest in the case of a hanged test. This
# usually means an environment error and the rest of tests will
# hang (causing super long unit test runtimes)
pytest.exit("Test hanged, exiting", returncode=0)
except Exception as e:
self._close_pool(pool, num_procs, use_terminate=True)
raise e
finally:
# Tear down distributed environment and close process pools
self._close_pool(pool, num_procs)
# If we skipped a test, propagate that to this process
if any(skip_msgs):
assert len(set(skip_msgs)) == 1, "Multiple different skip messages received"
pytest.skip(skip_msgs[0])
def _dist_run(self, local_rank, num_procs, master_port, tp_size, pp_size):
skip_msg = ""
if not dist.is_initialized():
"""Initializes communication and executes the user function."""
if self.set_dist_env:
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = str(master_port)
# Unit tests do not support multi-node so local_rank == global rank
os.environ["LOCAL_RANK"] = str(local_rank)
os.environ["RANK"] = str(local_rank)
os.environ["LOCAL_SIZE"] = str(num_procs)
os.environ["WORLD_SIZE"] = str(num_procs)
os.environ["LOCAL_WORLD_SIZE"] = str(num_procs)
# Unit tests do not support multi-node so there is only one group in our case
os.environ["GROUP_RANK"] = "0"
if not hasattr(self, "torchelastic_run_id"):
raise RuntimeError("self.torchelastic_run_id was not set, it is needed to run a distributed test.")
os.environ["TORCHELASTIC_RUN_ID"] = self.torchelastic_run_id
# Now that the environment has been set, we can configure the PJRT environment.
torch_neuronx.xla.configure_pjrt_environment()
if self.init_distributed:
dist.init_process_group(backend=self.backend, rank=local_rank, world_size=num_procs)
if not isinstance(torch.distributed.group.WORLD, xbn.ProcessGroupXla):
raise AssertionError("Failed to initialize torch.distributed process group using XLA backend.")
# Intializing NxD.
neuronx_distributed.parallel_layers.parallel_state.initialize_model_parallel(
tensor_model_parallel_size=tp_size,
pipeline_model_parallel_size=pp_size,
)
try:
self.run(**self._fixture_kwargs)
except BaseException as e:
if isinstance(e, Skipped):
skip_msg = e.msg
else:
raise e
return skip_msg
def _dist_destroy(self):
if (dist is not None) and dist.is_initialized():
dist.barrier()
dist.destroy_process_group()
def _close_pool(self, pool, num_procs, force=False, use_terminate=False):
if force or not self.reuse_dist_env:
try:
_ = pool.starmap(self._dist_destroy, [() for _ in range(num_procs)])
if use_terminate:
pool.terminate()
else:
pool.close()
pool.join()
except ValueError:
pass
class DistributedTest(DistributedExec):
"""
Implementation for running pytest with distributed execution.
"""
is_dist_test = True
def early_skip(self, fixtures_kwargs):
"""
Override to enable early test skipping (before processes creation).
"""
pass
# Temporary directory that is shared among test methods in a class
@pytest.fixture(autouse=True, scope="class")
def class_tmpdir(self, tmpdir_factory):
fn = tmpdir_factory.mktemp(self.__class__.__name__)
return fn
def run(self, **fixture_kwargs):
self._current_test(**fixture_kwargs)
def __call__(self, request):
self._current_test = self._get_current_test_func(request)
self._fixture_kwargs = self._get_fixture_kwargs(request, self._current_test)
if self.requires_neuron_environment and not is_neuron_environment_available():
pytest.skip("Only supported in a Neuron environment.")
self.early_skip(self._fixture_kwargs)
world_size = tp_size = pp_size = parallel_sizes = None
# Catch world_size, tp_size or pp_size override pytest mark.
def try_to_override_via_pytest_mark(mark, name):
if mark.name == name:
return mark.args[0]
return None
for mark in getattr(request.function, "pytestmark", []):
world_size = try_to_override_via_pytest_mark(mark, "world_size")
tp_size = try_to_override_via_pytest_mark(mark, "tp_size")
pp_size = try_to_override_via_pytest_mark(mark, "pp_size")
parallel_sizes = try_to_override_via_pytest_mark(mark, "parallel_sizes")
# Catch world_size, tp_size or pp_size override via fixture.
def try_to_override_via_fixture(name, current_value):
if name in self._fixture_kwargs:
if current_value is not None:
raise ValueError(f"It is not possible to override {name} both via pytest.mark and fixtures.")
return self._fixture_kwargs[name]
return current_value
world_size = try_to_override_via_fixture("world_size", world_size)
tp_size = try_to_override_via_fixture("tp_size", tp_size)
pp_size = try_to_override_via_fixture("pp_size", pp_size)
parallel_sizes = try_to_override_via_fixture("parallel_sizes", parallel_sizes)
if parallel_sizes is not None:
if not all(size is None for size in [world_size, tp_size, pp_size]):
raise ValueError("Either specify parallel_sizes or specific size (world_size, tp_size, pp_size)")
world_size, tp_size, pp_size = parallel_sizes
if world_size is None:
world_size = self.world_size
if tp_size is None:
tp_size = self.tp_size
if pp_size is None:
pp_size = self.pp_size
sizes = [world_size, tp_size, pp_size]
if all(isinstance(size, int) for size in sizes):
world_size = [world_size]
tp_size = [tp_size]
pp_size = [pp_size]
else:
lengths = [len(size) for size in sizes if not isinstance(size, int)]
if len(set(lengths)) != 1:
raise ValueError(
"When providing multiple values for either world_size, tp_size or pp_size, you must provide the "
f"same number of values. Here: {', '.join(lengths)}."
)
if not all(isinstance(size, (tuple, list)) for size in sizes):
length = lengths[0]
world_size = [world_size] * length if isinstance(world_size, int) else world_size
tp_size = [tp_size] * length if isinstance(tp_size, int) else tp_size
pp_size = [pp_size] * length if isinstance(pp_size, int) else pp_size
for sizes in zip(world_size, tp_size, pp_size):
self._launch_procs(*sizes)
time.sleep(0.5)
def _get_current_test_func(self, request):
# DistributedTest subclasses may have multiple test methods
func_name = request.function.__name__
return getattr(self, func_name)