Skip to content

Commit 6fd4a08

Browse files
ruisearch42rasmith
authored andcommitted
[V1] Refactor get_executor_cls (vllm-project#11754)
1 parent d7e85cf commit 6fd4a08

File tree

5 files changed

+26
-46
lines changed

5 files changed

+26
-46
lines changed

tests/v1/engine/test_engine_core.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
from vllm.engine.arg_utils import EngineArgs
99
from vllm.platforms import current_platform
1010
from vllm.v1.engine import EngineCoreRequest
11-
from vllm.v1.engine.async_llm import AsyncLLM
1211
from vllm.v1.engine.core import EngineCore
12+
from vllm.v1.executor.abstract import Executor
1313

1414
if not current_platform.is_cuda():
1515
pytest.skip(reason="V1 currently only supported on CUDA.",
@@ -43,7 +43,7 @@ def test_engine_core(monkeypatch):
4343
"""Setup the EngineCore."""
4444
engine_args = EngineArgs(model=MODEL_NAME)
4545
vllm_config = engine_args.create_engine_config()
46-
executor_class = AsyncLLM._get_executor_cls(vllm_config)
46+
executor_class = Executor.get_class(vllm_config)
4747

4848
engine_core = EngineCore(vllm_config=vllm_config,
4949
executor_class=executor_class)
@@ -149,7 +149,7 @@ def test_engine_core_advanced_sampling(monkeypatch):
149149
"""Setup the EngineCore."""
150150
engine_args = EngineArgs(model=MODEL_NAME)
151151
vllm_config = engine_args.create_engine_config()
152-
executor_class = AsyncLLM._get_executor_cls(vllm_config)
152+
executor_class = Executor.get_class(vllm_config)
153153

154154
engine_core = EngineCore(vllm_config=vllm_config,
155155
executor_class=executor_class)

tests/v1/engine/test_engine_core_client.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
from vllm.platforms import current_platform
1212
from vllm.usage.usage_lib import UsageContext
1313
from vllm.v1.engine import EngineCoreRequest
14-
from vllm.v1.engine.async_llm import AsyncLLM
1514
from vllm.v1.engine.core_client import EngineCoreClient
15+
from vllm.v1.executor.abstract import Executor
1616

1717
if not current_platform.is_cuda():
1818
pytest.skip(reason="V1 currently only supported on CUDA.",
@@ -84,7 +84,7 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
8484
engine_args = EngineArgs(model=MODEL_NAME, compilation_config=3)
8585
vllm_config = engine_args.create_engine_config(
8686
UsageContext.UNKNOWN_CONTEXT)
87-
executor_class = AsyncLLM._get_executor_cls(vllm_config)
87+
executor_class = Executor.get_class(vllm_config)
8888
client = EngineCoreClient.make_client(
8989
multiprocess_mode=multiprocessing_mode,
9090
asyncio_mode=False,
@@ -152,7 +152,7 @@ async def test_engine_core_client_asyncio(monkeypatch):
152152
engine_args = EngineArgs(model=MODEL_NAME)
153153
vllm_config = engine_args.create_engine_config(
154154
usage_context=UsageContext.UNKNOWN_CONTEXT)
155-
executor_class = AsyncLLM._get_executor_cls(vllm_config)
155+
executor_class = Executor.get_class(vllm_config)
156156
client = EngineCoreClient.make_client(
157157
multiprocess_mode=True,
158158
asyncio_mode=True,

vllm/v1/engine/async_llm.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from vllm.v1.engine.detokenizer import Detokenizer
2323
from vllm.v1.engine.processor import Processor
2424
from vllm.v1.executor.abstract import Executor
25-
from vllm.v1.executor.ray_utils import initialize_ray_cluster
2625

2726
logger = init_logger(__name__)
2827

@@ -105,7 +104,7 @@ def from_engine_args(
105104
else:
106105
vllm_config = engine_config
107106

108-
executor_class = cls._get_executor_cls(vllm_config)
107+
executor_class = Executor.get_class(vllm_config)
109108

110109
# Create the AsyncLLM.
111110
return cls(
@@ -127,24 +126,6 @@ def shutdown(self):
127126
if handler := getattr(self, "output_handler", None):
128127
handler.cancel()
129128

130-
@classmethod
131-
def _get_executor_cls(cls, vllm_config: VllmConfig) -> Type[Executor]:
132-
executor_class: Type[Executor]
133-
distributed_executor_backend = (
134-
vllm_config.parallel_config.distributed_executor_backend)
135-
if distributed_executor_backend == "ray":
136-
initialize_ray_cluster(vllm_config.parallel_config)
137-
from vllm.v1.executor.ray_executor import RayExecutor
138-
executor_class = RayExecutor
139-
elif distributed_executor_backend == "mp":
140-
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
141-
executor_class = MultiprocExecutor
142-
else:
143-
assert (distributed_executor_backend is None)
144-
from vllm.v1.executor.uniproc_executor import UniprocExecutor
145-
executor_class = UniprocExecutor
146-
return executor_class
147-
148129
async def add_request(
149130
self,
150131
request_id: str,

vllm/v1/engine/llm_engine.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def from_engine_args(
8989

9090
# Create the engine configs.
9191
vllm_config = engine_args.create_engine_config(usage_context)
92-
executor_class = cls._get_executor_cls(vllm_config)
92+
executor_class = Executor.get_class(vllm_config)
9393

9494
if VLLM_ENABLE_V1_MULTIPROCESSING:
9595
logger.debug("Enabling multiprocessing for LLMEngine.")
@@ -103,24 +103,6 @@ def from_engine_args(
103103
stat_loggers=stat_loggers,
104104
multiprocess_mode=enable_multiprocessing)
105105

106-
@classmethod
107-
def _get_executor_cls(cls, vllm_config: VllmConfig) -> Type[Executor]:
108-
executor_class: Type[Executor]
109-
distributed_executor_backend = (
110-
vllm_config.parallel_config.distributed_executor_backend)
111-
if distributed_executor_backend == "ray":
112-
from vllm.v1.executor.ray_executor import RayExecutor
113-
executor_class = RayExecutor
114-
elif distributed_executor_backend == "mp":
115-
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
116-
executor_class = MultiprocExecutor
117-
else:
118-
assert (distributed_executor_backend is None)
119-
from vllm.v1.executor.uniproc_executor import UniprocExecutor
120-
executor_class = UniprocExecutor
121-
122-
return executor_class
123-
124106
def get_num_unfinished_requests(self) -> int:
125107
return self.detokenizer.get_num_unfinished_requests()
126108

vllm/v1/executor/abstract.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABC, abstractmethod
2-
from typing import Tuple
2+
from typing import Tuple, Type
33

44
from vllm.config import VllmConfig
55
from vllm.v1.outputs import ModelRunnerOutput
@@ -8,6 +8,23 @@
88
class Executor(ABC):
99
"""Abstract class for executors."""
1010

11+
@staticmethod
12+
def get_class(vllm_config: VllmConfig) -> Type["Executor"]:
13+
executor_class: Type[Executor]
14+
distributed_executor_backend = (
15+
vllm_config.parallel_config.distributed_executor_backend)
16+
if distributed_executor_backend == "ray":
17+
from vllm.v1.executor.ray_executor import RayExecutor
18+
executor_class = RayExecutor
19+
elif distributed_executor_backend == "mp":
20+
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
21+
executor_class = MultiprocExecutor
22+
else:
23+
assert (distributed_executor_backend is None)
24+
from vllm.v1.executor.uniproc_executor import UniprocExecutor
25+
executor_class = UniprocExecutor
26+
return executor_class
27+
1128
@abstractmethod
1229
def __init__(self, vllm_config: VllmConfig) -> None:
1330
raise NotImplementedError

0 commit comments

Comments
 (0)