1
1
# SPDX-License-Identifier: Apache-2.0
2
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
3
from collections .abc import Iterable
4
- from copy import deepcopy
5
4
from typing import Optional
6
5
7
6
import torch
12
11
from vllm .compilation .decorators import support_torch_compile
13
12
from vllm .config import CacheConfig , VllmConfig
14
13
from vllm .distributed import get_tensor_model_parallel_world_size
15
- from vllm .logger import init_logger
16
14
from vllm .model_executor .layers .activation import (get_act_and_mul_fn ,
17
15
get_act_fn )
18
16
from vllm .model_executor .layers .linear import (ColumnParallelLinear ,
30
28
from vllm .model_executor .models .utils import WeightsMapper
31
29
from vllm .sequence import IntermediateTensors
32
30
33
- logger = init_logger (__name__ )
34
-
35
31
36
32
class BertWithRopeEmbedding (nn .Module ):
37
33
@@ -408,17 +404,14 @@ class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant):
408
404
def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
409
405
super ().__init__ ()
410
406
self .vllm_config = vllm_config
411
- self .config = self . config_verify ( vllm_config )
407
+ self .config = vllm_config . model_config . hf_config
412
408
self .embeddings = BertWithRopeEmbedding (self .config )
413
409
self .encoder = BertWithRopeEncoder (
414
410
vllm_config = vllm_config ,
415
411
bias = getattr (self .config , "bias" , True ),
416
412
rotary_kwargs = self .config .rotary_kwargs ,
417
413
prefix = f"{ prefix } .encoder" )
418
414
419
- def config_verify (self , vllm_config ):
420
- raise NotImplementedError
421
-
422
415
def forward (
423
416
self ,
424
417
input_ids : Optional [torch .Tensor ],
@@ -490,95 +483,6 @@ class NomicBertModel(BertWithRope):
490
483
"norm2" : "mlp_ln" ,
491
484
})
492
485
493
- def config_verify (self , vllm_config ):
494
- config = vllm_config .model_config .hf_config
495
-
496
- assert config .__class__ .__name__ == "NomicBertConfig"
497
- assert config .activation_function in ["swiglu" , "gelu" ]
498
- config .position_embedding_type = getattr (config ,
499
- "position_embedding_type" ,
500
- "rope" )
501
-
502
- if config .activation_function == "swiglu" :
503
- config .hidden_act = "silu"
504
- else :
505
- config .hidden_act = config .activation_function
506
-
507
- assert (config .mlp_fc1_bias == config .mlp_fc2_bias ==
508
- config .qkv_proj_bias )
509
- config .bias = config .qkv_proj_bias
510
-
511
- assert config .rotary_emb_scale_base is None
512
- assert not config .rotary_emb_interleaved
513
-
514
- config .layer_norm_eps = config .layer_norm_epsilon
515
- config .intermediate_size = config .n_inner
516
- config .hidden_size = config .n_embd
517
- config .num_hidden_layers = config .n_layer
518
-
519
- head_dim = config .hidden_size // config .num_attention_heads
520
- rotary_emb_dim = head_dim * config .rotary_emb_fraction
521
- max_trained_positions = getattr (config , "max_trained_positions" , 2048 )
522
- config .rotary_kwargs = {
523
- "head_size" : head_dim ,
524
- "rotary_dim" : rotary_emb_dim ,
525
- "max_position" : max_trained_positions ,
526
- "base" : getattr (config , "rope_theta" , config .rotary_emb_base ),
527
- "rope_scaling" : getattr (config , "rope_scaling" , None )
528
- }
529
-
530
- # we ignore config.rotary_scaling_factor so that for datasets shorter
531
- # than max_trained_positions 2048, the results are consistent
532
- # with SentenceTransformer.
533
- # The context extension uses vllm style rope_theta and rope_scaling.
534
- # See #17785 #18755
535
- if (not vllm_config .model_config .hf_overrides
536
- and vllm_config .model_config .original_max_model_len is None ):
537
- # Default
538
- # Reset max_model_len to max_trained_positions.
539
- # nomic-embed-text-v2-moe the length is set to 512
540
- # by sentence_bert_config.json.
541
- max_model_len_before = vllm_config .model_config .max_model_len
542
- max_model_len = min (vllm_config .model_config .max_model_len ,
543
- max_trained_positions )
544
-
545
- vllm_config .recalculate_max_model_len (max_model_len )
546
- logger .warning (
547
- "Nomic context extension is disabled. "
548
- "Changing max_model_len from %s to %s. "
549
- "To enable context extension, see: "
550
- "https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/context_extension.html" ,
551
- max_model_len_before , vllm_config .model_config .max_model_len )
552
- else :
553
- # We need to re-verify max_model_len to avoid lengths
554
- # greater than position_embedding.
555
- model_config = vllm_config .model_config
556
- hf_text_config = model_config .hf_text_config
557
-
558
- if isinstance (model_config .hf_overrides , dict ):
559
- # hf_overrides_kw
560
- max_model_len = model_config .hf_overrides .get (
561
- "max_model_len" , vllm_config .model_config .max_model_len )
562
- else :
563
- # hf_overrides_fn
564
- # This might be overridden by sentence_bert_config.json.
565
- max_model_len = vllm_config .model_config .max_model_len
566
-
567
- # reset hf_text_config for recalculate_max_model_len.
568
- if hasattr (hf_text_config , "max_model_len" ):
569
- delattr (hf_text_config , "max_model_len" )
570
- hf_text_config .max_position_embeddings = max_trained_positions
571
- hf_text_config .rope_scaling = config .rotary_kwargs ["rope_scaling" ]
572
-
573
- # The priority of sentence_bert_config.json is higher
574
- # than max_position_embeddings
575
- encoder_config = deepcopy (model_config .encoder_config )
576
- encoder_config .pop ("max_seq_length" , None )
577
- model_config .encoder_config = encoder_config
578
-
579
- vllm_config .recalculate_max_model_len (max_model_len )
580
- return config
581
-
582
486
583
487
class GteNewModel (BertWithRope ):
584
488
# for https://huggingface.co/Alibaba-NLP/new-impl
@@ -600,24 +504,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
600
504
layer .mlp .gate_up_proj .bias = None
601
505
layer .mlp .gate_up_proj .skip_bias_add = True
602
506
603
- def config_verify (self , vllm_config ):
604
- config = vllm_config .model_config .hf_config
605
-
606
- assert config .__class__ .__name__ == "NewConfig"
607
- assert config .hidden_act == "gelu"
608
-
609
- config .hidden_act = "geglu"
610
-
611
- head_dim = config .hidden_size // config .num_attention_heads
612
- config .rotary_kwargs = {
613
- "head_size" : head_dim ,
614
- "rotary_dim" : getattr (config , "rotary_emb_dim" , head_dim ),
615
- "max_position" : config .max_position_embeddings ,
616
- "base" : config .rope_theta ,
617
- "rope_scaling" : getattr (config , "rope_scaling" , None )
618
- }
619
- return config
620
-
621
507
def split_up_gate_proj (self , weights : Iterable [tuple [str , torch .Tensor ]]):
622
508
n = "mlp.up_gate_proj"
623
509
for name , weight in weights :
@@ -652,24 +538,6 @@ class SnowflakeGteNewModel(GteNewModel):
652
538
"attention.o_proj" : "attn.out_proj" ,
653
539
})
654
540
655
- def config_verify (self , vllm_config ):
656
- config = vllm_config .model_config .hf_config
657
-
658
- assert config .__class__ .__name__ == "GteConfig"
659
- assert config .hidden_act == "gelu"
660
-
661
- config .hidden_act = "geglu"
662
-
663
- head_dim = config .hidden_size // config .num_attention_heads
664
- config .rotary_kwargs = {
665
- "head_size" : head_dim ,
666
- "rotary_dim" : getattr (config , "rotary_emb_dim" , head_dim ),
667
- "max_position" : config .max_position_embeddings ,
668
- "base" : config .rope_theta ,
669
- "rope_scaling" : getattr (config , "rope_scaling" , None )
670
- }
671
- return config
672
-
673
541
674
542
class JinaRobertaModel (BertWithRope ):
675
543
# for https://huggingface.co/jinaai/jina-embeddings-v3
@@ -685,21 +553,6 @@ class JinaRobertaModel(BertWithRope):
685
553
"norm2" : "mlp_ln" ,
686
554
})
687
555
688
- def config_verify (self , vllm_config ):
689
- config = vllm_config .model_config .hf_config
690
-
691
- assert config .__class__ .__name__ == "XLMRobertaFlashConfig"
692
-
693
- head_dim = config .hidden_size // config .num_attention_heads
694
- config .rotary_kwargs = {
695
- "head_size" : head_dim ,
696
- "rotary_dim" : getattr (config , "rotary_emb_dim" , head_dim ),
697
- "max_position" : config .max_position_embeddings ,
698
- "base" : getattr (config , "rope_theta" , config .rotary_emb_base ),
699
- "rope_scaling" : getattr (config , "rope_scaling" , None )
700
- }
701
- return config
702
-
703
556
def forward (
704
557
self ,
705
558
input_ids : torch .Tensor ,
0 commit comments