Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions tests/entrypoints/pooling/llm/test_classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,8 @@ def get_outputs(activation):


@pytest.mark.skip_global_cleanup
def test_encode_api(llm: LLM):
# chunked prefill does not support all pooling
err_msg = "pooling_task must be one of.+"
with pytest.raises(ValueError, match=err_msg):
llm.encode(prompts, pooling_task="token_classify", use_tqdm=False)
def test_token_classify(llm: LLM):
llm.encode(prompts, pooling_task="token_classify", use_tqdm=False)


def test_score_api(llm: LLM):
Expand Down
2 changes: 1 addition & 1 deletion tests/entrypoints/pooling/llm/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def llm():


@pytest.mark.skip_global_cleanup
def test_encode_api(llm: LLM):
def test_token_embed(llm: LLM):
outputs = llm.encode(prompts, pooling_task="token_embed", use_tqdm=False)
multi_vector = outputs[0].outputs.data
assert multi_vector.shape == (11, 384)
Expand Down
13 changes: 9 additions & 4 deletions tests/entrypoints/pooling/openai/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch.nn.functional as F

from tests.utils import RemoteOpenAIServer
from vllm.entrypoints.openai.protocol import ClassificationResponse
from vllm.entrypoints.openai.protocol import ClassificationResponse, PoolingResponse

MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach"
DTYPE = "float32" # Use float32 to avoid NaN issue
Expand Down Expand Up @@ -192,12 +192,17 @@ async def get_outputs(activation):
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_pooling(server: RemoteOpenAIServer, model_name: str):
# pooling api uses ALL pooling, which does not support chunked prefill.
input_text = ["This product was excellent and exceeded my expectations"]
response = requests.post(
server.url_for("pooling"),
json={"model": model_name, "input": "test", "encoding_format": "float"},
json={"model": model_name, "input": input_text, "encoding_format": "float"},
)
assert response.json()["error"]["type"] == "BadRequestError"
poolings = PoolingResponse.model_validate(response.json())

# token_classify
assert len(poolings.data) == 1
assert len(poolings.data[0].data) == 8
assert len(poolings.data[0].data[0]) == 2


@pytest.mark.asyncio
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from transformers import AutoModel

from tests.models.utils import check_embeddings_close
from vllm import TokensPrompt


@pytest.mark.parametrize(
"model",
["Qwen/Qwen3-Embedding-0.6B"],
)
@torch.inference_mode
def test_embed_models(hf_runner, vllm_runner, model: str):
chunk_size = 10
n_prompt_tokens = [55, 56, 57]
token_prompts = [[1024 + i for i in range(n)] for n in n_prompt_tokens]

with vllm_runner(
model,
runner="pooling",
max_model_len=128,
max_num_batched_tokens=chunk_size,
enforce_eager=True,
# `enable_chunked_prefill`: Set to `False` instead of `None` in VllmRunner
enable_chunked_prefill=True,
# If enable_prefix_caching is enabled,
# the output of all pooling will be less than n_prompt_tokens,
# we need a method to disable prefix_caching at the request level.
enable_prefix_caching=False,
) as vllm_model:
vllm_outputs = vllm_model.token_embed(
[TokensPrompt(prompt_token_ids=t) for t in token_prompts],
)

with hf_runner(
model,
auto_cls=AutoModel,
) as hf_model:
hf_outputs = []
for token_prompt in token_prompts:
inputs = hf_model.wrap_device({"input_ids": torch.tensor([token_prompt])})
input_ids = inputs["input_ids"]
output = hf_model.model(input_ids)
hf_outputs.append(output.last_hidden_state.cpu().float()[0])

for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
check_embeddings_close(
embeddings_0_lst=hf_output,
embeddings_1_lst=vllm_output,
name_0="hf",
name_1="vllm",
tol=1e-2,
)
11 changes: 8 additions & 3 deletions tests/models/language/pooling/test_extract_hidden_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,19 @@
)
@torch.inference_mode
def test_embed_models(hf_runner, vllm_runner, model: str):
chunk_size = 10
n_prompt_tokens = [55, 56, 57]
token_prompts = [[1024 + i for i in range(n)] for n in n_prompt_tokens]

with vllm_runner(
model,
max_model_len=128,
max_num_batched_tokens=chunk_size,
enforce_eager=True,
runner="pooling",
enable_chunked_prefill=False,
enable_prefix_caching=False,
# `enable_chunked_prefill`: Set to `False` instead of `None` in VllmRunner
enable_chunked_prefill=True,
enable_prefix_caching=True,
) as vllm_model:
pooling_outputs = vllm_model.llm.encode(
[TokensPrompt(prompt_token_ids=t) for t in token_prompts],
Expand All @@ -30,4 +33,6 @@ def test_embed_models(hf_runner, vllm_runner, model: str):

for n, output in zip(n_prompt_tokens, pooling_outputs):
assert len(output.prompt_token_ids) == n
assert output.num_cached_tokens == 0
# We should ensure that all pooling task output.num_cached_tokens == 0
# even if prefix caching is enabled
assert output.num_cached_tokens >= 0
75 changes: 62 additions & 13 deletions vllm/model_executor/layers/pooler.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,22 +204,60 @@ def forward_all(


class AllPool(PoolingMethod):
def __init__(self):
super().__init__()

vllm_config = get_current_vllm_config()
self.enable_chunked_prefill = (
vllm_config.scheduler_config.enable_chunked_prefill
)

def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_embed", "token_classify"}

def forward_all(
self, hidden_states: torch.Tensor, pooling_cursor: PoolingCursor
) -> list[torch.Tensor] | torch.Tensor:
raise NotImplementedError(
"forward_all is not implemented for AllPool. Use forward instead."
)

def forward(
self,
hidden_states: torch.Tensor,
pooling_cursor: PoolingCursor,
pooling_metadata: PoolingMetadata,
) -> list[torch.Tensor] | torch.Tensor:
assert not pooling_cursor.is_partial_prefill(), (
"partial prefill not supported with ALL pooling"
)
pooling_cursor = pooling_metadata.pooling_cursor
pooling_params = get_pooling_params(pooling_metadata)
is_finished = pooling_cursor.is_finished()

hidden_states_lst = list(
hidden_states.split(pooling_cursor.num_scheduled_tokens_cpu.tolist())
)
return [hidden_states_lst[i] for i in pooling_cursor.index]
hidden_states_lst = [hidden_states_lst[i] for i in pooling_cursor.index]

if not self.enable_chunked_prefill:
return hidden_states_lst

# If chunked_prefill is enabled
# 1. first store the chunked hidden_states in pooling_param.hidden_states_cache
for pooling_param, hs_chunk in zip(pooling_params, hidden_states_lst):
pooling_param.hidden_states_cache.append(hs_chunk)

# 2. Once prefill is finished, send hidden_states_cache to PoolerHead
output_list = []
for pooling_param, finished in zip(pooling_params, is_finished):
if finished:
hidden_states_cache = pooling_param.hidden_states_cache
if len(hidden_states_cache) == 1:
output_list.append(hidden_states_cache[0])
else:
output_list.append(torch.concat(hidden_states_cache, dim=0))
pooling_param.hidden_states_cache.clear()
else:
output_list.append(None)

return output_list


class MeanPool(PoolingMethod):
Expand Down Expand Up @@ -622,8 +660,12 @@ def forward(

class TokenEmbeddingPoolerHead(EmbeddingPoolerHead):
def forward(
self, pooled_data: torch.Tensor, pooling_param: PoolingParams
) -> torch.Tensor:
self, pooled_data: torch.Tensor | None, pooling_param: PoolingParams
) -> PoolerOutput:
# for unfinished chunked prefill
if pooled_data is None:
return None

pooled_data = pooled_data.to(self.head_dtype)
# pooled_data shape: [n_tokens, hidden_dimension]

Expand Down Expand Up @@ -666,9 +708,13 @@ def get_supported_tasks(self) -> Set[PoolingTask]:

def forward(
self,
hidden_states: torch.Tensor,
hidden_states: torch.Tensor | None,
pooling_param: PoolingParams,
) -> torch.Tensor:
) -> PoolerOutput:
# for unfinished chunked prefill
if hidden_states is None:
return None

hidden_states = hidden_states.to(self.head_dtype)
# hidden_states shape: [n_token, hidden_size]

Expand Down Expand Up @@ -722,17 +768,20 @@ def extract_states(
self,
hidden_states: torch.Tensor | list[torch.Tensor],
pooling_metadata: PoolingMetadata,
) -> torch.Tensor | list[torch.Tensor]:
) -> PoolerOutput:
pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
prompt_token_ids = get_prompt_token_ids(pooling_metadata)

pooled_data = list[torch.Tensor]()

pooling_params = get_pooling_params(pooling_metadata)

pooled_data: list[torch.Tensor | None] = []
for data, token_id, pooling_param in zip(
pooled_data_lst, prompt_token_ids, pooling_params
):
# for unfinished chunked prefill
if data is None:
pooled_data.append(data)
continue

step_tag_id = pooling_param.step_tag_id
returned_token_ids = pooling_param.returned_token_ids

Expand Down
4 changes: 4 additions & 0 deletions vllm/pooling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import TYPE_CHECKING, Annotated, Any, Optional

import msgspec
import torch

from vllm.sampling_params import RequestOutputKind
from vllm.tasks import PoolingTask
Expand Down Expand Up @@ -57,6 +58,9 @@ class PoolingParams(
extra_kwargs: dict[str, Any] | None = None
output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY

# for chunked prefill with ALL pooling
hidden_states_cache: list[torch.Tensor] = msgspec.field(default_factory=list)

@property
def all_parameters(self) -> list[str]:
return ["dimensions", "normalize", "activation"]
Expand Down
8 changes: 8 additions & 0 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,14 @@ def schedule(self) -> SchedulerOutput:

self.kv_cache_manager.free(preempted_req)
self.encoder_cache_manager.free(preempted_req)

# The hidden_states_cache is used in requests that
# use all pooling + chunked prefill.
# If the request is preempted, the hidden_states_cache
# needs to be cleared and recalculated.
if preempted_req.pooling_params is not None:
preempted_req.pooling_params.hidden_states_cache.clear()

preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0
preempted_req.num_preemptions += 1
Expand Down
3 changes: 3 additions & 0 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,9 @@ def add_request(self, request: Request, request_wave: int = 0):
f"Supported tasks: {supported_pooling_tasks}"
)

# Ensure that no multiple requests share the same pooling_params
request.pooling_params = request.pooling_params.clone()

if request.kv_transfer_params is not None and (
not self.scheduler.get_kv_connector()
):
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def empty_cpu(

# [num_reqs, <dynamic>]
# The shape of each element depends on the pooler used
PoolerOutput = torch.Tensor | list[torch.Tensor]
PoolerOutput = torch.Tensor | list[torch.Tensor] | None


@dataclass
Expand Down
18 changes: 15 additions & 3 deletions vllm/v1/pool/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class PoolingCursor:
first_token_indices_gpu: torch.Tensor
last_token_indices_gpu: torch.Tensor
prompt_lens_cpu: torch.Tensor
seq_lens_cpu: torch.Tensor
num_scheduled_tokens_cpu: torch.Tensor

def __getitem__(self, indices: slice):
Expand All @@ -24,12 +25,16 @@ def __getitem__(self, indices: slice):
first_token_indices_gpu=self.first_token_indices_gpu[indices],
last_token_indices_gpu=self.last_token_indices_gpu[indices],
prompt_lens_cpu=self.prompt_lens_cpu[indices],
seq_lens_cpu=self.seq_lens_cpu[indices],
num_scheduled_tokens_cpu=self.num_scheduled_tokens_cpu[indices],
)

def is_partial_prefill(self):
return not torch.all(self.prompt_lens_cpu == self.num_scheduled_tokens_cpu)

def is_finished(self):
return self.prompt_lens_cpu == self.seq_lens_cpu


@dataclass
class PoolingMetadata:
Expand All @@ -53,15 +58,21 @@ def __getitem__(self, indices: slice):
)

def build_pooling_cursor(
self, num_scheduled_tokens: list[int], device: torch.device
self,
num_scheduled_tokens: list[int],
seq_lens_cpu: torch.Tensor,
device: torch.device,
):
self.pooling_cursor = build_pooling_cursor(
num_scheduled_tokens, self.prompt_lens, device
num_scheduled_tokens, seq_lens_cpu, self.prompt_lens, device
)


def build_pooling_cursor(
num_scheduled_tokens: list[int], prompt_lens: torch.Tensor, device: torch.device
num_scheduled_tokens: list[int],
seq_lens_cpu: torch.Tensor,
prompt_lens: torch.Tensor,
device: torch.device,
):
assert len(prompt_lens) == len(num_scheduled_tokens)

Expand All @@ -78,5 +89,6 @@ def build_pooling_cursor(
first_token_indices_gpu=cumsum[:n_seq],
last_token_indices_gpu=cumsum[1:] - 1,
prompt_lens_cpu=prompt_lens,
seq_lens_cpu=seq_lens_cpu,
num_scheduled_tokens_cpu=num_scheduled_tokens,
)
Loading