Skip to content

[V1] DP scale-out (2/N): Decouple engine process management and comms #15977

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
May 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
8802521
[V1] DP scale-out (2/N): Decouple engine process management and comms
njhill Apr 2, 2025
e869380
Headless mode
njhill Apr 3, 2025
1ca3d15
Wire data_parallel_address arg
njhill Apr 4, 2025
a551183
Some code cleanup
njhill Apr 4, 2025
a662169
Fix offline DP compatibility
njhill Apr 4, 2025
b29dcf4
Merge remote-tracking branch 'refs/remotes/origin/main' into decouple…
njhill Apr 7, 2025
8126f72
Address some review comments
njhill Apr 7, 2025
8fdc6f5
Address other minor review comments
njhill Apr 7, 2025
9c90ad4
Merge remote-tracking branch 'origin/main' into decouple-engines
njhill Apr 17, 2025
80f9c98
Merge remote-tracking branch 'origin/main' into decouple-engines
njhill Apr 17, 2025
efa8ad8
Fix merge error, address @russellb's ipv6 review comment
njhill Apr 17, 2025
30ab14b
Hande ipv6 URIs in all places
njhill Apr 18, 2025
acc5af3
Fix head node with no engines, don't require dp size on other nodes
njhill Apr 19, 2025
1649d7d
Merge remote-tracking branch 'refs/remotes/origin/main' into decouple…
njhill Apr 23, 2025
4fbf90e
Merge remote-tracking branch 'origin/main' into decouple-engines
njhill Apr 23, 2025
86a0453
Merge remote-tracking branch 'refs/remotes/origin/main' into decouple…
njhill Apr 26, 2025
e70545c
Merge remote-tracking branch 'origin/main' into decouple-engines
njhill Apr 27, 2025
24b2e1e
Merge remote-tracking branch 'refs/remotes/origin/main' into decouple…
njhill May 1, 2025
f7a909e
Merge remote-tracking branch 'origin/main' into decouple-engines
njhill May 11, 2025
42c30bf
Fix test_startup_failure
njhill May 12, 2025
3904d10
Fix mock config related test failure
njhill May 12, 2025
cece58a
Merge remote-tracking branch 'origin/main' into decouple-engines
njhill May 12, 2025
02f7263
Merge remote-tracking branch 'origin/main' into decouple-engines
njhill May 12, 2025
e1400f7
Merge remote-tracking branch 'refs/remotes/origin/main' into decouple…
njhill May 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/async_engine/test_async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self):
self.abort_request_calls = 0
self.request_id = None
# Ugly, remove dependency when possible
self.parallel_config = ParallelConfig(1, 1, False)
self.parallel_config = ParallelConfig()
self.model_config = MockModelConfig()

async def step_async(self, virtual_engine):
Expand Down
15 changes: 8 additions & 7 deletions tests/v1/engine/test_engine_core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core import EngineCore
from vllm.v1.engine.core_client import (AsyncMPClient, CoreEngine,
EngineCoreClient, SyncMPClient)
from vllm.v1.engine.core_client import (AsyncMPClient, EngineCoreClient,
SyncMPClient)
from vllm.v1.executor.abstract import Executor
from vllm.v1.utils import CoreEngineProcManager

from ...distributed.conftest import MockSubscriber
from ...utils import create_new_process_for_each_test
Expand Down Expand Up @@ -348,13 +349,13 @@ def test_startup_failure(monkeypatch: pytest.MonkeyPatch):

# Monkey-patch to extract core process pid while it's starting.
core_proc_pid = [None]
ce_ctor = CoreEngine.__init__
cepm_ctor = CoreEngineProcManager.__init__

def patched_ce_ctor(self, *args, **kwargs):
ce_ctor(self, *args, **kwargs)
core_proc_pid[0] = self.proc_handle.proc.pid
def patched_cepm_ctor(self: CoreEngineProcManager, *args, **kwargs):
cepm_ctor(self, *args, **kwargs)
core_proc_pid[0] = self.processes[0].pid

m.setattr(CoreEngine, "__init__", patched_ce_ctor)
m.setattr(CoreEngineProcManager, "__init__", patched_cepm_ctor)

t = time.time()
engine_args = EngineArgs(model=MODEL_NAME)
Expand Down
41 changes: 19 additions & 22 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1668,25 +1668,17 @@ class ParallelConfig:
data_parallel_size: int = 1
"""Number of data parallel groups. MoE layers will be sharded according to
the product of the tensor parallel size and data parallel size."""
data_parallel_size_local: int = 1
"""Number of local data parallel groups."""
data_parallel_rank: int = 0
"""Rank of the data parallel group."""
_data_parallel_rank_local: Optional[int] = field(default=None, init=False)
"""Private field to store the local rank of the data parallel group."""

@property
def data_parallel_rank_local(self) -> int:
"""Local rank of the data parallel group, defaults to global rank."""
if self._data_parallel_rank_local is None:
return self.data_parallel_rank
return self._data_parallel_rank_local

@data_parallel_rank_local.setter
def data_parallel_rank_local(self, value: int) -> None:
"""Set the local rank of the data parallel group."""
self._data_parallel_rank_local = value

data_parallel_rank_local: Optional[int] = None
"""Local rank of the data parallel group,
set only in SPMD mode."""
data_parallel_master_ip: str = "127.0.0.1"
"""IP of the data parallel master."""
data_parallel_rpc_port: int = 29550
"""Port for data parallel messaging."""
data_parallel_master_port: int = 29500
"""Port of the data parallel master."""
enable_expert_parallel: bool = False
Expand Down Expand Up @@ -1734,13 +1726,16 @@ class is dynamically inherited by the worker class. This is used to inject

world_size: int = field(init=False)
"""world_size is TPxPP, it affects the number of workers we create."""
world_size_across_dp: int = field(init=False)
"""world_size_across_dp is TPxPPxDP, it is the size of the world
including data parallelism."""

rank: int = 0
"""Global rank in distributed setup."""

@property
def world_size_across_dp(self) -> int:
"""world_size_across_dp is TPxPPxDP, it is the size of the world
including data parallelism."""
return self.world_size * self.data_parallel_size

def get_next_dp_init_port(self) -> int:
"""
We might need to initialize process groups in multiple
Expand Down Expand Up @@ -1800,10 +1795,14 @@ def __post_init__(self) -> None:
self.world_size = self.pipeline_parallel_size * \
self.tensor_parallel_size

if self.data_parallel_size > 1:
if self.data_parallel_size_local > self.data_parallel_size:
raise ValueError(
f"data_parallel_size_local ({self.data_parallel_size_local}) "
f"must be <= data_parallel_size ({self.data_parallel_size})")

if self.data_parallel_size > 1 or self.data_parallel_size_local == 0:
# Data parallel was specified in the engine args.
self.data_parallel_master_port = get_open_port()
# TODO multi-node
else:
# Otherwise fall back to env vars (e.g. for offline SPMD case).
self.data_parallel_size = envs.VLLM_DP_SIZE
Expand All @@ -1812,8 +1811,6 @@ def __post_init__(self) -> None:
self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP
self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT

self.world_size_across_dp = self.world_size * self.data_parallel_size

if self.distributed_executor_backend == "external_launcher":
import os
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
Expand Down
3 changes: 2 additions & 1 deletion vllm/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import vllm.envs as envs
from vllm.logger import init_logger
from vllm.utils import get_tcp_uri

logger = init_logger(__name__)

Expand Down Expand Up @@ -303,7 +304,7 @@ def stateless_init_torch_distributed_process_group(
always formed with process 1, 2, ..., 8, and the additional communication
channel is formed with process 9 and 10.
"""
init_method = f"tcp://{host}:{port}"
init_method = get_tcp_uri(host, port)
backend = Backend(backend) # it is basically string
timeout = _get_default_timeout(backend)

Expand Down
38 changes: 38 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,9 @@ class EngineArgs:
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
data_parallel_size: int = ParallelConfig.data_parallel_size
data_parallel_size_local: Optional[int] = None
data_parallel_address: Optional[str] = None
data_parallel_rpc_port: Optional[int] = None
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
max_parallel_loading_workers: Optional[
int] = ParallelConfig.max_parallel_loading_workers
Expand Down Expand Up @@ -596,6 +599,21 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
**parallel_kwargs["tensor_parallel_size"])
parallel_group.add_argument("--data-parallel-size", "-dp",
**parallel_kwargs["data_parallel_size"])
parallel_group.add_argument('--data-parallel-size-local',
'-dpl',
type=int,
help='Number of data parallel replicas '
'to run on this node.')
parallel_group.add_argument('--data-parallel-address',
'-dpa',
type=str,
help='Address of data parallel cluster '
'head-node.')
parallel_group.add_argument('--data-parallel-rpc-port',
'-dpp',
type=int,
help='Port for data parallel RPC '
'communication.')
parallel_group.add_argument(
"--enable-expert-parallel",
**parallel_kwargs["enable_expert_parallel"])
Expand Down Expand Up @@ -1019,10 +1037,30 @@ def create_engine_config(
# but we should not do this here.
placement_group = ray.util.get_current_placement_group()

# Local DP size defaults to global DP size if not set.
data_parallel_size_local = self.data_parallel_size if (
self.data_parallel_size_local
is None) else self.data_parallel_size_local

# DP address, used in multi-node case for torch distributed group
# and ZMQ sockets.
data_parallel_address = self.data_parallel_address if (
self.data_parallel_address
is not None) else ParallelConfig.data_parallel_master_ip

# This port is only used when there are remote data parallel engines,
# otherwise the local IPC transport is used.
data_parallel_rpc_port = self.data_parallel_rpc_port if (
self.data_parallel_rpc_port
is not None) else ParallelConfig.data_parallel_rpc_port

parallel_config = ParallelConfig(
pipeline_parallel_size=self.pipeline_parallel_size,
tensor_parallel_size=self.tensor_parallel_size,
data_parallel_size=self.data_parallel_size,
data_parallel_size_local=data_parallel_size_local,
data_parallel_master_ip=data_parallel_address,
data_parallel_rpc_port=data_parallel_rpc_port,
enable_expert_parallel=self.enable_expert_parallel,
max_parallel_loading_workers=self.max_parallel_loading_workers,
disable_custom_all_reduce=self.disable_custom_all_reduce,
Expand Down
81 changes: 79 additions & 2 deletions vllm/entrypoints/cli/serve.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,24 @@
# SPDX-License-Identifier: Apache-2.0

import argparse
import signal

import uvloop

import vllm.envs as envs
from vllm import AsyncEngineArgs
from vllm.entrypoints.cli.types import CLISubcommand
from vllm.entrypoints.openai.api_server import run_server
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
validate_parsed_serve_args)
from vllm.utils import FlexibleArgumentParser
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, get_tcp_uri
from vllm.v1.engine.core import EngineCoreProc
from vllm.v1.engine.core_client import CoreEngineProcManager
from vllm.v1.executor.abstract import Executor

logger = init_logger(__name__)


class ServeSubcommand(CLISubcommand):
Expand All @@ -24,7 +34,10 @@ def cmd(args: argparse.Namespace) -> None:
if hasattr(args, 'model_tag') and args.model_tag is not None:
args.model = args.model_tag

uvloop.run(run_server(args))
if args.headless:
run_headless(args)
else:
uvloop.run(run_server(args))

def validate(self, args: argparse.Namespace) -> None:
validate_parsed_serve_args(args)
Expand All @@ -42,6 +55,18 @@ def subparser_init(
nargs='?',
help="The model tag to serve "
"(optional if specified in config)")
serve_parser.add_argument(
"--headless",
action='store_true',
default=False,
help="Run in headless mode. See multi-node data parallel "
"documentation for more details.")
serve_parser.add_argument(
'--data-parallel-start-rank',
'-dpr',
type=int,
default=0,
help='Starting data parallel rank for secondary nodes.')
serve_parser.add_argument(
"--config",
type=str,
Expand All @@ -57,3 +82,55 @@ def subparser_init(

def cmd_init() -> list[CLISubcommand]:
return [ServeSubcommand()]


def run_headless(args: argparse.Namespace):

# Create the EngineConfig.
engine_args = AsyncEngineArgs.from_cli_args(args)
usage_context = UsageContext.OPENAI_API_SERVER
vllm_config = engine_args.create_engine_config(usage_context=usage_context)

if not envs.VLLM_USE_V1:
raise RuntimeError("Headless mode is only supported for V1")

parallel_config = vllm_config.parallel_config
local_engine_count = parallel_config.data_parallel_size_local
host = parallel_config.data_parallel_master_ip
port = engine_args.data_parallel_rpc_port # add to config too
input_address = get_tcp_uri(host, port)

if local_engine_count <= 0:
raise RuntimeError("data_parallel_size_local must be > 0 in "
"headless mode")

# Catch SIGTERM and SIGINT to allow graceful shutdown.
def signal_handler(signum, frame):
logger.debug("Received %d signal.", signum)
raise SystemExit

signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)

logger.info(
"Launching %d data parallel engine(s) in headless mode, "
"with head node address %s.", local_engine_count, input_address)

# Create the engines.
engine_manager = CoreEngineProcManager(
target_fn=EngineCoreProc.run_engine_core,
local_engine_count=local_engine_count,
start_index=args.data_parallel_start_rank,
local_start_index=0,
vllm_config=vllm_config,
on_head_node=False,
input_address=input_address,
executor_class=Executor.get_class(vllm_config),
log_stats=not engine_args.disable_log_stats,
)

try:
engine_manager.join_first()
finally:
logger.info("Shutting down.")
engine_manager.close()
4 changes: 4 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,10 @@ def is_valid_ipv6_address(address: str) -> bool:


def get_distributed_init_method(ip: str, port: int) -> str:
return get_tcp_uri(ip, port)


def get_tcp_uri(ip: str, port: int) -> str:
# Brackets are not permitted in ipv4 addresses,
# see https://github.com/python/cpython/issues/103848
return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}"
Expand Down
Loading
Loading