Skip to content

[Core] Raise when non-multi-instance DP clients target a DP rank #19227

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions tests/async_engine/test_async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,3 +384,25 @@ async def test_delayed_generator(async_engine, stop):
assert final_output is not None
assert len(final_output.outputs[0].token_ids) == 10
assert final_output.finished


@pytest.mark.asyncio(scope="module")
async def test_invalid_argument(async_engine):
scheduler_config = await async_engine.get_scheduler_config()

if scheduler_config.num_scheduler_steps != 1:
pytest.skip("no need to test this one with multistep")

sampling_params = SamplingParams(
temperature=0,
min_tokens=10,
max_tokens=10,
)

# Targeting specific DP rank only supported in v1 multi-instance DP
with pytest.raises(ValueError):
async for _ in async_engine.generate("test",
sampling_params,
request_id=uid(),
data_parallel_rank=0):
pass
29 changes: 29 additions & 0 deletions tests/v1/engine/test_async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,3 +250,32 @@ async def test_customize_loggers(monkeypatch):
assert len(engine.stat_loggers) == 1
assert len(engine.stat_loggers[0]) == 1
engine.stat_loggers[0][0].log.assert_called_once()


@pytest.mark.asyncio(scope="module")
async def test_dp_rank_argument(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")

engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
after.callback(engine.shutdown)

sampling_params = SamplingParams(max_tokens=100,
output_kind=RequestOutputKind.DELTA,
temperature=1.0,
seed=33)

# Test with valid DP rank.
async for _ in engine.generate(request_id="request-34",
prompt=TEXT_PROMPT,
sampling_params=sampling_params,
data_parallel_rank=0):
pass

# Test with out-of-range DP rank.
with pytest.raises(ValueError):
async for _ in engine.generate(request_id="request-35",
prompt=TEXT_PROMPT,
sampling_params=sampling_params,
data_parallel_rank=1):
pass
25 changes: 16 additions & 9 deletions tests/v1/test_async_llm_dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@
allow_module_level=True)


async def generate(engine: AsyncLLM,
request_id: str,
prompt: PromptType,
output_kind: RequestOutputKind,
max_tokens: int,
prompt_logprobs: Optional[int] = None) -> tuple[int, str]:
async def generate(
engine: AsyncLLM,
request_id: str,
prompt: PromptType,
output_kind: RequestOutputKind,
max_tokens: int,
prompt_logprobs: Optional[int] = None,
data_parallel_rank: Optional[int] = None) -> tuple[int, str]:
# Ensure generate doesn't complete too fast for cancellation test.
await asyncio.sleep(0.2)

Expand All @@ -46,7 +48,8 @@ async def generate(engine: AsyncLLM,
prompt_logprobs=prompt_logprobs)
async for out in engine.generate(request_id=request_id,
prompt=prompt,
sampling_params=sampling_params):
sampling_params=sampling_params,
data_parallel_rank=data_parallel_rank):

num_tokens = len(out.outputs[0].token_ids)
if output_kind == RequestOutputKind.DELTA:
Expand Down Expand Up @@ -89,8 +92,12 @@ async def test_load(output_kind: RequestOutputKind,
for request_id in request_ids:
tasks.append(
asyncio.create_task(
generate(engine, request_id, prompt, output_kind,
NUM_EXPECTED_TOKENS)))
generate(engine,
request_id,
prompt,
output_kind,
NUM_EXPECTED_TOKENS,
data_parallel_rank=0)))
# Confirm that we got all the EXPECTED tokens from the requests.
done, pending = await asyncio.wait(tasks,
return_when=asyncio.FIRST_EXCEPTION)
Expand Down
4 changes: 4 additions & 0 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,10 @@ async def add_request_async(
if arrival_time is None:
arrival_time = time.time()

if data_parallel_rank is not None:
raise ValueError("Targeting data_parallel_rank only supported "
"in v1 client.")

if (isinstance(prompt, dict)
and prompt.get("prompt_embeds", None) is not None
and not prompt.get("prompt_token_ids", None)):
Expand Down
3 changes: 0 additions & 3 deletions vllm/v1/engine/core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1000,9 +1000,6 @@ def get_core_engine_for_request(self,
) -> CoreEngine:
if dp_rank is not None:
# engines are already in rank order
if dp_rank < 0 or dp_rank >= len(self.core_engines):
raise ValueError(f"Requested DP rank {dp_rank} is out of "
f"range [0, {len(self.core_engines)})")
return self.core_engines[dp_rank]

if not self.lb_engines:
Expand Down
6 changes: 6 additions & 0 deletions vllm/v1/engine/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,12 @@ def process_inputs(
if prompt_adapter_request is not None:
raise ValueError("V1 does not support prompt_adapter_request.")

data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
if data_parallel_rank is not None and not (0 <= data_parallel_rank <
data_parallel_size):
raise ValueError(f"data_parallel_rank {data_parallel_rank} "
f"is out of range [0, {data_parallel_size}).")

if arrival_time is None:
arrival_time = time.time()

Expand Down