Skip to content

Commit 71b1df4

Browse files
committed
+ GteNewForSequenceClassification support
Signed-off-by: wang.yuqi <noooop@126.com>
1 parent 50cd102 commit 71b1df4

File tree

5 files changed

+126
-18
lines changed

5 files changed

+126
-18
lines changed

docs/models/supported_models.md

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -453,19 +453,24 @@ If your model is not in the above list, we will try to automatically convert the
453453

454454
Specified using `--task score`.
455455

456-
| Architecture | Models | Example HF Models | [V1](gh-issue:8779) |
457-
|---------------------------------------|-------------------|--------------------------------------------------------------------------------------|-----------------------|
458-
| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | |
459-
| `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | |
460-
| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | |
461-
| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. | |
456+
| Architecture | Models | Example HF Models | [V1](gh-issue:8779) |
457+
|---------------------------------------|---------------------|--------------------------------------------------------------------------------------|---------------------|
458+
| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | |
459+
| `GteNewForSequenceClassification` | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-reranker-base`, etc. | |
460+
| `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | |
461+
| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | |
462+
| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. | |
463+
464+
!!! note
465+
The second-generation GTE model (mGTE-TRM) is named `NewModel`. The name `NewModel` is too generic, you should set `--hf-overrides '{"architectures": ["GteNewForSequenceClassification"]}'` to specify the use of the `GteNewForSequenceClassification` architecture.
462466

463467
!!! note
464468
Load the official original `Qwen3 Reranker` by using the following command. More information can be found at: <gh-file:examples/offline_inference/qwen3_reranker.py>.
465469

466470
```bash
467471
vllm serve Qwen/Qwen3-Reranker-0.6B --hf_overrides '{"architectures": ["Qwen3ForSequenceClassification"],"classifier_from_token": ["no", "yes"],"is_original_qwen3_reranker": true}'
468472
```
473+
469474
[](){ #supported-mm-models }
470475

471476
## List of Multimodal Language Models

tests/models/language/pooling/test_gte.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@
5959
]
6060

6161
RERANK_MODELS = [
62+
RerankModelInfo("Alibaba-NLP/gte-multilingual-reranker-base",
63+
architecture="GteNewForSequenceClassification",
64+
enable_test=True),
6265
RerankModelInfo("Alibaba-NLP/gte-reranker-modernbert-base",
6366
architecture="ModernBertForSequenceClassification",
6467
enable_test=False),
@@ -93,10 +96,30 @@ def test_embed_models_correctness(hf_runner, vllm_runner,
9396
@pytest.mark.parametrize("model_info", RERANK_MODELS)
9497
def test_rerank_models_mteb(hf_runner, vllm_runner,
9598
model_info: RerankModelInfo) -> None:
96-
mteb_test_rerank_models(hf_runner, vllm_runner, model_info)
99+
100+
vllm_extra_kwargs: dict[str, Any] = {}
101+
if model_info.architecture == "GteNewForSequenceClassification":
102+
vllm_extra_kwargs["hf_overrides"] = {
103+
"architectures": ["GteNewForSequenceClassification"]
104+
}
105+
106+
mteb_test_rerank_models(hf_runner,
107+
vllm_runner,
108+
model_info,
109+
vllm_extra_kwargs=vllm_extra_kwargs)
97110

98111

99112
@pytest.mark.parametrize("model_info", RERANK_MODELS)
100113
def test_rerank_models_correctness(hf_runner, vllm_runner,
101114
model_info: RerankModelInfo) -> None:
102-
ping_pong_test_score_models(hf_runner, vllm_runner, model_info)
115+
116+
vllm_extra_kwargs: dict[str, Any] = {}
117+
if model_info.architecture == "GteNewForSequenceClassification":
118+
vllm_extra_kwargs["hf_overrides"] = {
119+
"architectures": ["GteNewForSequenceClassification"]
120+
}
121+
122+
ping_pong_test_score_models(hf_runner,
123+
vllm_runner,
124+
model_info,
125+
vllm_extra_kwargs=vllm_extra_kwargs)

tests/models/registry.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,9 +301,13 @@ def check_available_online(
301301
_CROSS_ENCODER_EXAMPLE_MODELS = {
302302
# [Text-only]
303303
"BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2"), # noqa: E501
304+
"GteNewForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-multilingual-reranker-base", # noqa: E501
305+
hf_overrides={
306+
"architectures": ["GteNewForSequenceClassification"] # noqa: E501
307+
}),
308+
"ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base"), # noqa: E501
304309
"RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base"), # noqa: E501
305310
"XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3"), # noqa: E501
306-
"ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base"), # noqa: E501
307311
}
308312

309313
_MULTIMODAL_EXAMPLE_MODELS = {

vllm/model_executor/models/bert_with_rope.py

Lines changed: 82 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,21 @@
2020
QKVParallelLinear,
2121
ReplicatedLinear,
2222
RowParallelLinear)
23+
from vllm.model_executor.layers.pooler import ClassifierPooler
2324
from vllm.model_executor.layers.quantization import QuantizationConfig
2425
from vllm.model_executor.layers.rotary_embedding import get_rope
2526
from vllm.model_executor.layers.vocab_parallel_embedding import (
2627
VocabParallelEmbedding)
2728
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
2829
from vllm.model_executor.models import SupportsV0Only
29-
from vllm.model_executor.models.interfaces import SupportsQuant
30-
from vllm.model_executor.models.utils import WeightsMapper
31-
from vllm.sequence import IntermediateTensors
30+
from vllm.model_executor.models.bert import BertPooler
31+
from vllm.model_executor.models.interfaces import (SupportsCrossEncoding,
32+
SupportsQuant)
33+
from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix
34+
from vllm.model_executor.pooling_metadata import PoolingMetadata
35+
from vllm.sequence import IntermediateTensors, PoolerOutput
36+
from vllm.transformers_utils.config import (
37+
get_cross_encoder_activation_function)
3238

3339
logger = init_logger(__name__)
3440

@@ -405,16 +411,23 @@ def forward(
405411
class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant):
406412
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
407413

408-
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
414+
def __init__(self,
415+
*,
416+
vllm_config: VllmConfig,
417+
prefix: str = "",
418+
add_pooling_layer=False):
409419
super().__init__()
410420
self.vllm_config = vllm_config
421+
self.add_pooling_layer = add_pooling_layer
411422
self.config = self.config_verify(vllm_config)
412423
self.embeddings = BertWithRopeEmbedding(self.config)
413424
self.encoder = BertWithRopeEncoder(
414425
vllm_config=vllm_config,
415426
bias=getattr(self.config, "bias", True),
416427
rotary_kwargs=self.config.rotary_kwargs,
417428
prefix=f"{prefix}.encoder")
429+
if self.add_pooling_layer:
430+
self.pooler = BertPooler(self.config)
418431

419432
def config_verify(self, vllm_config):
420433
raise NotImplementedError
@@ -450,7 +463,7 @@ def load_weights(self, weights: Iterable[tuple[str,
450463
params_dict = dict(self.named_parameters())
451464
loaded_params: set[str] = set()
452465
for name, loaded_weight in weights:
453-
if "pooler" in name:
466+
if not self.add_pooling_layer and "pooler" in name:
454467
continue
455468
for (param_name, weight_name, shard_id) in stacked_params_mapping:
456469
if weight_name not in name:
@@ -591,8 +604,8 @@ class GteNewModel(BertWithRope):
591604
"attention.o_proj": "attn.out_proj",
592605
})
593606

594-
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
595-
super().__init__(vllm_config=vllm_config, prefix=prefix)
607+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs):
608+
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
596609

597610
# GteNewModel only gate_up_proj does not have bias.
598611
# Hack method learned from vllm/model_executor/models/glm.py
@@ -762,3 +775,65 @@ def load_weights(self, weights: Iterable[tuple[str,
762775
torch.Tensor]]) -> set[str]:
763776
weights = self.jina_merge_lora_weights(weights)
764777
return super().load_weights(weights)
778+
779+
780+
class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding,
781+
SupportsQuant):
782+
783+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
784+
super().__init__()
785+
config = vllm_config.model_config.hf_config
786+
787+
self.default_activation_function = \
788+
get_cross_encoder_activation_function(config)
789+
790+
self.num_labels = config.num_labels
791+
self.new = GteNewModel(vllm_config=vllm_config,
792+
prefix=maybe_prefix(prefix, "new"),
793+
add_pooling_layer=True)
794+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
795+
self._pooler = ClassifierPooler(vllm_config.model_config,
796+
self.classifier, self.new.pooler)
797+
798+
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
799+
800+
self_weights = []
801+
802+
def weight_filter():
803+
for name, weight in weights:
804+
if name.startswith("new."):
805+
yield (name[len("new."):], weight)
806+
else:
807+
self_weights.append((name, weight))
808+
809+
self.new.load_weights(weight_filter())
810+
811+
params_dict = dict(self.named_parameters())
812+
813+
for name, loaded_weight in self_weights:
814+
if name.startswith("classifier"):
815+
param = params_dict[name]
816+
weight_loader = getattr(param, "weight_loader",
817+
default_weight_loader)
818+
weight_loader(param, loaded_weight)
819+
820+
def pooler(
821+
self,
822+
hidden_states: torch.Tensor,
823+
pooling_metadata: PoolingMetadata,
824+
) -> Optional[PoolerOutput]:
825+
return self._pooler(hidden_states, pooling_metadata)
826+
827+
def forward(
828+
self,
829+
input_ids: Optional[torch.Tensor],
830+
positions: torch.Tensor,
831+
intermediate_tensors: Optional[IntermediateTensors] = None,
832+
inputs_embeds: Optional[torch.Tensor] = None,
833+
token_type_ids: Optional[torch.Tensor] = None,
834+
) -> torch.Tensor:
835+
return self.new(input_ids=input_ids,
836+
positions=positions,
837+
inputs_embeds=inputs_embeds,
838+
intermediate_tensors=intermediate_tensors,
839+
token_type_ids=token_type_ids)

vllm/model_executor/models/registry.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,9 +172,10 @@
172172
"ModernBertForSequenceClassification": ("modernbert",
173173
"ModernBertForSequenceClassification"),
174174
# [Auto-converted (see adapters.py)]
175-
"Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForCausalLM"), # noqa: E501
176-
"Qwen3ForSequenceClassification": ("qwen3", "Qwen3ForSequenceClassification"), # noqa: E501
177175
"GemmaForSequenceClassification": ("gemma", "GemmaForCausalLM"),
176+
"GteNewForSequenceClassification": ("bert_with_rope", "GteNewForSequenceClassification"), # noqa: E501
177+
"Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForCausalLM"),
178+
"Qwen3ForSequenceClassification": ("qwen3", "Qwen3ForSequenceClassification"), # noqa: E501
178179
}
179180

180181
_MULTIMODAL_MODELS = {

0 commit comments

Comments
 (0)