Skip to content

[Misc] refactor: simplify EngineCoreClient.make_async_mp_client in AsyncLLM #18817

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 3 additions & 11 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Device, cdiv
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core_client import (AsyncMPClient, DPAsyncMPClient,
RayDPClient)
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
from vllm.v1.engine.output_processor import (OutputProcessor,
RequestOutputCollector)
Expand Down Expand Up @@ -120,15 +119,8 @@ def __init__(
log_stats=self.log_stats)

# EngineCore (starts the engine in background process).
core_client_class: type[AsyncMPClient]
if vllm_config.parallel_config.data_parallel_size == 1:
core_client_class = AsyncMPClient
elif vllm_config.parallel_config.data_parallel_backend == "ray":
core_client_class = RayDPClient
else:
core_client_class = DPAsyncMPClient

self.engine_core = core_client_class(

self.engine_core = EngineCoreClient.make_async_mp_client(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use the make_async_mp_client method instead of importing Client, which can reduce code maintenance costs

vllm_config=vllm_config,
executor_class=executor_class,
log_stats=self.log_stats,
Expand Down
25 changes: 19 additions & 6 deletions vllm/v1/engine/core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,31 @@ def make_client(
"is not currently supported.")

if multiprocess_mode and asyncio_mode:
if vllm_config.parallel_config.data_parallel_size > 1:
if vllm_config.parallel_config.data_parallel_backend == "ray":
return RayDPClient(vllm_config, executor_class, log_stats)
return DPAsyncMPClient(vllm_config, executor_class, log_stats)

return AsyncMPClient(vllm_config, executor_class, log_stats)
return EngineCoreClient.make_async_mp_client(
vllm_config, executor_class, log_stats)

if multiprocess_mode and not asyncio_mode:
return SyncMPClient(vllm_config, executor_class, log_stats)

return InprocClient(vllm_config, executor_class, log_stats)

@staticmethod
def make_async_mp_client(
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
client_addresses: Optional[dict[str, str]] = None,
client_index: int = 0,
) -> "MPClient":
if vllm_config.parallel_config.data_parallel_size > 1:
if vllm_config.parallel_config.data_parallel_backend == "ray":
return RayDPClient(vllm_config, executor_class, log_stats,
client_addresses, client_index)
return DPAsyncMPClient(vllm_config, executor_class, log_stats,
client_addresses, client_index)
return AsyncMPClient(vllm_config, executor_class, log_stats,
client_addresses, client_index)

@abstractmethod
def shutdown(self):
...
Expand Down