Skip to content

[New Model]: Support Qwen3 Embedding & Reranker #19260

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 28 commits into from
Jun 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 25 additions & 17 deletions docs/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -387,18 +387,19 @@ See [this page](./pooling_models.md) for more information on how to use pooling

Specified using `--task embed`.

| Architecture | Models | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] |
|--------------------------------------------------------|---------------------|---------------------------------------------------------------------------------------------------------------------|------------------------|-----------------------------|
| `BertModel` | BERT-based | `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc. | | |
| `Gemma2Model` | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | ✅︎ | |
| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ |
| `GteModel` | Arctic-Embed-2.0-M | `Snowflake/snowflake-arctic-embed-m-v2.0`. | ︎ | |
| `GteNewModel` | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-base`, etc. | ︎ | ︎ |
| `ModernBertModel` | ModernBERT-based | `Alibaba-NLP/gte-modernbert-base`, etc. | ︎ | ︎ |
| `NomicBertModel` | Nomic BERT | `nomic-ai/nomic-embed-text-v1`, `nomic-ai/nomic-embed-text-v2-moe`, `Snowflake/snowflake-arctic-embed-m-long`, etc. | ︎ | ︎ |
| `LlamaModel`, `LlamaForCausalLM`, `MistralModel`, etc. | Llama-based | `intfloat/e5-mistral-7b-instruct`, etc. | ✅︎ | ✅︎ |
| `Qwen2Model`, `Qwen2ForCausalLM` | Qwen2-based | `ssmits/Qwen2-7B-Instruct-embed-base` (see note), `Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc. | ✅︎ | ✅︎ |
| `RobertaModel`, `RobertaForMaskedLM` | RoBERTa-based | `sentence-transformers/all-roberta-large-v1`, etc. | | |
| Architecture | Models | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] |
|--------------------------------------------------------|---------------------|---------------------------------------------------------------------------------------------------------------------|----------------------|---------------------------|
| `BertModel` | BERT-based | `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc. | | |
| `Gemma2Model` | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | ✅︎ | |
| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ |
| `GteModel` | Arctic-Embed-2.0-M | `Snowflake/snowflake-arctic-embed-m-v2.0`. | ︎ | |
| `GteNewModel` | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-base`, etc. | ︎ | ︎ |
| `ModernBertModel` | ModernBERT-based | `Alibaba-NLP/gte-modernbert-base`, etc. | ︎ | ︎ |
| `NomicBertModel` | Nomic BERT | `nomic-ai/nomic-embed-text-v1`, `nomic-ai/nomic-embed-text-v2-moe`, `Snowflake/snowflake-arctic-embed-m-long`, etc. | ︎ | ︎ |
| `LlamaModel`, `LlamaForCausalLM`, `MistralModel`, etc. | Llama-based | `intfloat/e5-mistral-7b-instruct`, etc. | ✅︎ | ✅︎ |
| `Qwen2Model`, `Qwen2ForCausalLM` | Qwen2-based | `ssmits/Qwen2-7B-Instruct-embed-base` (see note), `Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc. | ✅︎ | ✅︎ |
| `Qwen3Model`, `Qwen3ForCausalLM` | Qwen3-based | `Qwen/Qwen3-Embedding-0.6B`, etc. | ✅︎ | ✅︎ |
| `RobertaModel`, `RobertaForMaskedLM` | RoBERTa-based | `sentence-transformers/all-roberta-large-v1`, etc. | | |

!!! note
`ssmits/Qwen2-7B-Instruct-embed-base` has an improperly defined Sentence Transformers config.
Expand Down Expand Up @@ -450,12 +451,19 @@ If your model is not in the above list, we will try to automatically convert the

Specified using `--task score`.

| Architecture | Models | Example HF Models |
|---------------------------------------|-------------------|----------------------------------------------|
| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. |
| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. |
| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. |
| Architecture | Models | Example HF Models |
|---------------------------------------|-------------------|--------------------------------------------------------------------------------------|
| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. |
| `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. |
| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. |
| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. |

!!! note
Load the official original `Qwen3 Reranker` by using the following command. More information can be found at: <gh-file:examples/offline_inference/qwen3_reranker.py>.

```bash
vllm serve Qwen/Qwen3-Reranker-0.6B --hf_overrides '{"architectures": ["Qwen3ForSequenceClassification"],"classifier_from_token": ["no", "yes"],"is_original_qwen3_reranker": true}'
```
[](){ #supported-mm-models }

## List of Multimodal Language Models
Expand Down
77 changes: 77 additions & 0 deletions examples/offline_inference/qwen3_reranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501

from vllm import LLM

model_name = "Qwen/Qwen3-Reranker-0.6B"

# What is the difference between the official original version and one
# that has been converted into a sequence classification model?
# Qwen3-Reranker is a language model that doing reranker by using the
# logits of "no" and "yes" tokens.
# It needs to computing 151669 tokens logits, making this method extremely
# inefficient, not to mention incompatible with the vllm score API.
# A method for converting the original model into a sequence classification
# model was proposed. See:https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
# Models converted offline using this method can not only be more efficient
# and support the vllm score API, but also make the init parameters more
# concise, for example.
# model = LLM(model="tomaarsen/Qwen3-Reranker-0.6B-seq-cls", task="score")

# If you want to load the official original version, the init parameters are
# as follows.

model = LLM(
model=model_name,
task="score",
hf_overrides={
"architectures": ["Qwen3ForSequenceClassification"],
"classifier_from_token": ["no", "yes"],
"is_original_qwen3_reranker": True,
},
)

# Why do we need hf_overrides for the official original version:
# vllm converts it to Qwen3ForSequenceClassification when loaded for
# better performance.
# - Firstly, we need using `"architectures": ["Qwen3ForSequenceClassification"],`
# to manually route to Qwen3ForSequenceClassification.
# - Then, we will extract the vector corresponding to classifier_from_token
# from lm_head using `"classifier_from_token": ["no", "yes"]`.
# - Third, we will convert these two vectors into one vector. The use of
# conversion logic is controlled by `using "is_original_qwen3_reranker": True`.

# Please use the query_template and document_template to format the query and
# document for better reranker results.

prefix = '<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n'
suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"

query_template = "{prefix}<Instruct>: {instruction}\n<Query>: {query}\n"
document_template = "<Document>: {doc}{suffix}"

if __name__ == "__main__":
instruction = (
"Given a web search query, retrieve relevant passages that answer the query"
)

queries = [
"What is the capital of China?",
"Explain gravity",
]

documents = [
"The capital of China is Beijing.",
"Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.",
]

queries = [
query_template.format(prefix=prefix, instruction=instruction, query=query)
for query in queries
]
documents = [document_template.format(doc=doc, suffix=suffix) for doc in documents]

outputs = model.score(queries, documents)

print([output.outputs.score for output in outputs])
9 changes: 9 additions & 0 deletions tests/models/language/pooling/test_gte.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,15 @@
EmbedModelInfo("Alibaba-NLP/gte-modernbert-base",
architecture="ModernBertModel",
enable_test=True),
########## Qwen3ForCausalLM
EmbedModelInfo("Qwen/Qwen3-Embedding-0.6B",
architecture="Qwen3ForCausalLM",
dtype="float32",
enable_test=True),
EmbedModelInfo("Qwen/Qwen3-Embedding-4B",
architecture="Qwen3ForCausalLM",
dtype="float32",
enable_test=False),
]


Expand Down
87 changes: 87 additions & 0 deletions tests/models/language/pooling/test_qwen3_reranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# SPDX-License-Identifier: Apache-2.0
import pytest

model_name = "Qwen/Qwen3-Reranker-4B"

text_1 = "What is the capital of France?"
texts_2 = [
"The capital of Brazil is Brasilia.",
"The capital of France is Paris.",
]


def vllm_reranker(model_name):
from vllm import LLM

model = LLM(model=model_name,
task="score",
hf_overrides={
"architectures": ["Qwen3ForSequenceClassification"],
"classifier_from_token": ["no", "yes"],
"is_original_qwen3_reranker": True,
},
dtype="float32")

text_1 = "What is the capital of France?"
texts_2 = [
"The capital of Brazil is Brasilia.",
"The capital of France is Paris.",
]

outputs = model.score(text_1, texts_2)

return [output.outputs.score for output in outputs]


def hf_reranker(model_name):
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')
model = AutoModelForCausalLM.from_pretrained(model_name).eval()

token_false_id = tokenizer.convert_tokens_to_ids("no")
token_true_id = tokenizer.convert_tokens_to_ids("yes")

max_length = 8192

def process_inputs(pairs):
inputs = tokenizer(pairs,
padding=False,
truncation='longest_first',
return_attention_mask=False,
max_length=max_length)
for i, ele in enumerate(inputs['input_ids']):
inputs['input_ids'][i] = ele
inputs = tokenizer.pad(inputs,
padding=True,
return_tensors="pt",
max_length=max_length)
for key in inputs:
inputs[key] = inputs[key].to(model.device)
return inputs

@torch.no_grad()
def compute_logits(inputs, **kwargs):
batch_scores = model(**inputs).logits[:, -1, :]
true_vector = batch_scores[:, token_true_id]
false_vector = batch_scores[:, token_false_id]
batch_scores = torch.stack([false_vector, true_vector], dim=1)
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
scores = batch_scores[:, 1].exp().tolist()
return scores

pairs = [(text_1, texts_2[0]), (text_1, texts_2[1])]
inputs = process_inputs(pairs)
scores = compute_logits(inputs)

return scores


@pytest.mark.parametrize("model_name", [model_name])
def test_model(model_name):
hf_outputs = hf_reranker(model_name)
vllm_outputs = vllm_reranker(model_name)

assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.01)
assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.01)
73 changes: 73 additions & 0 deletions tests/models/language/pooling/test_qwen3_reranker_seq_cls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# SPDX-License-Identifier: Apache-2.0
import pytest

model_name = "tomaarsen/Qwen3-Reranker-0.6B-seq-cls"

text_1 = "What is the capital of France?"
texts_2 = [
"The capital of Brazil is Brasilia.",
"The capital of France is Paris.",
]


def vllm_reranker(model_name):
from vllm import LLM

model = LLM(model=model_name, task="score")
outputs = model.score(text_1, texts_2)

return [output.outputs.score for output in outputs]


def hf_reranker(model_name):
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')
model = AutoModelForCausalLM.from_pretrained(model_name).eval()

token_false_id = tokenizer.convert_tokens_to_ids("no")
token_true_id = tokenizer.convert_tokens_to_ids("yes")

max_length = 8192

def process_inputs(pairs):
inputs = tokenizer(pairs,
padding=False,
truncation='longest_first',
return_attention_mask=False,
max_length=max_length)
for i, ele in enumerate(inputs['input_ids']):
inputs['input_ids'][i] = ele
inputs = tokenizer.pad(inputs,
padding=True,
return_tensors="pt",
max_length=max_length)
for key in inputs:
inputs[key] = inputs[key].to(model.device)
return inputs

@torch.no_grad()
def compute_logits(inputs, **kwargs):
batch_scores = model(**inputs).logits[:, -1, :]
true_vector = batch_scores[:, token_true_id]
false_vector = batch_scores[:, token_false_id]
batch_scores = torch.stack([false_vector, true_vector], dim=1)
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
scores = batch_scores[:, 1].exp().tolist()
return scores

pairs = [(text_1, texts_2[0]), (text_1, texts_2[1])]
inputs = process_inputs(pairs)
scores = compute_logits(inputs)

return scores


@pytest.mark.parametrize("model_name", [model_name])
def test_model(model_name):
hf_outputs = hf_reranker(model_name)
vllm_outputs = vllm_reranker(model_name)

assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.01)
assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.01)
1 change: 1 addition & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def check_available_online(
"Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"),
"Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"),
"Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"),
"Qwen3ForSequenceClassification": _HfExamplesInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls"), # noqa: E501
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"),
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b", # noqa: E501
v0_only=True),
Expand Down
Loading