Skip to content

[Model][1/N] Automatic conversion of CrossEncoding model #20012

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jun 27, 2025
11 changes: 7 additions & 4 deletions tests/models/language/pooling/mteb_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def encode(
# issues by randomizing the order.
r = self.rng.permutation(len(sentences))
sentences = [sentences[i] for i in r]
outputs = self.model.encode(sentences, use_tqdm=False)
outputs = self.model.embed(sentences, use_tqdm=False)
embeds = np.array(outputs)
embeds = embeds[np.argsort(r)]
return embeds
Expand Down Expand Up @@ -250,16 +250,19 @@ def mteb_test_rerank_models(hf_runner,
with vllm_runner(model_info.name,
task="score",
max_model_len=None,
max_num_seqs=8,
**vllm_extra_kwargs) as vllm_model:

model_config = vllm_model.model.llm_engine.model_config

if model_info.architecture:
assert (model_info.architecture
in vllm_model.model.llm_engine.model_config.architectures)
assert (model_info.architecture in model_config.architectures)
assert model_config.hf_config.num_labels == 1

vllm_main_score = run_mteb_rerank(VllmMtebEncoder(vllm_model),
tasks=MTEB_RERANK_TASKS,
languages=MTEB_RERANK_LANGS)
vllm_dtype = vllm_model.model.llm_engine.model_config.dtype
vllm_dtype = model_config.dtype

with hf_runner(model_info.name, is_cross_encoder=True,
dtype="float32") as hf_model:
Expand Down
29 changes: 28 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,10 @@ def __post_init__(self) -> None:
else:
self.truncation_side = "right"

model_info, arch = self.registry.inspect_model_cls(self.architectures)
self._model_info = model_info
self._architecture = arch

self.pooler_config = self._init_pooler_config()

self.dtype = _get_and_verify_dtype(
Expand Down Expand Up @@ -660,8 +664,18 @@ def registry(self):

@property
def architectures(self) -> list[str]:
# architectures in the model config.
return getattr(self.hf_config, "architectures", [])

@property
def architecture(self) -> str:
# The architecture vllm actually used.
return self._architecture

@property
def model_info(self) -> dict[str, Any]:
return self._model_info

def maybe_pull_model_tokenizer_for_s3(self, model: str,
tokenizer: str) -> None:
"""Pull model/tokenizer from S3 to temporary directory when needed.
Expand Down Expand Up @@ -4450,6 +4464,9 @@ def with_hf_config(
def __post_init__(self):
"""Verify configs are valid & consistent with each other.
"""

self.try_verify_and_update_config()

if self.model_config is not None:
self.model_config.verify_async_output_proc(self.parallel_config,
self.speculative_config,
Expand Down Expand Up @@ -4694,11 +4711,21 @@ def _set_cudagraph_sizes(self):
batch_size_capture_list)

def recalculate_max_model_len(self, max_model_len: int):
# Can only be called in try_verify_and_update_config
model_config = self.model_config
max_model_len = model_config.get_and_verify_max_len(max_model_len)
self.model_config.max_model_len = max_model_len
self.scheduler_config.max_model_len = max_model_len
self.compute_hash()

def try_verify_and_update_config(self):
architecture = getattr(self.model_config, "architecture", None)
if architecture is None:
return

from vllm.model_executor.models.config import MODELS_CONFIG_MAP
cls = MODELS_CONFIG_MAP.get(architecture, None)
if cls is not None:
cls.verify_and_update_config(self)

def __str__(self):
return (
Expand Down
149 changes: 1 addition & 148 deletions vllm/model_executor/models/bert_with_rope.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from copy import deepcopy
from typing import Optional

import torch
Expand All @@ -12,7 +11,6 @@
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import (get_act_and_mul_fn,
get_act_fn)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
Expand All @@ -30,8 +28,6 @@
from vllm.model_executor.models.utils import WeightsMapper
from vllm.sequence import IntermediateTensors

logger = init_logger(__name__)


class BertWithRopeEmbedding(nn.Module):

Expand Down Expand Up @@ -408,17 +404,14 @@ class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.vllm_config = vllm_config
self.config = self.config_verify(vllm_config)
self.config = vllm_config.model_config.hf_config
self.embeddings = BertWithRopeEmbedding(self.config)
self.encoder = BertWithRopeEncoder(
vllm_config=vllm_config,
bias=getattr(self.config, "bias", True),
rotary_kwargs=self.config.rotary_kwargs,
prefix=f"{prefix}.encoder")

def config_verify(self, vllm_config):
raise NotImplementedError

def forward(
self,
input_ids: Optional[torch.Tensor],
Expand Down Expand Up @@ -490,95 +483,6 @@ class NomicBertModel(BertWithRope):
"norm2": "mlp_ln",
})

def config_verify(self, vllm_config):
config = vllm_config.model_config.hf_config

assert config.__class__.__name__ == "NomicBertConfig"
assert config.activation_function in ["swiglu", "gelu"]
config.position_embedding_type = getattr(config,
"position_embedding_type",
"rope")

if config.activation_function == "swiglu":
config.hidden_act = "silu"
else:
config.hidden_act = config.activation_function

assert (config.mlp_fc1_bias == config.mlp_fc2_bias ==
config.qkv_proj_bias)
config.bias = config.qkv_proj_bias

assert config.rotary_emb_scale_base is None
assert not config.rotary_emb_interleaved

config.layer_norm_eps = config.layer_norm_epsilon
config.intermediate_size = config.n_inner
config.hidden_size = config.n_embd
config.num_hidden_layers = config.n_layer

head_dim = config.hidden_size // config.num_attention_heads
rotary_emb_dim = head_dim * config.rotary_emb_fraction
max_trained_positions = getattr(config, "max_trained_positions", 2048)
config.rotary_kwargs = {
"head_size": head_dim,
"rotary_dim": rotary_emb_dim,
"max_position": max_trained_positions,
"base": getattr(config, "rope_theta", config.rotary_emb_base),
"rope_scaling": getattr(config, "rope_scaling", None)
}

# we ignore config.rotary_scaling_factor so that for datasets shorter
# than max_trained_positions 2048, the results are consistent
# with SentenceTransformer.
# The context extension uses vllm style rope_theta and rope_scaling.
# See #17785 #18755
if (not vllm_config.model_config.hf_overrides
and vllm_config.model_config.original_max_model_len is None):
# Default
# Reset max_model_len to max_trained_positions.
# nomic-embed-text-v2-moe the length is set to 512
# by sentence_bert_config.json.
max_model_len_before = vllm_config.model_config.max_model_len
max_model_len = min(vllm_config.model_config.max_model_len,
max_trained_positions)

vllm_config.recalculate_max_model_len(max_model_len)
logger.warning(
"Nomic context extension is disabled. "
"Changing max_model_len from %s to %s. "
"To enable context extension, see: "
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/context_extension.html",
max_model_len_before, vllm_config.model_config.max_model_len)
else:
# We need to re-verify max_model_len to avoid lengths
# greater than position_embedding.
model_config = vllm_config.model_config
hf_text_config = model_config.hf_text_config

if isinstance(model_config.hf_overrides, dict):
# hf_overrides_kw
max_model_len = model_config.hf_overrides.get(
"max_model_len", vllm_config.model_config.max_model_len)
else:
# hf_overrides_fn
# This might be overridden by sentence_bert_config.json.
max_model_len = vllm_config.model_config.max_model_len

# reset hf_text_config for recalculate_max_model_len.
if hasattr(hf_text_config, "max_model_len"):
delattr(hf_text_config, "max_model_len")
hf_text_config.max_position_embeddings = max_trained_positions
hf_text_config.rope_scaling = config.rotary_kwargs["rope_scaling"]

# The priority of sentence_bert_config.json is higher
# than max_position_embeddings
encoder_config = deepcopy(model_config.encoder_config)
encoder_config.pop("max_seq_length", None)
model_config.encoder_config = encoder_config

vllm_config.recalculate_max_model_len(max_model_len)
return config


class GteNewModel(BertWithRope):
# for https://huggingface.co/Alibaba-NLP/new-impl
Expand All @@ -600,24 +504,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
layer.mlp.gate_up_proj.bias = None
layer.mlp.gate_up_proj.skip_bias_add = True

def config_verify(self, vllm_config):
config = vllm_config.model_config.hf_config

assert config.__class__.__name__ == "NewConfig"
assert config.hidden_act == "gelu"

config.hidden_act = "geglu"

head_dim = config.hidden_size // config.num_attention_heads
config.rotary_kwargs = {
"head_size": head_dim,
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
"max_position": config.max_position_embeddings,
"base": config.rope_theta,
"rope_scaling": getattr(config, "rope_scaling", None)
}
return config

def split_up_gate_proj(self, weights: Iterable[tuple[str, torch.Tensor]]):
n = "mlp.up_gate_proj"
for name, weight in weights:
Expand Down Expand Up @@ -652,24 +538,6 @@ class SnowflakeGteNewModel(GteNewModel):
"attention.o_proj": "attn.out_proj",
})

def config_verify(self, vllm_config):
config = vllm_config.model_config.hf_config

assert config.__class__.__name__ == "GteConfig"
assert config.hidden_act == "gelu"

config.hidden_act = "geglu"

head_dim = config.hidden_size // config.num_attention_heads
config.rotary_kwargs = {
"head_size": head_dim,
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
"max_position": config.max_position_embeddings,
"base": config.rope_theta,
"rope_scaling": getattr(config, "rope_scaling", None)
}
return config


class JinaRobertaModel(BertWithRope):
# for https://huggingface.co/jinaai/jina-embeddings-v3
Expand All @@ -685,21 +553,6 @@ class JinaRobertaModel(BertWithRope):
"norm2": "mlp_ln",
})

def config_verify(self, vllm_config):
config = vllm_config.model_config.hf_config

assert config.__class__.__name__ == "XLMRobertaFlashConfig"

head_dim = config.hidden_size // config.num_attention_heads
config.rotary_kwargs = {
"head_size": head_dim,
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
"max_position": config.max_position_embeddings,
"base": getattr(config, "rope_theta", config.rotary_emb_base),
"rope_scaling": getattr(config, "rope_scaling", None)
}
return config

def forward(
self,
input_ids: torch.Tensor,
Expand Down
Loading