Skip to content

[Core] [Bugfix]: tensor parallel with prompt embeds #18171

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 65 additions & 10 deletions tests/basic_correctness/test_basic_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
from unittest.mock import Mock

import pytest
import torch

from vllm import LLM
from vllm import LLM, envs
from vllm.platforms import current_platform
from vllm.v1.engine.llm_engine import LLMEngine as LLMEngineV1

from ..conftest import VllmRunner
from ..conftest import HfRunner, VllmRunner
from ..models.utils import check_outputs_equal
from ..utils import multi_gpu_test

Expand Down Expand Up @@ -43,11 +44,26 @@ def test_vllm_gc_ed():
assert weak_llm() is None


def _fix_prompt_embed_outputs(
vllm_outputs: list[tuple[list[int], str]], hf_model: HfRunner,
example_prompts: list[str]) -> list[tuple[list[int], str]]:
fixed_vllm_outputs = []
for vllm_output, hf_input, prompt in zip(
vllm_outputs, hf_model.get_inputs(example_prompts),
example_prompts):
hf_input_ids = hf_input["input_ids"].tolist()[0]
fixed_vllm_outputs.append(
(hf_input_ids + vllm_output[0][len(hf_input_ids):],
prompt + vllm_output[1]))
return fixed_vllm_outputs


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("backend", ["FLASH_ATTN"])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("enforce_eager", [False])
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
def test_models(
monkeypatch: pytest.MonkeyPatch,
hf_runner,
Expand All @@ -56,8 +72,13 @@ def test_models(
dtype: str,
max_tokens: int,
enforce_eager: bool,
enable_prompt_embeds: bool,
) -> None:

if enable_prompt_embeds and envs.is_set(
"VLLM_USE_V1") and envs.VLLM_USE_V1:
pytest.skip("enable_prompt_embeds is not supported in v1.")

if backend == "FLASHINFER" and current_platform.is_rocm():
pytest.skip("Flashinfer does not support ROCm/HIP.")

Expand All @@ -78,14 +99,25 @@ def test_models(

with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
if enable_prompt_embeds:
with torch.no_grad():
prompt_embeds = hf_model.get_prompt_embeddings(
example_prompts)

with VllmRunner(model,
max_model_len=8192,
dtype=dtype,
enforce_eager=enforce_eager,
enable_prompt_embeds=enable_prompt_embeds,
gpu_memory_utilization=0.7) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts,
max_tokens)
if enable_prompt_embeds:
vllm_outputs = vllm_model.generate_greedy(
prompt_embeds, max_tokens)
vllm_outputs = _fix_prompt_embed_outputs(
vllm_outputs, hf_model, example_prompts)
else:
vllm_outputs = vllm_model.generate_greedy(
example_prompts, max_tokens)

check_outputs_equal(
outputs_0_lst=hf_outputs,
Expand All @@ -108,6 +140,7 @@ def test_models(
("distilbert/distilgpt2", "mp", "FLASHINFER", "A100"),
("meta-llama/Meta-Llama-3-8B", "ray", "FLASHINFER", "A100"),
])
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
def test_models_distributed(
monkeypatch: pytest.MonkeyPatch,
hf_runner,
Expand All @@ -117,14 +150,22 @@ def test_models_distributed(
distributed_executor_backend: str,
attention_backend: str,
test_suite: str,
enable_prompt_embeds: bool,
) -> None:

if enable_prompt_embeds and envs.is_set(
"VLLM_USE_V1") and envs.VLLM_USE_V1:
pytest.skip("enable_prompt_embeds is not supported in v1.")

if test_suite != TARGET_TEST_SUITE:
pytest.skip(f"Skip test for {test_suite}")

with monkeypatch.context() as monkeypatch_context:
if model == "meta-llama/Llama-3.2-1B-Instruct" and distributed_executor_backend == "ray" and attention_backend == "" and test_suite == "L4": # noqa
# test Ray Compiled Graph
if enable_prompt_embeds:
pytest.skip(
"enable_prompt_embeds does not work with ray compiled dag."
)
monkeypatch_context.setenv("VLLM_USE_RAY_SPMD_WORKER", "1")
monkeypatch_context.setenv("VLLM_USE_RAY_COMPILED_DAG", "1")

Expand All @@ -147,12 +188,26 @@ def test_models_distributed(
dtype=dtype,
tensor_parallel_size=2,
distributed_executor_backend=distributed_executor_backend,
enable_prompt_embeds=enable_prompt_embeds,
gpu_memory_utilization=0.7,
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts,
max_tokens)

with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
if enable_prompt_embeds:
with hf_runner(model, dtype=dtype) as hf_model:
with torch.no_grad():
prompt_embeds = hf_model.get_prompt_embeddings(
example_prompts)
vllm_outputs = vllm_model.generate_greedy(
prompt_embeds, max_tokens)
vllm_outputs = _fix_prompt_embed_outputs(
vllm_outputs, hf_model, example_prompts)
hf_outputs = hf_model.generate_greedy(
example_prompts, max_tokens)
else:
vllm_outputs = vllm_model.generate_greedy(
example_prompts, max_tokens)
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(
example_prompts, max_tokens)

check_outputs_equal(
outputs_0_lst=hf_outputs,
Expand Down
9 changes: 9 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,15 @@ def get_inputs(

return all_inputs

def get_prompt_embeddings(self, prompts: list[str]) -> list[torch.Tensor]:
all_inputs = self.get_inputs(prompts)
embeddings = []
for inputs in all_inputs:
input_ids = self.wrap_device(inputs)["input_ids"]
embedding = self.model.get_input_embeddings()(input_ids).squeeze(0)
embeddings.append(embedding)
return embeddings

def classify(self, prompts: list[str]) -> list[str]:
# output is final logits
all_inputs = self.get_inputs(prompts)
Expand Down
14 changes: 7 additions & 7 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,12 @@ class RequestMetrics:
will include model forward, block/sync across
workers, cpu-gpu sync time and sampling time.
spec_token_acceptance_counts: number of accepted speculative tokens at
each position; the first token is from
each position; the first token is from
the target model and is always accepted;
e.g., when it's [10, 8, 4, 2] for a req,
e.g., when it's [10, 8, 4, 2] for a req,
it means there were 10 forward passes in
total, and there were 8, 4, 2 accepted
tokens at 1st, 2nd, 3rd speculation step.
total, and there were 8, 4, 2 accepted
tokens at 1st, 2nd, 3rd speculation step.
"""
arrival_time: float
last_token_time: float
Expand Down Expand Up @@ -714,9 +714,9 @@ class SequenceGroup:
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request.
priority: User-defined priority of the request.
draft_size: The number of speculative tokens plus one from the target
draft_size: The number of speculative tokens plus one from the target
model; equal to max number of tokens a step can generate
for single-draft speculative decoding but larger than
for single-draft speculative decoding but larger than
that for multi-draft SD (currently not supported).
"""

Expand Down Expand Up @@ -1123,7 +1123,7 @@ def __repr__(self) -> str:
self.output_embed.shape if self.output_embed is not None else None
return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
f"output_token={self.output_token}, "
f"output_embed.shape={output_embed_shape}"
f"output_embed.shape={output_embed_shape}, "
f"logprobs={self.logprobs})")

def __eq__(self, other: object) -> bool:
Expand Down
100 changes: 55 additions & 45 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from vllm.attention.backends.utils import CommonAttentionState
from vllm.config import CompilationLevel, VllmConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.distributed import get_pp_group
from vllm.distributed import broadcast_tensor_dict, get_pp_group
from vllm.distributed.kv_transfer import get_kv_transfer_group
from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank,
graph_capture)
Expand Down Expand Up @@ -872,23 +872,23 @@ def build(self) -> ModelInputForGPU:
"""
# Combine and flatten intermediate data.
input_tokens = list[int]()
inputs_embeds_lst = list[torch.Tensor]()
inputs_embeds_list = list[torch.Tensor]()
token_types = list[int]()
for inter_data in self.inter_data_list:
for cur_input_tokens in inter_data.input_tokens:
input_tokens.extend(cur_input_tokens)
for cur_token_types in inter_data.token_types:
token_types.extend(cur_token_types)
if inter_data.inputs_embeds is not None:
inputs_embeds_lst.append(
inputs_embeds_list.append(
inter_data.inputs_embeds.to(
dtype=self.runner.model_config.dtype,
device=self.runner.device))
inputs_embeds: Optional[torch.Tensor]
if len(inputs_embeds_lst) == 0:
if len(inputs_embeds_list) == 0:
inputs_embeds = None
else:
inputs_embeds = torch.cat(inputs_embeds_lst, dim=0).to(
inputs_embeds = torch.cat(inputs_embeds_list, dim=0).to(
dtype=self.runner.model_config.dtype,
device=self.runner.device)
assert len(inputs_embeds) == len(input_tokens)
Expand Down Expand Up @@ -1893,50 +1893,60 @@ def execute_model(
logits = self.model.compute_logits(hidden_or_intermediate_states,
model_input.sampling_metadata)

if not self.is_driver_worker:
return []
if self.is_driver_worker:
if model_input.async_callback is not None:
model_input.async_callback()

if model_input.async_callback is not None:
model_input.async_callback()
# Sample the next token.
assert isinstance(self.sampler, Sampler)
orig_include_gpu_probs = self.sampler.include_gpu_probs_tensor
if model_input.inputs_embeds is not None:
self.sampler.include_gpu_probs_tensor = True

# Sample the next token.
assert isinstance(self.sampler, Sampler)
orig_include_gpu_probs_tensor = self.sampler.include_gpu_probs_tensor
if model_input.inputs_embeds is not None:
self.sampler.include_gpu_probs_tensor = True

output: SamplerOutput = self.sampler(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time
and output is not None):
model_forward_end.synchronize()
model_forward_time = model_forward_start.elapsed_time(
model_forward_end)
orig_model_forward_time = 0.0
if intermediate_tensors is not None:
orig_model_forward_time = intermediate_tensors.tensors.get(
"model_forward_time", torch.tensor(0.0)).item()
# If there are multiple workers, we are still tracking the latency
# from the start time of the driver worker to the end time of the
# driver worker. The model forward time will then end up covering
# the communication time as well.
output.model_forward_time = (orig_model_forward_time +
model_forward_time)
output: SamplerOutput = self.sampler(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time
and output is not None):
model_forward_end.synchronize()
model_forward_time = model_forward_start.elapsed_time(
model_forward_end)
orig_model_forward_time = 0.0
if intermediate_tensors is not None:
orig_model_forward_time = intermediate_tensors.tensors.get(
"model_forward_time", torch.tensor(0.0)).item()
# If there are multiple workers, we are still tracking the
# latency from the start time of the driver worker to the end
# time of the driver worker. The model forward time will then
# end up covering the communication time as well.
output.model_forward_time = (orig_model_forward_time +
model_forward_time)

if model_input.inputs_embeds is not None:
self.sampler.include_gpu_probs_tensor = \
orig_include_gpu_probs_tensor
if output.sampled_token_ids is not None:
output.sampled_token_embeds = self.model.get_input_embeddings(
output.sampled_token_ids.squeeze(1))

for token_embed, sequence_group_output in zip(
output.sampled_token_embeds, output.outputs):
assert len(sequence_group_output.samples) == 1
sequence_group_output.samples[0].output_embed = token_embed
if self.is_driver_worker:
sampled = broadcast_tensor_dict(
{"token_ids": output.sampled_token_ids})
else:
sampled = broadcast_tensor_dict()
if sampled["token_ids"] is not None:
sampled_token_embeds = self.model.get_input_embeddings(
sampled["token_ids"].squeeze(1))
if self.is_driver_worker:
self.sampler.include_gpu_probs_tensor = \
orig_include_gpu_probs

output.sampled_token_embeds = sampled_token_embeds

for token_embed, sequence_group_output in zip(
output.sampled_token_embeds, output.outputs):
assert len(sequence_group_output.samples) == 1
sequence_group_output.samples[
0].output_embed = token_embed
Copy link
Member

@DarkLight1337 DarkLight1337 May 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: Not really related to this PR, but why do we need output_embed? This may lead to additional communication overhead, and IMO it's not needed since you already have it in the input

Copy link
Contributor Author

@Nan2018 Nan2018 May 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need the embedding of the sampled output for the next model execution. currently we can't have a mix of input embeds and token ids for the same sequence

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks for the explanation!


if not self.is_driver_worker:
return []

if self.return_hidden_states:
# we only need to pass hidden states of most recent token
Expand Down