Skip to content

[V1][ModelRunner] Support pooling model for v1 engine #1359

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions .github/workflows/vllm_ascend_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ jobs:
- name: Run codespell check
run: |
CODESPELL_EXCLUDES=('--skip' 'tests/prompts/**,./benchmarks/sonnet.txt,*tests/lora/data/**,build/**,./vllm_ascend.egg-info/**')
CODESPELL_IGNORE_WORDS=('-L' 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue,CopyIn')
CODESPELL_IGNORE_WORDS=('-L' 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue,CopyIn,assertIn')

codespell --toml pyproject.toml "${CODESPELL_EXCLUDES[@]}" "${CODESPELL_IGNORE_WORDS[@]}"
- name: Analysing the code with ruff
Expand Down Expand Up @@ -262,11 +262,13 @@ jobs:
pytest -sv tests/e2e/singlecard/test_ilama_lora.py
pytest -sv tests/e2e/singlecard/test_guided_decoding.py
pytest -sv tests/e2e/singlecard/test_camem.py
pytest -sv tests/e2e/singlecard/test_embedding.py
pytest -sv tests/e2e/singlecard/ \
--ignore=tests/e2e/singlecard/test_offline_inference.py \
--ignore=tests/e2e/singlecard/test_ilama_lora.py \
--ignore=tests/e2e/singlecard/test_guided_decoding.py \
--ignore=tests/e2e/singlecard/test_camem.py
--ignore=tests/e2e/singlecard/test_camem.py \
--ignore=tests/e2e/singlecard/test_embedding.py

- name: Run e2e test on V0 engine
if: ${{ github.event_name == 'schedule' }}
Expand All @@ -281,14 +283,16 @@ jobs:
pytest -sv tests/e2e/singlecard/test_guided_decoding.py
pytest -sv tests/e2e/singlecard/test_camem.py
pytest -sv tests/e2e/singlecard/test_prompt_embedding.py
pytest -sv tests/e2e/singlecard/test_embedding.py
pytest -sv tests/e2e/singlecard/ \
--ignore=tests/e2e/singlecard/test_offline_inference.py \
--ignore=tests/e2e/singlecard/test_ilama_lora.py \
--ignore=tests/e2e/singlecard/test_guided_decoding.py \
--ignore=tests/e2e/singlecard/test_camem.py \
--ignore=tests/e2e/singlecard/test_prompt_embedding.py \
--ignore=tests/e2e/singlecard/core/test_ascend_scheduler.py \
--ignore=tests/e2e/singlecard/core/test_ascend_scheduler_e2e.py
--ignore=tests/e2e/singlecard/core/test_ascend_scheduler_e2e.py \
--ignore=tests/e2e/singlecard/test_embedding.py

e2e-4-cards:
needs: [e2e]
Expand Down
53 changes: 53 additions & 0 deletions examples/offline_embed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from https://www.modelscope.cn/models/Qwen/Qwen3-Embedding-0.6B
#

import os

import torch
from vllm import LLM

os.environ["VLLM_USE_MODELSCOPE"] = "True"


def get_detailed_instruct(task_description: str, query: str) -> str:
return f'Instruct: {task_description}\nQuery:{query}'


# Each query must come with a one-sentence instruction that describes the task
task = 'Given a web search query, retrieve relevant passages that answer the query'

queries = [
get_detailed_instruct(task, 'What is the capital of China?'),
get_detailed_instruct(task, 'Explain gravity')
]
# No need to add instruction for retrieval documents
documents = [
"The capital of China is Beijing.",
"Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun."
]
input_texts = queries + documents

model = LLM(model="Qwen/Qwen3-Embedding-0.6B", task="embed")

outputs = model.embed(input_texts)
embeddings = torch.tensor([o.outputs.embedding for o in outputs])
# Calculate the similarity scores between the first two queries and the last two documents
scores = (embeddings[:2] @ embeddings[2:].T)
print(scores.tolist())
# [[0.7620252966880798, 0.14078938961029053], [0.1358368694782257, 0.6013815999031067]]
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ xgrammar
zmq
types-psutil
pytest-cov
sentence_transformers
138 changes: 136 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,23 @@

import contextlib
import gc
from typing import List, Optional, Tuple, TypeVar, Union
from typing import Any, List, Optional, Tuple, TypeVar, Union

import numpy as np
import pytest
import torch
from huggingface_hub import snapshot_download
from PIL import Image
from torch import nn
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
BatchEncoding, BatchFeature)
from transformers.models.auto.auto_factory import _BaseAutoModelClass
from vllm import LLM, SamplingParams
from vllm.config import TaskOption
from vllm.config import TaskOption, _get_and_verify_dtype
from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams
from vllm.transformers_utils.utils import maybe_model_redirect
from vllm.utils import is_list_of

from tests.model_utils import (PROMPT_TEMPLATES, TokensTextLogprobs,
Expand All @@ -45,6 +50,7 @@
from vllm.distributed.parallel_state import ( # noqa E402
destroy_distributed_environment, destroy_model_parallel)

_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict)
_M = TypeVar("_M")

_PromptMultiModalInput = Union[List[_M], List[List[_M]]]
Expand Down Expand Up @@ -364,3 +370,131 @@
@pytest.fixture(scope="session")
def ilama_lora_files():
return snapshot_download(repo_id="jeeejeee/ilama-text2sql-spider")


class HfRunner:

def get_default_device(self):
from vllm.platforms import current_platform

Check warning on line 378 in tests/conftest.py

View check run for this annotation

Codecov / codecov/patch

tests/conftest.py#L378

Added line #L378 was not covered by tests

return ("cpu"

Check warning on line 380 in tests/conftest.py

View check run for this annotation

Codecov / codecov/patch

tests/conftest.py#L380

Added line #L380 was not covered by tests
if current_platform.is_cpu() else current_platform.device_type)

def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
if x is None or isinstance(x, (bool, )):
return x

Check warning on line 385 in tests/conftest.py

View check run for this annotation

Codecov / codecov/patch

tests/conftest.py#L384-L385

Added lines #L384 - L385 were not covered by tests

if device is None:
device = self.device

Check warning on line 388 in tests/conftest.py

View check run for this annotation

Codecov / codecov/patch

tests/conftest.py#L387-L388

Added lines #L387 - L388 were not covered by tests

if isinstance(x, dict):
return {k: self.wrap_device(v, device) for k, v in x.items()}

Check warning on line 391 in tests/conftest.py

View check run for this annotation

Codecov / codecov/patch

tests/conftest.py#L390-L391

Added lines #L390 - L391 were not covered by tests

if hasattr(x, "device") and x.device.type == device:
return x

Check warning on line 394 in tests/conftest.py

View check run for this annotation

Codecov / codecov/patch

tests/conftest.py#L393-L394

Added lines #L393 - L394 were not covered by tests

return x.to(device)

Check warning on line 396 in tests/conftest.py

View check run for this annotation

Codecov / codecov/patch

tests/conftest.py#L396

Added line #L396 was not covered by tests

def __init__(
self,
model_name: str,
dtype: str = "auto",
*,
model_kwargs: Optional[dict[str, Any]] = None,
trust_remote_code: bool = True,
is_sentence_transformer: bool = False,
is_cross_encoder: bool = False,
skip_tokenizer_init: bool = False,
auto_cls: type[_BaseAutoModelClass] = AutoModelForCausalLM,
) -> None:
model_name = maybe_model_redirect(model_name)
self.model_name = model_name

Check warning on line 411 in tests/conftest.py

View check run for this annotation

Codecov / codecov/patch

tests/conftest.py#L410-L411

Added lines #L410 - L411 were not covered by tests

self.config = AutoConfig.from_pretrained(

Check warning on line 413 in tests/conftest.py

View check run for this annotation

Codecov / codecov/patch

tests/conftest.py#L413

Added line #L413 was not covered by tests
model_name,
trust_remote_code=trust_remote_code,
)
self.device = self.get_default_device()
self.dtype = torch_dtype = _get_and_verify_dtype(

Check warning on line 418 in tests/conftest.py

View check run for this annotation

Codecov / codecov/patch

tests/conftest.py#L417-L418

Added lines #L417 - L418 were not covered by tests
self.model_name,
self.config,
dtype=dtype,
is_pooling_model=is_sentence_transformer or is_cross_encoder,
)

model_kwargs = model_kwargs if model_kwargs is not None else {}
model_kwargs.setdefault("torch_dtype", torch_dtype)

Check warning on line 426 in tests/conftest.py

View check run for this annotation

Codecov / codecov/patch

tests/conftest.py#L425-L426

Added lines #L425 - L426 were not covered by tests

if is_sentence_transformer:

Check warning on line 428 in tests/conftest.py

View check run for this annotation

Codecov / codecov/patch

tests/conftest.py#L428

Added line #L428 was not covered by tests
# Lazy init required for AMD CI
from sentence_transformers import SentenceTransformer

Check warning on line 430 in tests/conftest.py

View check run for this annotation

Codecov / codecov/patch

tests/conftest.py#L430

Added line #L430 was not covered by tests

self.model = SentenceTransformer(

Check warning on line 432 in tests/conftest.py

View check run for this annotation

Codecov / codecov/patch

tests/conftest.py#L432

Added line #L432 was not covered by tests
model_name,
device=self.device,
model_kwargs=model_kwargs,
trust_remote_code=trust_remote_code,
)
elif is_cross_encoder:

Check warning on line 438 in tests/conftest.py

View check run for this annotation

Codecov / codecov/patch

tests/conftest.py#L438

Added line #L438 was not covered by tests
# Lazy init required for AMD CI
from sentence_transformers import CrossEncoder

Check warning on line 440 in tests/conftest.py

View check run for this annotation

Codecov / codecov/patch

tests/conftest.py#L440

Added line #L440 was not covered by tests

self.model = CrossEncoder(

Check warning on line 442 in tests/conftest.py

View check run for this annotation

Codecov / codecov/patch

tests/conftest.py#L442

Added line #L442 was not covered by tests
model_name,
device=self.device,
automodel_args=model_kwargs,
trust_remote_code=trust_remote_code,
)
else:
model = auto_cls.from_pretrained(

Check warning on line 449 in tests/conftest.py

View check run for this annotation

Codecov / codecov/patch

tests/conftest.py#L449

Added line #L449 was not covered by tests
model_name,
trust_remote_code=trust_remote_code,
**model_kwargs,
)

# in case some unquantized custom models are not in same dtype
if (getattr(model, "quantization_method", None) is None

Check warning on line 456 in tests/conftest.py

View check run for this annotation

Codecov / codecov/patch

tests/conftest.py#L456

Added line #L456 was not covered by tests
and any(p.dtype != self.dtype
for p in model.parameters())):
model = model.to(dtype=self.dtype)

Check warning on line 459 in tests/conftest.py

View check run for this annotation

Codecov / codecov/patch

tests/conftest.py#L459

Added line #L459 was not covered by tests

if (getattr(model, "quantization_method", None) != "bitsandbytes"

Check warning on line 461 in tests/conftest.py

View check run for this annotation

Codecov / codecov/patch

tests/conftest.py#L461

Added line #L461 was not covered by tests
and len({p.device
for p in model.parameters()}) < 2):
model = model.to(device=self.device)

Check warning on line 464 in tests/conftest.py

View check run for this annotation

Codecov / codecov/patch

tests/conftest.py#L464

Added line #L464 was not covered by tests

self.model = model

Check warning on line 466 in tests/conftest.py

View check run for this annotation

Codecov / codecov/patch

tests/conftest.py#L466

Added line #L466 was not covered by tests

if not skip_tokenizer_init:
self.tokenizer = AutoTokenizer.from_pretrained(

Check warning on line 469 in tests/conftest.py

View check run for this annotation

Codecov / codecov/patch

tests/conftest.py#L468-L469

Added lines #L468 - L469 were not covered by tests
model_name,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
)

# don't put this import at the top level
# it will call torch.cuda.device_count()
from transformers import AutoProcessor # noqa: F401
self.processor = AutoProcessor.from_pretrained(

Check warning on line 478 in tests/conftest.py

View check run for this annotation

Codecov / codecov/patch

tests/conftest.py#L477-L478

Added lines #L477 - L478 were not covered by tests
model_name,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
)
if skip_tokenizer_init:
self.tokenizer = self.processor.tokenizer

Check warning on line 484 in tests/conftest.py

View check run for this annotation

Codecov / codecov/patch

tests/conftest.py#L483-L484

Added lines #L483 - L484 were not covered by tests

def encode(self, prompts: list[str], *args,
**kwargs) -> list[list[torch.Tensor]]:
return self.model.encode(prompts, *args, **kwargs)

Check warning on line 488 in tests/conftest.py

View check run for this annotation

Codecov / codecov/patch

tests/conftest.py#L488

Added line #L488 was not covered by tests

def __enter__(self):
return self

Check warning on line 491 in tests/conftest.py

View check run for this annotation

Codecov / codecov/patch

tests/conftest.py#L491

Added line #L491 was not covered by tests

def __exit__(self, exc_type, exc_value, traceback):
del self.model
cleanup_dist_env_and_memory()

Check warning on line 495 in tests/conftest.py

View check run for this annotation

Codecov / codecov/patch

tests/conftest.py#L494-L495

Added lines #L494 - L495 were not covered by tests


@pytest.fixture(scope="session")
def hf_runner():
return HfRunner

Check warning on line 500 in tests/conftest.py

View check run for this annotation

Codecov / codecov/patch

tests/conftest.py#L500

Added line #L500 was not covered by tests
72 changes: 72 additions & 0 deletions tests/e2e/singlecard/test_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
#
from collections.abc import Sequence
from typing import Optional

import pytest
from modelscope import snapshot_download # type: ignore[import-untyped]

from tests.conftest import HfRunner
from tests.utils import check_embeddings_close, matryoshka_fy
from vllm_ascend.utils import vllm_version_is


def run_embedding_correctness_test(
hf_model: "HfRunner",
inputs: list[str],
vllm_outputs: Sequence[list[float]],
dimensions: Optional[int] = None,
):
hf_outputs = hf_model.encode(inputs)
if dimensions:
hf_outputs = matryoshka_fy(hf_outputs, dimensions)

check_embeddings_close(
embeddings_0_lst=hf_outputs,
embeddings_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
tol=1e-2,
)


# dummy to avoid pytest collect nothing and exit code 5
def test_dummy():
assert True


@pytest.mark.skipif(vllm_version_is("0.9.1"),
reason="vLLM 0.9.1 does not support embed task for v1")
def test_embed_models_correctness(hf_runner, vllm_runner):
queries = ['What is the capital of China?', 'Explain gravity']

model_name = snapshot_download("Qwen/Qwen3-Embedding-0.6B")
with vllm_runner(
model_name,
task="embed",
enforce_eager=True,
) as vllm_model:
vllm_outputs = vllm_model.encode(queries)

with hf_runner(
model_name,
dtype="float32",
is_sentence_transformer=True,
) as hf_model:
run_embedding_correctness_test(hf_model, queries, vllm_outputs)
Loading
Loading