Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions tests/models/language/pooling/test_auto_prefix_cache_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,25 @@ def test_classify_models(
model: str,
dtype: str,
) -> None:
example_prompts = example_prompts * 2
# example_prompts is too short for testing prefix_caching
example_prompts = [s * 10 for s in example_prompts]

with vllm_runner(
model, max_model_len=512, dtype=dtype, enable_prefix_caching=True
) as vllm_model:
cache_config = vllm_model.llm.llm_engine.cache_config
assert cache_config.enable_prefix_caching
vllm_outputs = vllm_model.classify(example_prompts)

# First Run
vllm_model.classify(example_prompts)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we check that initially the number of cached tokens is zero?


# assert prefix_caching works
pooling_outputs = vllm_model.llm.encode(
example_prompts, pooling_task="classify"
)
for output in pooling_outputs:
assert output.num_cached_tokens > 0
vllm_outputs = [req_output.outputs.data for req_output in pooling_outputs]

with hf_runner(
model, dtype=dtype, auto_cls=AutoModelForSequenceClassification
Expand Down Expand Up @@ -54,7 +65,8 @@ def test_embed_models(
model: str,
dtype: str,
):
example_prompts = [str(s).strip() for s in example_prompts] * 2
# example_prompts is too short for testing prefix_caching
example_prompts = [str(s).strip() * 10 for s in example_prompts]

with vllm_runner(
model,
Expand All @@ -64,7 +76,15 @@ def test_embed_models(
) as vllm_model:
cache_config = vllm_model.llm.llm_engine.cache_config
assert cache_config.enable_prefix_caching
vllm_outputs = vllm_model.embed(example_prompts)

# First Run
vllm_model.embed(example_prompts)

# assert prefix_caching works
pooling_outputs = vllm_model.llm.encode(example_prompts, pooling_task="embed")
for output in pooling_outputs:
assert output.num_cached_tokens > 0
vllm_outputs = [req_output.outputs.data for req_output in pooling_outputs]

with hf_runner(
model,
Expand Down
33 changes: 33 additions & 0 deletions tests/models/language/pooling/test_extract_hidden_states.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch

from vllm import TokensPrompt


@pytest.mark.parametrize(
"model",
["Qwen/Qwen3-0.6B"],
)
@torch.inference_mode
def test_embed_models(hf_runner, vllm_runner, model: str):
n_prompt_tokens = [55, 56, 57]
token_prompts = [[1024 + i for i in range(n)] for n in n_prompt_tokens]

with vllm_runner(
model,
max_model_len=128,
enforce_eager=True,
runner="pooling",
enable_chunked_prefill=False,
enable_prefix_caching=False,
) as vllm_model:
pooling_outputs = vllm_model.llm.encode(
[TokensPrompt(prompt_token_ids=t) for t in token_prompts],
pooling_task="token_embed",
)

for n, output in zip(n_prompt_tokens, pooling_outputs):
assert len(output.prompt_token_ids) == n
assert output.num_cached_tokens == 0
3 changes: 3 additions & 0 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1078,6 +1078,9 @@ def encode(
PoolingRequestOutput[Any](
request_id="",
outputs=processed_outputs,
num_cached_tokens=getattr(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need getattr here? In what case is that not available?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the result of io_processor might not have this value

Copy link
Collaborator Author

@noooop noooop Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please unblock Language Models Test (Extended Pooling) and Language Models Test (MTEB) to check for CI failures in the main branch that still need to be fixed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm... I think we should make this a property of PoolingRequestOutput itself?

Copy link
Member

@DarkLight1337 DarkLight1337 Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something like

@property
def num_cached_tokens(self) -> int:
    return getattr(self.processed_outputs, "num_cached_tokens", 0)

Copy link
Collaborator Author

@noooop noooop Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we first merge this and discuss the issue in #26973? This PR is actually intended to fix CI failures in the main branch for #27329.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

processed_outputs, "num_cached_tokens", 0
),
prompt_token_ids=[],
finished=True,
)
Expand Down
1 change: 1 addition & 0 deletions vllm/entrypoints/openai/serving_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,7 @@ async def _collect_batch(
request_id=aggregator["request_id"],
prompt_token_ids=original_token_ids,
outputs=pooling_output_data,
num_cached_tokens=0,
finished=True,
)

Expand Down
1 change: 1 addition & 0 deletions vllm/entrypoints/score_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def _cosine_similarity(
request_id=f"{emb_1.request_id}_{emb_2.request_id}",
outputs=pair_score,
prompt_token_ids=tokens,
num_cached_tokens=emb_1.num_cached_tokens + emb_2.num_cached_tokens,
finished=True,
)
)
Expand Down
13 changes: 12 additions & 1 deletion vllm/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,14 +201,21 @@ class PoolingRequestOutput(Generic[_O]):
request_id (str): A unique identifier for the pooling request.
outputs (PoolingOutput): The pooling results for the given input.
prompt_token_ids (list[int]): A list of token IDs used in the prompt.
num_cached_tokens: The number of tokens with prefix cache hit.
finished (bool): A flag indicating whether the pooling is completed.
"""

def __init__(
self, request_id: str, outputs: _O, prompt_token_ids: list[int], finished: bool
self,
request_id: str,
outputs: _O,
prompt_token_ids: list[int],
num_cached_tokens: int,
finished: bool,
):
self.request_id = request_id
self.prompt_token_ids = prompt_token_ids
self.num_cached_tokens = num_cached_tokens
self.finished = finished
self.outputs = outputs

Expand All @@ -217,6 +224,7 @@ def __repr__(self):
f"{type(self).__name__}(request_id={self.request_id!r}, "
f"outputs={self.outputs!r}, "
f"prompt_token_ids={self.prompt_token_ids}, "
f"num_cached_tokens={self.num_cached_tokens}, "
f"finished={self.finished})"
)

Expand Down Expand Up @@ -255,6 +263,7 @@ def from_base(request_output: PoolingRequestOutput):
request_id=request_output.request_id,
outputs=EmbeddingOutput.from_base(request_output.outputs),
prompt_token_ids=request_output.prompt_token_ids,
num_cached_tokens=request_output.num_cached_tokens,
finished=request_output.finished,
)

Expand Down Expand Up @@ -294,6 +303,7 @@ def from_base(request_output: PoolingRequestOutput):
request_id=request_output.request_id,
outputs=ClassificationOutput.from_base(request_output.outputs),
prompt_token_ids=request_output.prompt_token_ids,
num_cached_tokens=request_output.num_cached_tokens,
finished=request_output.finished,
)

Expand Down Expand Up @@ -330,5 +340,6 @@ def from_base(request_output: PoolingRequestOutput):
request_id=request_output.request_id,
outputs=ScoringOutput.from_base(request_output.outputs),
prompt_token_ids=request_output.prompt_token_ids,
num_cached_tokens=request_output.num_cached_tokens,
finished=request_output.finished,
)
1 change: 1 addition & 0 deletions vllm/v1/engine/output_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ def _new_request_output(
return PoolingRequestOutput(
request_id=request_id,
outputs=first_output,
num_cached_tokens=self.num_cached_tokens,
prompt_token_ids=self.prompt_token_ids,
finished=finished,
)
Expand Down
Loading