Skip to content

Allow AsyncLLMEngine.generate to target a specific DP rank #19102

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions examples/online_serving/multi_instance_data_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
from typing import Optional

from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams

"""
To run this example, run the following commands simultaneously with
different CUDA_VISIBLE_DEVICES:
python examples/online_serving/multi_instance_data_parallel.py

vllm serve ibm-research/PowerMoE-3b -dp 2 -dpr 1 \
--data-parallel-address 127.0.0.1 --data-parallel-rpc-port 62300 \
--data-parallel-size-local 1 --enforce-eager --headless

Once both instances have completed the handshake, this example will
send a request to the instance with DP rank 1.
"""


async def main():
engine_args = AsyncEngineArgs(
model="ibm-research/PowerMoE-3b",
data_parallel_size=2,
dtype="auto",
max_model_len=2048,
data_parallel_address="127.0.0.1",
data_parallel_rpc_port=62300,
data_parallel_size_local=1,
enforce_eager=True,
)

engine_client = AsyncLLMEngine.from_engine_args(engine_args)

sampling_params = SamplingParams(
temperature=0.7,
top_p=0.9,
max_tokens=100,
)

prompt = "Who won the 2004 World Series?"
final_output: Optional[RequestOutput] = None
async for output in engine_client.generate(
prompt=prompt,
sampling_params=sampling_params,
request_id="abcdef",
data_parallel_rank=1,
):
final_output = output
if final_output:
print(final_output.outputs[0].text)


if __name__ == "__main__":
asyncio.run(main())
3 changes: 2 additions & 1 deletion tests/tokenization/test_detokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ def _run_incremental_decode(tokenizer,
None,
0.0,
None,
cache_salt=None)
cache_salt=None,
data_parallel_rank=None)

if fast is None:
detokenizer = IncrementalDetokenizer.from_new_request(
Expand Down
1 change: 1 addition & 0 deletions tests/v1/engine/test_engine_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def make_request() -> EngineCoreRequest:
arrival_time=time.time(),
lora_request=None,
cache_salt=None,
data_parallel_rank=None,
)


Expand Down
1 change: 1 addition & 0 deletions tests/v1/engine/test_engine_core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def make_request(
arrival_time=time.time(),
lora_request=None,
cache_salt=None,
data_parallel_rank=None,
)


Expand Down
5 changes: 5 additions & 0 deletions tests/v1/engine/test_output_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind,
eos_token_id=None,
lora_request=None,
cache_salt=None,
data_parallel_rank=None,
sampling_params=SamplingParams(
skip_special_tokens=False,
spaces_between_special_tokens=False,
Expand Down Expand Up @@ -406,6 +407,7 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
eos_token_id=None,
lora_request=None,
cache_salt=None,
data_parallel_rank=None,
sampling_params=SamplingParams(
skip_special_tokens=False,
spaces_between_special_tokens=False,
Expand Down Expand Up @@ -569,6 +571,7 @@ def test_stop_token(include_stop_str_in_output: bool,
eos_token_id=eos_token_id,
lora_request=None,
cache_salt=None,
data_parallel_rank=None,
sampling_params=SamplingParams(
skip_special_tokens=False,
spaces_between_special_tokens=False,
Expand Down Expand Up @@ -666,6 +669,7 @@ def test_stop_string(include_stop_str_in_output: bool,
eos_token_id=None,
lora_request=None,
cache_salt=None,
data_parallel_rank=None,
sampling_params=SamplingParams(
skip_special_tokens=False,
spaces_between_special_tokens=False,
Expand Down Expand Up @@ -780,6 +784,7 @@ def test_iteration_stats(dummy_test_vectors):
eos_token_id=None,
lora_request=None,
cache_salt=None,
data_parallel_rank=None,
sampling_params=SamplingParams(),
) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
]
Expand Down
12 changes: 11 additions & 1 deletion vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,7 @@ async def add_request_async(
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
) -> None:
...

Expand All @@ -456,6 +457,7 @@ async def add_request_async(
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
) -> None:
...

Expand All @@ -473,6 +475,7 @@ async def add_request_async(
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> None:
Expand Down Expand Up @@ -902,6 +905,7 @@ def add_request(
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
) -> Coroutine[None, None, AsyncGenerator[Union[
RequestOutput, PoolingRequestOutput], None]]:
...
Expand All @@ -917,6 +921,7 @@ def add_request(
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
) -> Coroutine[None, None, AsyncGenerator[Union[
RequestOutput, PoolingRequestOutput], None]]:
...
Expand All @@ -935,6 +940,7 @@ async def add_request(
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
Expand Down Expand Up @@ -967,6 +973,7 @@ async def add_request(
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
data_parallel_rank=data_parallel_rank,
)

return stream.generator()
Expand All @@ -980,6 +987,7 @@ async def generate(
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request.

Expand All @@ -999,7 +1007,8 @@ async def generate(
for generation, if any.
priority: The priority of the request.
Only applicable with priority scheduling.

data_parallel_rank: The (global) data parallel rank that must
handle this request. Only applicable if DP is enabled.
Yields:
The output `RequestOutput` objects from the LLMEngine
for the request.
Expand Down Expand Up @@ -1057,6 +1066,7 @@ async def generate(
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
data_parallel_rank=data_parallel_rank,
):
yield LLMEngine.validate_output(output, RequestOutput)
except asyncio.CancelledError:
Expand Down
1 change: 1 addition & 0 deletions vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class EngineCoreRequest(
arrival_time: float
lora_request: Optional[LoRARequest]
cache_salt: Optional[str]
data_parallel_rank: Optional[int]

# Index of the client, used to ensure outputs are sent back to the same
# client for this request when scaling out the front-end.
Expand Down
5 changes: 4 additions & 1 deletion vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ async def add_request(
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
) -> RequestOutputCollector:
"""Add new request to the AsyncLLM."""

Expand All @@ -245,7 +246,7 @@ async def add_request(
prompt_str, request = self.processor.process_inputs(
request_id, prompt, params, arrival_time, lora_request,
tokenization_kwargs, trace_headers, prompt_adapter_request,
priority)
priority, data_parallel_rank)

if params.n == 1:
await self._add_request(request, prompt_str, None, 0, queue)
Expand Down Expand Up @@ -291,6 +292,7 @@ async def generate(
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
) -> AsyncGenerator[RequestOutput, None]:
"""
Main function called by the API server to kick off a request
Expand Down Expand Up @@ -321,6 +323,7 @@ async def generate(
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
data_parallel_rank=data_parallel_rank,
)

# The output_handler task pushes items into the queue.
Expand Down
14 changes: 12 additions & 2 deletions vllm/v1/engine/core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,7 +982,16 @@ async def run_engine_stats_update_task():
resources.stats_update_task = asyncio.create_task(
run_engine_stats_update_task())

def get_core_engine_for_request(self) -> CoreEngine:
def get_core_engine_for_request(self,
dp_rank: Optional[int] = None
) -> CoreEngine:
if dp_rank is not None:
# engines are already in rank order
if dp_rank < 0 or dp_rank >= len(self.core_engines):
raise ValueError(f"Requested DP rank {dp_rank} is out of "
f"range [0, {len(self.core_engines)})")
return self.core_engines[dp_rank]

if not self.lb_engines:
return self.core_engines[0]
# TODO use P2C alg for larger DP sizes
Expand Down Expand Up @@ -1018,7 +1027,8 @@ async def add_request_async(self, request: EngineCoreRequest) -> None:
request.current_wave = self.current_wave
request.client_index = self.client_index

chosen_engine = self.get_core_engine_for_request()
chosen_engine = self.get_core_engine_for_request(
request.data_parallel_rank)
self.reqs_in_flight[request.request_id] = chosen_engine

to_await = self._send_input(EngineCoreRequestType.ADD, request,
Expand Down
2 changes: 2 additions & 0 deletions vllm/v1/engine/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ def process_inputs(
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
data_parallel_rank: Optional[int] = None,
) -> tuple[Optional[str], EngineCoreRequest]:

# TODO(woosuk): Support pooling models.
Expand Down Expand Up @@ -328,6 +329,7 @@ def process_inputs(
arrival_time=arrival_time,
lora_request=lora_request,
cache_salt=decoder_inputs.get("cache_salt"),
data_parallel_rank=data_parallel_rank,
)

def _validate_model_inputs(self,
Expand Down