Skip to content

Commit cd4cfee

Browse files
authored
[Model][1/N] Automatic conversion of CrossEncoding model (#20012)
Signed-off-by: wang.yuqi <noooop@126.com>
1 parent e110930 commit cd4cfee

File tree

5 files changed

+239
-167
lines changed

5 files changed

+239
-167
lines changed

tests/models/language/pooling/mteb_utils.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def encode(
4343
# issues by randomizing the order.
4444
r = self.rng.permutation(len(sentences))
4545
sentences = [sentences[i] for i in r]
46-
outputs = self.model.encode(sentences, use_tqdm=False)
46+
outputs = self.model.embed(sentences, use_tqdm=False)
4747
embeds = np.array(outputs)
4848
embeds = embeds[np.argsort(r)]
4949
return embeds
@@ -250,16 +250,19 @@ def mteb_test_rerank_models(hf_runner,
250250
with vllm_runner(model_info.name,
251251
task="score",
252252
max_model_len=None,
253+
max_num_seqs=8,
253254
**vllm_extra_kwargs) as vllm_model:
254255

256+
model_config = vllm_model.model.llm_engine.model_config
257+
255258
if model_info.architecture:
256-
assert (model_info.architecture
257-
in vllm_model.model.llm_engine.model_config.architectures)
259+
assert (model_info.architecture in model_config.architectures)
260+
assert model_config.hf_config.num_labels == 1
258261

259262
vllm_main_score = run_mteb_rerank(VllmMtebEncoder(vllm_model),
260263
tasks=MTEB_RERANK_TASKS,
261264
languages=MTEB_RERANK_LANGS)
262-
vllm_dtype = vllm_model.model.llm_engine.model_config.dtype
265+
vllm_dtype = model_config.dtype
263266

264267
with hf_runner(model_info.name, is_cross_encoder=True,
265268
dtype="float32") as hf_model:

vllm/config.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,10 @@ def __post_init__(self) -> None:
569569
else:
570570
self.truncation_side = "right"
571571

572+
model_info, arch = self.registry.inspect_model_cls(self.architectures)
573+
self._model_info = model_info
574+
self._architecture = arch
575+
572576
self.pooler_config = self._init_pooler_config()
573577

574578
self.dtype = _get_and_verify_dtype(
@@ -660,8 +664,18 @@ def registry(self):
660664

661665
@property
662666
def architectures(self) -> list[str]:
667+
# architectures in the model config.
663668
return getattr(self.hf_config, "architectures", [])
664669

670+
@property
671+
def architecture(self) -> str:
672+
# The architecture vllm actually used.
673+
return self._architecture
674+
675+
@property
676+
def model_info(self) -> dict[str, Any]:
677+
return self._model_info
678+
665679
def maybe_pull_model_tokenizer_for_s3(self, model: str,
666680
tokenizer: str) -> None:
667681
"""Pull model/tokenizer from S3 to temporary directory when needed.
@@ -4450,6 +4464,9 @@ def with_hf_config(
44504464
def __post_init__(self):
44514465
"""Verify configs are valid & consistent with each other.
44524466
"""
4467+
4468+
self.try_verify_and_update_config()
4469+
44534470
if self.model_config is not None:
44544471
self.model_config.verify_async_output_proc(self.parallel_config,
44554472
self.speculative_config,
@@ -4694,11 +4711,21 @@ def _set_cudagraph_sizes(self):
46944711
batch_size_capture_list)
46954712

46964713
def recalculate_max_model_len(self, max_model_len: int):
4714+
# Can only be called in try_verify_and_update_config
46974715
model_config = self.model_config
46984716
max_model_len = model_config.get_and_verify_max_len(max_model_len)
46994717
self.model_config.max_model_len = max_model_len
47004718
self.scheduler_config.max_model_len = max_model_len
4701-
self.compute_hash()
4719+
4720+
def try_verify_and_update_config(self):
4721+
architecture = getattr(self.model_config, "architecture", None)
4722+
if architecture is None:
4723+
return
4724+
4725+
from vllm.model_executor.models.config import MODELS_CONFIG_MAP
4726+
cls = MODELS_CONFIG_MAP.get(architecture, None)
4727+
if cls is not None:
4728+
cls.verify_and_update_config(self)
47024729

47034730
def __str__(self):
47044731
return (

vllm/model_executor/models/bert_with_rope.py

Lines changed: 1 addition & 148 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
from collections.abc import Iterable
4-
from copy import deepcopy
54
from typing import Optional
65

76
import torch
@@ -12,7 +11,6 @@
1211
from vllm.compilation.decorators import support_torch_compile
1312
from vllm.config import CacheConfig, VllmConfig
1413
from vllm.distributed import get_tensor_model_parallel_world_size
15-
from vllm.logger import init_logger
1614
from vllm.model_executor.layers.activation import (get_act_and_mul_fn,
1715
get_act_fn)
1816
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@@ -30,8 +28,6 @@
3028
from vllm.model_executor.models.utils import WeightsMapper
3129
from vllm.sequence import IntermediateTensors
3230

33-
logger = init_logger(__name__)
34-
3531

3632
class BertWithRopeEmbedding(nn.Module):
3733

@@ -408,17 +404,14 @@ class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant):
408404
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
409405
super().__init__()
410406
self.vllm_config = vllm_config
411-
self.config = self.config_verify(vllm_config)
407+
self.config = vllm_config.model_config.hf_config
412408
self.embeddings = BertWithRopeEmbedding(self.config)
413409
self.encoder = BertWithRopeEncoder(
414410
vllm_config=vllm_config,
415411
bias=getattr(self.config, "bias", True),
416412
rotary_kwargs=self.config.rotary_kwargs,
417413
prefix=f"{prefix}.encoder")
418414

419-
def config_verify(self, vllm_config):
420-
raise NotImplementedError
421-
422415
def forward(
423416
self,
424417
input_ids: Optional[torch.Tensor],
@@ -490,95 +483,6 @@ class NomicBertModel(BertWithRope):
490483
"norm2": "mlp_ln",
491484
})
492485

493-
def config_verify(self, vllm_config):
494-
config = vllm_config.model_config.hf_config
495-
496-
assert config.__class__.__name__ == "NomicBertConfig"
497-
assert config.activation_function in ["swiglu", "gelu"]
498-
config.position_embedding_type = getattr(config,
499-
"position_embedding_type",
500-
"rope")
501-
502-
if config.activation_function == "swiglu":
503-
config.hidden_act = "silu"
504-
else:
505-
config.hidden_act = config.activation_function
506-
507-
assert (config.mlp_fc1_bias == config.mlp_fc2_bias ==
508-
config.qkv_proj_bias)
509-
config.bias = config.qkv_proj_bias
510-
511-
assert config.rotary_emb_scale_base is None
512-
assert not config.rotary_emb_interleaved
513-
514-
config.layer_norm_eps = config.layer_norm_epsilon
515-
config.intermediate_size = config.n_inner
516-
config.hidden_size = config.n_embd
517-
config.num_hidden_layers = config.n_layer
518-
519-
head_dim = config.hidden_size // config.num_attention_heads
520-
rotary_emb_dim = head_dim * config.rotary_emb_fraction
521-
max_trained_positions = getattr(config, "max_trained_positions", 2048)
522-
config.rotary_kwargs = {
523-
"head_size": head_dim,
524-
"rotary_dim": rotary_emb_dim,
525-
"max_position": max_trained_positions,
526-
"base": getattr(config, "rope_theta", config.rotary_emb_base),
527-
"rope_scaling": getattr(config, "rope_scaling", None)
528-
}
529-
530-
# we ignore config.rotary_scaling_factor so that for datasets shorter
531-
# than max_trained_positions 2048, the results are consistent
532-
# with SentenceTransformer.
533-
# The context extension uses vllm style rope_theta and rope_scaling.
534-
# See #17785 #18755
535-
if (not vllm_config.model_config.hf_overrides
536-
and vllm_config.model_config.original_max_model_len is None):
537-
# Default
538-
# Reset max_model_len to max_trained_positions.
539-
# nomic-embed-text-v2-moe the length is set to 512
540-
# by sentence_bert_config.json.
541-
max_model_len_before = vllm_config.model_config.max_model_len
542-
max_model_len = min(vllm_config.model_config.max_model_len,
543-
max_trained_positions)
544-
545-
vllm_config.recalculate_max_model_len(max_model_len)
546-
logger.warning(
547-
"Nomic context extension is disabled. "
548-
"Changing max_model_len from %s to %s. "
549-
"To enable context extension, see: "
550-
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/context_extension.html",
551-
max_model_len_before, vllm_config.model_config.max_model_len)
552-
else:
553-
# We need to re-verify max_model_len to avoid lengths
554-
# greater than position_embedding.
555-
model_config = vllm_config.model_config
556-
hf_text_config = model_config.hf_text_config
557-
558-
if isinstance(model_config.hf_overrides, dict):
559-
# hf_overrides_kw
560-
max_model_len = model_config.hf_overrides.get(
561-
"max_model_len", vllm_config.model_config.max_model_len)
562-
else:
563-
# hf_overrides_fn
564-
# This might be overridden by sentence_bert_config.json.
565-
max_model_len = vllm_config.model_config.max_model_len
566-
567-
# reset hf_text_config for recalculate_max_model_len.
568-
if hasattr(hf_text_config, "max_model_len"):
569-
delattr(hf_text_config, "max_model_len")
570-
hf_text_config.max_position_embeddings = max_trained_positions
571-
hf_text_config.rope_scaling = config.rotary_kwargs["rope_scaling"]
572-
573-
# The priority of sentence_bert_config.json is higher
574-
# than max_position_embeddings
575-
encoder_config = deepcopy(model_config.encoder_config)
576-
encoder_config.pop("max_seq_length", None)
577-
model_config.encoder_config = encoder_config
578-
579-
vllm_config.recalculate_max_model_len(max_model_len)
580-
return config
581-
582486

583487
class GteNewModel(BertWithRope):
584488
# for https://huggingface.co/Alibaba-NLP/new-impl
@@ -600,24 +504,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
600504
layer.mlp.gate_up_proj.bias = None
601505
layer.mlp.gate_up_proj.skip_bias_add = True
602506

603-
def config_verify(self, vllm_config):
604-
config = vllm_config.model_config.hf_config
605-
606-
assert config.__class__.__name__ == "NewConfig"
607-
assert config.hidden_act == "gelu"
608-
609-
config.hidden_act = "geglu"
610-
611-
head_dim = config.hidden_size // config.num_attention_heads
612-
config.rotary_kwargs = {
613-
"head_size": head_dim,
614-
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
615-
"max_position": config.max_position_embeddings,
616-
"base": config.rope_theta,
617-
"rope_scaling": getattr(config, "rope_scaling", None)
618-
}
619-
return config
620-
621507
def split_up_gate_proj(self, weights: Iterable[tuple[str, torch.Tensor]]):
622508
n = "mlp.up_gate_proj"
623509
for name, weight in weights:
@@ -652,24 +538,6 @@ class SnowflakeGteNewModel(GteNewModel):
652538
"attention.o_proj": "attn.out_proj",
653539
})
654540

655-
def config_verify(self, vllm_config):
656-
config = vllm_config.model_config.hf_config
657-
658-
assert config.__class__.__name__ == "GteConfig"
659-
assert config.hidden_act == "gelu"
660-
661-
config.hidden_act = "geglu"
662-
663-
head_dim = config.hidden_size // config.num_attention_heads
664-
config.rotary_kwargs = {
665-
"head_size": head_dim,
666-
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
667-
"max_position": config.max_position_embeddings,
668-
"base": config.rope_theta,
669-
"rope_scaling": getattr(config, "rope_scaling", None)
670-
}
671-
return config
672-
673541

674542
class JinaRobertaModel(BertWithRope):
675543
# for https://huggingface.co/jinaai/jina-embeddings-v3
@@ -685,21 +553,6 @@ class JinaRobertaModel(BertWithRope):
685553
"norm2": "mlp_ln",
686554
})
687555

688-
def config_verify(self, vllm_config):
689-
config = vllm_config.model_config.hf_config
690-
691-
assert config.__class__.__name__ == "XLMRobertaFlashConfig"
692-
693-
head_dim = config.hidden_size // config.num_attention_heads
694-
config.rotary_kwargs = {
695-
"head_size": head_dim,
696-
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
697-
"max_position": config.max_position_embeddings,
698-
"base": getattr(config, "rope_theta", config.rotary_emb_base),
699-
"rope_scaling": getattr(config, "rope_scaling", None)
700-
}
701-
return config
702-
703556
def forward(
704557
self,
705558
input_ids: torch.Tensor,

0 commit comments

Comments
 (0)