Skip to content

Commit 0666250

Browse files
youkaichaolulmer
authored andcommitted
[bugfix] respect distributed_executor_backend in world_size=1 (vllm-project#12934)
Signed-off-by: youkaichao <youkaichao@gmail.com> Signed-off-by: Louis Ulmer <ulmerlouis@gmail.com>
1 parent 61525b4 commit 0666250

File tree

4 files changed

+53
-32
lines changed

4 files changed

+53
-32
lines changed

tests/engine/test_custom_executor.py renamed to tests/engine/test_executor.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def test_custom_executor(model, tmp_path):
5555
engine_args = EngineArgs(
5656
model=model,
5757
distributed_executor_backend=CustomUniExecutor,
58+
enforce_eager=True, # reduce test time
5859
)
5960
engine = LLMEngine.from_engine_args(engine_args)
6061
sampling_params = SamplingParams(max_tokens=1)
@@ -75,7 +76,10 @@ def test_custom_executor_async(model, tmp_path):
7576
assert not os.path.exists(".marker")
7677

7778
engine_args = AsyncEngineArgs(
78-
model=model, distributed_executor_backend=CustomUniExecutorAsync)
79+
model=model,
80+
distributed_executor_backend=CustomUniExecutorAsync,
81+
enforce_eager=True, # reduce test time
82+
)
7983
engine = AsyncLLMEngine.from_engine_args(engine_args)
8084
sampling_params = SamplingParams(max_tokens=1)
8185

@@ -89,3 +93,18 @@ async def t():
8993
assert os.path.exists(".marker")
9094
finally:
9195
os.chdir(cwd)
96+
97+
98+
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
99+
def test_respect_ray(model):
100+
# even for TP=1 and PP=1,
101+
# if users specify ray, we should use ray.
102+
# users might do this if they want to manage the
103+
# resources using ray.
104+
engine_args = EngineArgs(
105+
model=model,
106+
distributed_executor_backend="ray",
107+
enforce_eager=True, # reduce test time
108+
)
109+
engine = LLMEngine.from_engine_args(engine_args)
110+
assert engine.model_executor.uses_ray

vllm/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1401,6 +1401,9 @@ def __post_init__(self) -> None:
14011401
logger.info("Defaulting to use %s for distributed inference",
14021402
backend)
14031403

1404+
if self.distributed_executor_backend is None and self.world_size == 1:
1405+
self.distributed_executor_backend = "uni"
1406+
14041407
self._verify_args()
14051408

14061409
@property

vllm/engine/llm_engine.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,7 @@ def _initialize_kv_caches(self) -> None:
434434
@classmethod
435435
def _get_executor_cls(cls,
436436
engine_config: VllmConfig) -> Type[ExecutorBase]:
437+
# distributed_executor_backend must be set in VllmConfig.__post_init__
437438
distributed_executor_backend = (
438439
engine_config.parallel_config.distributed_executor_backend)
439440
# Initialize the cluster and specify the executor class.
@@ -443,30 +444,29 @@ def _get_executor_cls(cls,
443444
"distributed_executor_backend must be a subclass of "
444445
f"ExecutorBase. Got {distributed_executor_backend}.")
445446
executor_class = distributed_executor_backend
446-
elif engine_config.parallel_config.world_size > 1:
447-
if distributed_executor_backend == "ray":
448-
from vllm.executor.ray_distributed_executor import (
449-
RayDistributedExecutor)
450-
executor_class = RayDistributedExecutor
451-
elif distributed_executor_backend == "mp":
452-
from vllm.executor.mp_distributed_executor import (
453-
MultiprocessingDistributedExecutor)
454-
assert not envs.VLLM_USE_RAY_SPMD_WORKER, (
455-
"multiprocessing distributed executor backend does not "
456-
"support VLLM_USE_RAY_SPMD_WORKER=1")
457-
executor_class = MultiprocessingDistributedExecutor
458-
elif distributed_executor_backend == "uni":
459-
# JAX-style, single-process, multi-device executor.
460-
from vllm.executor.uniproc_executor import UniProcExecutor
461-
executor_class = UniProcExecutor
462-
elif distributed_executor_backend == "external_launcher":
463-
# executor with external launcher
464-
from vllm.executor.uniproc_executor import ( # noqa
465-
ExecutorWithExternalLauncher)
466-
executor_class = ExecutorWithExternalLauncher
467-
else:
447+
elif distributed_executor_backend == "ray":
448+
from vllm.executor.ray_distributed_executor import (
449+
RayDistributedExecutor)
450+
executor_class = RayDistributedExecutor
451+
elif distributed_executor_backend == "mp":
452+
from vllm.executor.mp_distributed_executor import (
453+
MultiprocessingDistributedExecutor)
454+
assert not envs.VLLM_USE_RAY_SPMD_WORKER, (
455+
"multiprocessing distributed executor backend does not "
456+
"support VLLM_USE_RAY_SPMD_WORKER=1")
457+
executor_class = MultiprocessingDistributedExecutor
458+
elif distributed_executor_backend == "uni":
459+
# JAX-style, single-process, multi-device executor.
468460
from vllm.executor.uniproc_executor import UniProcExecutor
469461
executor_class = UniProcExecutor
462+
elif distributed_executor_backend == "external_launcher":
463+
# executor with external launcher
464+
from vllm.executor.uniproc_executor import ( # noqa
465+
ExecutorWithExternalLauncher)
466+
executor_class = ExecutorWithExternalLauncher
467+
else:
468+
raise ValueError("unrecognized distributed_executor_backend: "
469+
f"{distributed_executor_backend}")
470470
return executor_class
471471

472472
@classmethod

vllm/v1/executor/abstract.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,14 @@ def get_class(vllm_config: VllmConfig) -> Type["Executor"]:
2525
parallel_config = vllm_config.parallel_config
2626
distributed_executor_backend = (
2727
parallel_config.distributed_executor_backend)
28-
if distributed_executor_backend is None:
29-
# If the user does not specify the distributed executor backend,
30-
# we will choose the backend based on the world size.
31-
if parallel_config.world_size > 1:
32-
distributed_executor_backend = "mp"
33-
else:
34-
distributed_executor_backend = "uni"
35-
36-
if distributed_executor_backend == "ray":
28+
# distributed_executor_backend must be set in VllmConfig.__post_init__
29+
if isinstance(distributed_executor_backend, type):
30+
if not issubclass(distributed_executor_backend, ExecutorBase):
31+
raise TypeError(
32+
"distributed_executor_backend must be a subclass of "
33+
f"ExecutorBase. Got {distributed_executor_backend}.")
34+
executor_class = distributed_executor_backend
35+
elif distributed_executor_backend == "ray":
3736
executor_class = RayDistributedExecutor
3837
elif distributed_executor_backend == "mp":
3938
from vllm.v1.executor.multiproc_executor import MultiprocExecutor

0 commit comments

Comments
 (0)