Skip to content

[V1][ModelRunner] Support pooling model for v1 engine #1359

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 30, 2025

Conversation

Potabk
Copy link
Contributor

@Potabk Potabk commented Jun 23, 2025

What this PR does / why we need it?

Change as little existing code as possible to add v1 pooling task's support, notice that i move down the vllm.v1.worker.gpu_input_batch to vllm-ascend, Considering the frequent changes in upstream interfaces, in order to decouple, so i move it here

Does this PR introduce any user-facing change?

How was this patch tested?

CI passed with new added/existing test, and I have a simple test was first conducted locally which is adapted from https://www.modelscope.cn/models/Qwen/Qwen3-Embedding-0.6B, just like bellow:

import os

import torch
from vllm import LLM


os.environ["VLLM_USE_MODELSCOPE"]="True"

def get_detailed_instruct(task_description: str, query: str) -> str:
    return f'Instruct: {task_description}\nQuery:{query}'

# Each query must come with a one-sentence instruction that describes the task
task = 'Given a web search query, retrieve relevant passages that answer the query'

queries = [
    get_detailed_instruct(task, 'What is the capital of China?'),
    get_detailed_instruct(task, 'Explain gravity')
]
# No need to add instruction for retrieval documents
documents = [
    "The capital of China is Beijing.",
    "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun."
]
input_texts = queries + documents

model = LLM(model="Qwen/Qwen3-Embedding-0.6B", task="embed")

outputs = model.embed(input_texts)
embeddings = torch.tensor([o.outputs.embedding for o in outputs])
scores = (embeddings[:2] @ embeddings[2:].T)
print(scores.tolist())
# [[0.7620252966880798, 0.14078938961029053], [0.1358368694782257, 0.6013815999031067]]

and the result looks good:

VLLM_USE_V1=1 python offline_embed.py 
INFO 06-23 02:00:16 [__init__.py:39] Available plugins for group vllm.platform_plugins:
INFO 06-23 02:00:16 [__init__.py:41] - ascend -> vllm_ascend:register
INFO 06-23 02:00:16 [__init__.py:44] All plugins in this group will be loaded. Set `VLLM_PLUGINS` to control which plugins to load.
INFO 06-23 02:00:16 [__init__.py:235] Platform plugin ascend is activated
WARNING 06-23 02:00:20 [_custom_ops.py:21] Failed to import from vllm._C with ModuleNotFoundError("No module named 'vllm._C'")
INFO 06-23 02:00:23 [importing.py:63] Triton not installed or not compatible; certain GPU-related functions will not be available.
WARNING 06-23 02:00:24 [registry.py:405] Model architecture DeepSeekMTPModel is already registered, and will be overwritten by the new model class vllm_ascend.models.deepseek_mtp:CustomDeepSeekMTP.
WARNING 06-23 02:00:24 [registry.py:405] Model architecture Qwen2VLForConditionalGeneration is already registered, and will be overwritten by the new model class vllm_ascend.models.qwen2_vl:AscendQwen2VLForConditionalGeneration.
WARNING 06-23 02:00:24 [registry.py:405] Model architecture Qwen2_5_VLForConditionalGeneration is already registered, and will be overwritten by the new model class vllm_ascend.models.qwen2_5_vl:AscendQwen2_5_VLForConditionalGeneration.
WARNING 06-23 02:00:24 [registry.py:405] Model architecture DeepseekV2ForCausalLM is already registered, and will be overwritten by the new model class vllm_ascend.models.deepseek_v2:CustomDeepseekV2ForCausalLM.
WARNING 06-23 02:00:24 [registry.py:405] Model architecture DeepseekV3ForCausalLM is already registered, and will be overwritten by the new model class vllm_ascend.models.deepseek_v2:CustomDeepseekV3ForCausalLM.
WARNING 06-23 02:00:24 [registry.py:405] Model architecture Qwen3MoeForCausalLM is already registered, and will be overwritten by the new model class vllm_ascend.models.qwen3_moe:CustomQwen3MoeForCausalLM.
Downloading Model from https://www.modelscope.cn to directory: /root/.cache/modelscope/hub/models/Qwen/Qwen3-Embedding-0.6B
2025-06-23 02:00:26,409 - modelscope - INFO - Target directory already exists, skipping creation.
Downloading Model from https://www.modelscope.cn to directory: /root/.cache/modelscope/hub/models/Qwen/Qwen3-Embedding-0.6B
2025-06-23 02:00:27,659 - modelscope - INFO - Target directory already exists, skipping creation.
INFO 06-23 02:00:47 [config.py:484] Found sentence-transformers modules configuration.
INFO 06-23 02:00:48 [config.py:504] Found pooling configuration.
INFO 06-23 02:00:58 [config.py:1444] Using max model len 32768
INFO 06-23 02:00:58 [arg_utils.py:1565] (Enabling) chunked prefill by default
INFO 06-23 02:00:58 [arg_utils.py:1568] (Enabling) prefix caching by default
INFO 06-23 02:00:58 [config.py:2188] Chunked prefill is enabled with max_num_batched_tokens=8192.
INFO 06-23 02:00:58 [platform.py:177] PIECEWISE compilation enabled on NPU. use_inductor not supported - using only ACL Graph mode
INFO 06-23 02:00:58 [utils.py:297] Calculated maximum supported batch sizes for ACL graph: 66
INFO 06-23 02:00:58 [utils.py:312] Adjusted ACL graph batch sizes for Qwen3ForCausalLM model (layers: 28): 67 → 66 sizes
Downloading Model from https://www.modelscope.cn to directory: /root/.cache/modelscope/hub/models/Qwen/Qwen3-Embedding-0.6B
2025-06-23 02:00:59,655 - modelscope - INFO - Target directory already exists, skipping creation.
Downloading Model from https://www.modelscope.cn to directory: /root/.cache/modelscope/hub/models/Qwen/Qwen3-Embedding-0.6B
2025-06-23 02:01:01,611 - modelscope - INFO - Target directory already exists, skipping creation.
INFO 06-23 02:01:01 [core.py:459] Waiting for init message from front-end.
INFO 06-23 02:01:01 [platform.py:177] PIECEWISE compilation enabled on NPU. use_inductor not supported - using only ACL Graph mode
INFO 06-23 02:01:01 [utils.py:297] Calculated maximum supported batch sizes for ACL graph: 66
INFO 06-23 02:01:01 [utils.py:323] No adjustment needed for ACL graph batch sizes: Qwen3ForCausalLM model (layers: 28) with 66 sizes
INFO 06-23 02:01:01 [core.py:69] Initializing a V1 LLM engine (v0.9.1) with config: model='Qwen/Qwen3-Embedding-0.6B', speculative_config=None, tokenizer='Qwen/Qwen3-Embedding-0.6B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config={}, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=32768, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=True, quantization=None, enforce_eager=False, kv_cache_dtype=auto,  device_config=npu, decoding_config=DecodingConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_backend=''), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=Qwen/Qwen3-Embedding-0.6B, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=True, use_async_output_proc=False, pooler_config=PoolerConfig(pooling_type='LAST', normalize=True, softmax=None, step_tag_id=None, returned_token_ids=None), compilation_config={"level":3,"debug_dump_path":"","cache_dir":"","backend":"","custom_ops":["all"],"splitting_ops":["vllm.unified_attention","vllm.unified_attention_with_output","vllm.unified_ascend_attention_with_output","vllm.unified_ascend_attention_with_output"],"use_inductor":false,"compile_sizes":[],"inductor_compile_config":{},"inductor_passes":{},"use_cudagraph":true,"cudagraph_num_of_warmups":1,"cudagraph_capture_sizes":[512,504,496,488,480,472,464,456,448,440,432,424,416,408,400,392,384,376,368,360,352,344,336,328,320,312,304,296,288,280,272,264,256,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],"cudagraph_copy_inputs":false,"full_cuda_graph":false,"max_capture_size":512,"local_cache_dir":null}
WARNING 06-23 02:01:02 [utils.py:2753] Methods determine_num_available_blocks,device_config,get_cache_block_size_bytes not implemented in <vllm_ascend.worker.worker_v1.NPUWorker object at 0xfffd2b88af80>
INFO 06-23 02:01:08 [parallel_state.py:1072] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0, EP rank 0
INFO 06-23 02:01:08 [model_runner_v1.py:1925] Starting to load model Qwen/Qwen3-Embedding-0.6B...
Downloading Model from https://www.modelscope.cn to directory: /root/.cache/modelscope/hub/models/Qwen/Qwen3-Embedding-0.6B
2025-06-23 02:01:10,861 - modelscope - INFO - Target directory already exists, skipping creation.
Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  3.62it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  3.61it/s]

INFO 06-23 02:01:11 [default_loader.py:272] Loading weights took 0.45 seconds
INFO 06-23 02:01:12 [model_runner_v1.py:1944] Loading model weights took 1.1177 GB
INFO 06-23 02:01:20 [backends.py:508] Using cache directory: /root/.cache/vllm/torch_compile_cache/4ef0a924ff/rank_0_0/backbone for vLLM's torch.compile
INFO 06-23 02:01:20 [backends.py:519] Dynamo bytecode transform time: 7.24 s
INFO 06-23 02:01:22 [backends.py:193] Compiling a graph for general shape takes 1.76 s
INFO 06-23 02:01:30 [monitor.py:34] torch.compile takes 8.99 s in total
INFO 06-23 02:01:31 [kv_cache_utils.py:716] GPU KV cache size: 495,360 tokens
INFO 06-23 02:01:31 [kv_cache_utils.py:720] Maximum concurrency for 32,768 tokens per request: 15.12x
INFO 06-23 02:02:20 [model_runner_v1.py:2177] Graph capturing finished in 48 secs, took 0.20 GiB
INFO 06-23 02:02:20 [core.py:172] init engine (profile, create kv cache, warmup model) took 68.14 seconds
Downloading Model from https://www.modelscope.cn to directory: /root/.cache/modelscope/hub/models/Qwen/Qwen3-Embedding-0.6B
2025-06-23 02:02:21,510 - modelscope - INFO - Target directory already exists, skipping creation.
Adding requests: 100%|█████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 219.59it/s]
Processed prompts: 100%|████████████████████| 4/4 [00:00<00:00, 26.31it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]
[[0.7653326988220215, 0.14090833067893982], [0.1335897445678711, 0.596314013004303]]

Copy link

codecov bot commented Jun 23, 2025

Codecov Report

Attention: Patch coverage is 75.81967% with 118 lines in your changes missing coverage. Please review.

Project coverage is 30.89%. Comparing base (c30ddb8) to head (4296599).
Report is 27 commits behind head on main.

Files with missing lines Patch % Lines
vllm_ascend/worker/npu_input_batch.py 77.84% 76 Missing ⚠️
tests/conftest.py 28.07% 41 Missing ⚠️
tests/ut/worker/test_input_batch.py 98.73% 1 Missing ⚠️

❌ Your patch check has failed because the patch coverage (75.81%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1359      +/-   ##
==========================================
+ Coverage   27.39%   30.89%   +3.49%     
==========================================
  Files          56       59       +3     
  Lines        6191     6756     +565     
==========================================
+ Hits         1696     2087     +391     
- Misses       4495     4669     +174     
Flag Coverage Δ
unittests 30.89% <75.81%> (+3.49%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@Potabk Potabk changed the title [V1][Worker][ModelRunner][WIP] Support pooling model for v1 engine [V1][Worker][ModelRunner] Support pooling model for v1 engine Jun 23, 2025
@Potabk
Copy link
Contributor Author

Potabk commented Jun 23, 2025

@Yikun @wangxiyuan I think this should be ready for review

Copy link
Collaborator

@wangxiyuan wangxiyuan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add the ut for npu_input_batch as well

@@ -80,6 +80,14 @@
from vllm_ascend.worker.eagle_proposer_v1 import EagleProposer
from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer

if vllm_version_is("0.9.1"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this necessary? vllm_ascend.worker.npu_input_batch can't support 0.9.1 vllm?

Copy link
Contributor Author

@Potabk Potabk Jun 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input_batch in the vllm v0.9.1 has been verified and will not make any broken changes, if we use vllm_ascend customized interface, then we also need some version compatibility for vllm v0.9.1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

such as from vllm.v1.pool.metadata import PoolingMetadata in the input_batch, it is needed to make version compatibility

@Potabk Potabk changed the title [V1][Worker][ModelRunner] Support pooling model for v1 engine [V1][ModelRunner] Support pooling model for v1 engine Jun 25, 2025
Copy link

This pull request has conflicts, please resolve those before we can evaluate the pull request.

Potabk and others added 4 commits June 25, 2025 22:09
Signed-off-by: wangli <wangli858794774@gmail.com>
Signed-off-by: wangli <wangli858794774@gmail.com>
Signed-off-by: wangli <wangli858794774@gmail.com>
Signed-off-by: wangli <858794774@qq.com>
@wangxiyuan
Copy link
Collaborator

ready for review for long time. I'll merge this first

@wangxiyuan wangxiyuan merged commit 5f8241c into vllm-project:main Jun 30, 2025
24 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants