Skip to content

Commit 3d8f784

Browse files
youkaichaomzusman
authored andcommitted
Support torchrun and SPMD-style offline inference (vllm-project#12071)
Signed-off-by: youkaichao <youkaichao@gmail.com>
1 parent 20240d6 commit 3d8f784

File tree

14 files changed

+248
-30
lines changed

14 files changed

+248
-30
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,7 @@ steps:
463463
- vllm/worker/worker.py
464464
- vllm/worker/model_runner.py
465465
commands:
466+
- torchrun --nproc-per-node=2 distributed/test_torchrun_example.py
466467
- pytest -v -s ./compile/test_basic_correctness.py
467468
- pytest -v -s ./compile/test_wrapper.py
468469
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
"""
2+
experimental support for tensor-parallel inference with torchrun,
3+
see https://github.com/vllm-project/vllm/issues/11400 for
4+
the motivation and use case for this example.
5+
run the script with `torchrun --nproc-per-node=2 torchrun_example.py`,
6+
the argument 2 should match the `tensor_parallel_size` below.
7+
see `tests/distributed/test_torchrun_example.py` for the unit test.
8+
"""
9+
10+
from vllm import LLM, SamplingParams
11+
12+
# Create prompts, the same across all ranks
13+
prompts = [
14+
"Hello, my name is",
15+
"The president of the United States is",
16+
"The capital of France is",
17+
"The future of AI is",
18+
]
19+
20+
# Create sampling parameters, the same across all ranks
21+
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
22+
23+
# Use `distributed_executor_backend="external_launcher"` so that
24+
# this llm engine/instance only creates one worker.
25+
llm = LLM(
26+
model="facebook/opt-125m",
27+
tensor_parallel_size=2,
28+
distributed_executor_backend="external_launcher",
29+
)
30+
31+
outputs = llm.generate(prompts, sampling_params)
32+
33+
# all ranks will have the same outputs
34+
for output in outputs:
35+
prompt = output.prompt
36+
generated_text = output.outputs[0].text
37+
print(f"Prompt: {prompt!r}, "
38+
f"Generated text: {generated_text!r}")
39+
"""
40+
Further tips:
41+
42+
1. to communicate control messages across all ranks, use the cpu group,
43+
a PyTorch ProcessGroup with GLOO backend.
44+
45+
```python
46+
from vllm.distributed.parallel_state import get_world_group
47+
cpu_group = get_world_group().cpu_group
48+
torch_rank = dist.get_rank(group=cpu_group)
49+
if torch_rank == 0:
50+
# do something for rank 0, e.g. saving the results to disk.
51+
```
52+
53+
2. to communicate data across all ranks, use the model's device group,
54+
a PyTorch ProcessGroup with NCCL backend.
55+
```python
56+
from vllm.distributed.parallel_state import get_world_group
57+
device_group = get_world_group().device_group
58+
```
59+
60+
3. to access the model directly in every rank, use the following code:
61+
```python
62+
llm.llm_engine.model_executor.driver_worker.worker.model_runner.model
63+
```
64+
"""
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# unit test for `examples/offline_inference/torchrun_example.py`
2+
3+
import random
4+
5+
import torch.distributed as dist
6+
7+
from vllm import LLM, SamplingParams
8+
from vllm.distributed.parallel_state import get_world_group
9+
10+
# Create prompts
11+
prompts = [
12+
"Hello, my name is",
13+
"The president of the United States is",
14+
"The capital of France is",
15+
"The future of AI is",
16+
]
17+
18+
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
19+
20+
# set different `gpu_memory_utilization` and `swap_space` for different ranks,
21+
# to test if all ranks agree on the same kv cache configuration.
22+
llm = LLM(model="facebook/opt-125m",
23+
tensor_parallel_size=2,
24+
distributed_executor_backend="external_launcher",
25+
gpu_memory_utilization=random.uniform(0.7, 0.9),
26+
swap_space=random.randint(1, 4))
27+
28+
outputs = llm.generate(prompts, sampling_params)
29+
30+
cpu_group = get_world_group().cpu_group
31+
32+
torch_rank = dist.get_rank(group=cpu_group)
33+
34+
35+
def test_consistent_across_ranks(obj):
36+
if torch_rank == 0:
37+
dist.broadcast_object_list([obj], src=0, group=cpu_group)
38+
else:
39+
container = [None]
40+
dist.broadcast_object_list(container, src=0, group=cpu_group)
41+
assert container[0] == obj
42+
43+
44+
test_consistent_across_ranks(
45+
llm.llm_engine.vllm_config.cache_config.num_cpu_blocks)
46+
test_consistent_across_ranks(
47+
llm.llm_engine.vllm_config.cache_config.num_gpu_blocks)
48+
49+
# all ranks should have the same outputs
50+
for output in outputs:
51+
prompt = output.prompt
52+
generated_text = output.outputs[0].text
53+
test_consistent_across_ranks(prompt)
54+
test_consistent_across_ranks(generated_text)
55+
print(f"Rank {torch_rank}, Prompt: {prompt!r}, "
56+
f"Generated text: {generated_text!r}")

tests/engine/test_multiproc_workers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def worker_method(self, worker_input: Any) -> Tuple[int, Any]:
2222
# simulate error case
2323
raise worker_input
2424

25-
return self.rank, input
25+
return self.rpc_rank, input
2626

2727

2828
def _start_workers() -> Tuple[List[ProcessWorkerWrapper], WorkerMonitor]:

vllm/config.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1338,14 +1338,15 @@ def _verify_args(self) -> None:
13381338
from vllm.executor.executor_base import ExecutorBase
13391339
from vllm.platforms import current_platform
13401340
if self.distributed_executor_backend not in (
1341-
"ray", "mp", "uni", None) and not (isinstance(
1341+
"ray", "mp", "uni",
1342+
"external_launcher", None) and not (isinstance(
13421343
self.distributed_executor_backend, type) and issubclass(
13431344
self.distributed_executor_backend, ExecutorBase)):
13441345
raise ValueError(
13451346
"Unrecognized distributed executor backend "
13461347
f"{self.distributed_executor_backend}. Supported "
1347-
"values are 'ray', 'mp' 'uni', or custom ExecutorBase"
1348-
" subclass.")
1348+
"values are 'ray', 'mp' 'uni', 'external_launcher' or"
1349+
" custom ExecutorBase subclass.")
13491350
if self.use_ray:
13501351
from vllm.executor import ray_utils
13511352
ray_utils.assert_ray_available()

vllm/engine/arg_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
388388
# Parallel arguments
389389
parser.add_argument(
390390
'--distributed-executor-backend',
391-
choices=['ray', 'mp'],
391+
choices=['ray', 'mp', 'uni', 'external_launcher'],
392392
default=EngineArgs.distributed_executor_backend,
393393
help='Backend to use for distributed model '
394394
'workers, either "ray" or "mp" (multiprocessing). If the product '

vllm/engine/llm_engine.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,11 @@ def _get_executor_cls(cls,
457457
# JAX-style, single-process, multi-device executor.
458458
from vllm.executor.uniproc_executor import UniProcExecutor
459459
executor_class = UniProcExecutor
460+
elif distributed_executor_backend == "external_launcher":
461+
# executor with external launcher
462+
from vllm.executor.uniproc_executor import ( # noqa
463+
ExecutorWithExternalLauncher)
464+
executor_class = ExecutorWithExternalLauncher
460465
else:
461466
from vllm.executor.uniproc_executor import UniProcExecutor
462467
executor_class = UniProcExecutor

vllm/executor/ray_distributed_executor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
172172
scheduling_strategy=scheduling_strategy,
173173
**ray_remote_kwargs,
174174
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config,
175-
rank=rank)
175+
rpc_rank=rank)
176176
else:
177177
worker = ray.remote(
178178
num_cpus=0,
@@ -181,7 +181,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
181181
scheduling_strategy=scheduling_strategy,
182182
**ray_remote_kwargs,
183183
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config,
184-
rank=rank)
184+
rpc_rank=rank)
185185
worker_metadata.append(
186186
RayWorkerMetaData(worker=worker, created_rank=rank))
187187
rank += 1
@@ -204,7 +204,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
204204
# as the resource holder for the driver process.
205205
self.driver_dummy_worker = worker
206206
self.driver_worker = RayWorkerWrapper(
207-
vllm_config=self.vllm_config, rank=0)
207+
vllm_config=self.vllm_config, rpc_rank=0)
208208
worker_metadata.pop(i)
209209
break
210210

vllm/executor/uniproc_executor.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
1+
import os
12
from typing import Any, Dict, List, Optional, Tuple
23

4+
import torch
5+
import torch.distributed as dist
6+
7+
import vllm.envs as envs
38
from vllm.executor.executor_base import ExecutorBase
49
from vllm.logger import init_logger
510
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
@@ -16,7 +21,7 @@ def _init_executor(self) -> None:
1621
"""Initialize the worker and load the model.
1722
"""
1823
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config,
19-
rank=0)
24+
rpc_rank=0)
2025
distributed_init_method = get_distributed_init_method(
2126
get_ip(), get_open_port())
2227
local_rank = 0
@@ -55,3 +60,77 @@ def check_health(self) -> None:
5560

5661

5762
UniProcExecutorAsync = UniProcExecutor
63+
64+
65+
class ExecutorWithExternalLauncher(UniProcExecutor):
66+
"""An executor that uses external launchers to launch engines,
67+
specially designed for torchrun-compatible launchers, for
68+
offline inference with tensor parallelism.
69+
70+
see https://github.com/vllm-project/vllm/issues/11400 for
71+
the motivation, and examples/offline_inference/torchrun_example.py
72+
for the usage example.
73+
74+
The key idea: although it is tensor-parallel inference, we only
75+
create one worker per executor, users will launch multiple
76+
engines with torchrun-compatible launchers, and all these engines
77+
work together to process the same prompts. When scheduling is
78+
deterministic, all the engines will generate the same outputs,
79+
and they don't need to synchronize the states with each other.
80+
"""
81+
uses_ray: bool = False
82+
83+
def _init_executor(self) -> None:
84+
"""Initialize the worker and load the model.
85+
"""
86+
assert self.vllm_config.parallel_config.pipeline_parallel_size == 1, \
87+
("ExecutorWithExternalLauncher does not "
88+
"support pipeline parallelism.")
89+
assert self.vllm_config.scheduler_config.delay_factor == 0.0, \
90+
("ExecutorWithExternalLauncher needs deterministic "
91+
"execution, so it"
92+
"does not support delay_factor in scheduling")
93+
assert not envs.VLLM_USE_V1, \
94+
("V1 architecture cannot guarantee deterministic execution, "
95+
"so it is not supported in ExecutorWithExternalLauncher.")
96+
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config,
97+
rpc_rank=0)
98+
# engines are launched in torchrun-compatible launchers
99+
# so we can use the env:// method.
100+
# required env vars:
101+
# - RANK
102+
# - MASTER_ADDR
103+
# - MASTER_PORT
104+
distributed_init_method = "env://"
105+
rank = int(os.environ["RANK"])
106+
local_rank = rank
107+
is_driver_worker = True
108+
kwargs = dict(
109+
vllm_config=self.vllm_config,
110+
local_rank=local_rank,
111+
rank=rank,
112+
distributed_init_method=distributed_init_method,
113+
is_driver_worker=is_driver_worker,
114+
)
115+
self.collective_rpc("init_worker", args=([kwargs], ))
116+
self.collective_rpc("init_device")
117+
self.collective_rpc("load_model")
118+
119+
def determine_num_available_blocks(self) -> Tuple[int, int]:
120+
"""
121+
Determine the number of available KV blocks.
122+
Add an additional all_reduce to get the min across all ranks.
123+
Note that even if we have the same `gpu_memory_utilization` and
124+
`swap_space`, the available memory in every rank might still
125+
differ because NCCL can take different amounts of memory in
126+
different ranks. Therefore, it is necessary to test if all ranks
127+
agree on the same KV cache configuration.
128+
"""
129+
a, b = super().determine_num_available_blocks()
130+
from vllm.distributed.parallel_state import get_world_group
131+
cpu_group = get_world_group().cpu_group
132+
a_tensor = torch.tensor([a], device="cpu", dtype=torch.int64)
133+
b_tensor = torch.tensor([b], device="cpu", dtype=torch.int64)
134+
dist.all_reduce(a_tensor, group=cpu_group, op=dist.ReduceOp.MIN)
135+
dist.all_reduce(b_tensor, group=cpu_group, op=dist.ReduceOp.MIN)
136+
return a_tensor.item(), b_tensor.item()

vllm/lora/layers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -940,8 +940,8 @@ def soft_cap(self):
940940
return self.base_layer.soft_cap
941941

942942
@property
943-
def use_gather(self):
944-
return self.base_layer.use_gather
943+
def use_all_gather(self):
944+
return self.base_layer.use_all_gather
945945

946946
@property
947947
def org_vocab_size(self):

vllm/model_executor/layers/logits_processor.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch.nn as nn
77

88
import vllm.envs as envs
9+
from vllm.config import get_current_vllm_config
910
from vllm.distributed import (tensor_model_parallel_all_gather,
1011
tensor_model_parallel_gather)
1112
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -44,8 +45,10 @@ def __init__(self,
4445
self.soft_cap = soft_cap
4546
# Whether to use gather or all-gather to gather the logits.
4647

47-
self.use_gather = not current_platform.is_tpu(
48-
) and not envs.VLLM_USE_V1
48+
parallel_config = get_current_vllm_config().parallel_config
49+
self.use_all_gather = current_platform.is_tpu() \
50+
or envs.VLLM_USE_V1 \
51+
or parallel_config.distributed_executor_backend == "external_launcher" # noqa
4952

5053
def forward(
5154
self,
@@ -88,16 +91,17 @@ def _get_logits(
8891
logits = lm_head.linear_method.apply(lm_head,
8992
hidden_states,
9093
bias=embedding_bias)
91-
if self.use_gather:
92-
# None may be returned for rank > 0
93-
logits = tensor_model_parallel_gather(logits)
94-
else:
94+
95+
if self.use_all_gather:
9596
# Gather is not supported for some devices such as TPUs.
9697
# Use all-gather instead.
9798
# NOTE(woosuk): Here, the outputs of every device should not be None
9899
# because XLA requires strict SPMD among all devices. Every device
99100
# should execute the same operations after gathering the logits.
100101
logits = tensor_model_parallel_all_gather(logits)
102+
else:
103+
# None may be returned for rank > 0
104+
logits = tensor_model_parallel_gather(logits)
101105
# Remove paddings in vocab (if any).
102106
if logits is not None:
103107
logits = logits[..., :self.org_vocab_size]

vllm/v1/executor/multiproc_executor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def __init__(
246246
ready_path: str,
247247
):
248248
self.rank = rank
249-
wrapper = WorkerWrapperBase(vllm_config=vllm_config, rank=rank)
249+
wrapper = WorkerWrapperBase(vllm_config=vllm_config, rpc_rank=rank)
250250
# TODO: move `init_worker` to executor level as a collective rpc call
251251
all_kwargs: List[Dict] = [
252252
{} for _ in range(vllm_config.parallel_config.world_size)

vllm/worker/worker.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,6 @@ def __init__(
5555
self.rank = rank
5656
self.distributed_init_method = distributed_init_method
5757
self.is_driver_worker = is_driver_worker
58-
if is_driver_worker:
59-
assert rank % self.parallel_config.tensor_parallel_size == 0, \
60-
"Driver worker should be rank 0 of tensor parallel group."
6158
if self.model_config.trust_remote_code:
6259
# note: lazy import to avoid importing torch before initializing
6360
from vllm.utils import init_cached_hf_modules

0 commit comments

Comments
 (0)