10
10
from vllm .config import CacheConfig , MultiModalConfig
11
11
from vllm .inputs import INPUT_REGISTRY , InputContext , LLMInputs
12
12
from vllm .model_executor .layers .activation import get_act_fn
13
- from vllm .model_executor .layers .logits_processor import LogitsProcessor
14
13
from vllm .model_executor .layers .quantization import QuantizationConfig
15
- from vllm .model_executor .layers .sampler import Sampler , SamplerOutput
14
+ from vllm .model_executor .layers .sampler import SamplerOutput
16
15
from vllm .model_executor .model_loader .weight_utils import default_weight_loader
17
- from vllm .model_executor .models .opt import OPTModel
18
16
from vllm .model_executor .sampling_metadata import SamplingMetadata
19
17
from vllm .multimodal import MULTIMODAL_REGISTRY
20
18
from vllm .sequence import IntermediateTensors , SequenceData
21
19
22
20
from .blip import (BlipVisionModel , dummy_image_for_blip ,
23
21
get_max_blip_image_tokens )
24
22
from .interfaces import SupportsMultiModal
25
- from .utils import merge_multimodal_embeddings
26
-
27
- _KEYS_TO_MODIFY_MAPPING = {
28
- "language_model.lm_head" : "lm_head" ,
29
- "language_model.model" : "language_model" ,
30
- }
23
+ from .utils import (group_weights_with_prefix , init_vllm_registered_model ,
24
+ merge_multimodal_embeddings )
31
25
32
26
# We use this internally as placeholders since there is no image token
33
27
# defined on the HuggingFace repo
@@ -491,9 +485,6 @@ def __init__(self,
491
485
492
486
super ().__init__ ()
493
487
494
- # currently all existing BLIP-2 models have `tie_word_embeddings`
495
- # enabled
496
- assert config .tie_word_embeddings
497
488
self .config = config
498
489
self .multimodal_config = multimodal_config
499
490
@@ -514,17 +505,8 @@ def __init__(self,
514
505
bias = True ,
515
506
)
516
507
517
- self .quant_config = quant_config
518
-
519
- self .language_model = OPTModel (config .text_config , cache_config ,
520
- quant_config )
521
-
522
- self .unpadded_vocab_size = config .text_config .vocab_size
523
- self .logits_processor = LogitsProcessor (self .unpadded_vocab_size )
524
- self .sampler = Sampler ()
525
-
526
- def get_lm_head (self ):
527
- return self .language_model .decoder .embed_tokens
508
+ self .language_model = init_vllm_registered_model (
509
+ config .text_config , cache_config , quant_config )
528
510
529
511
def _validate_pixel_values (self , data : torch .Tensor ) -> torch .Tensor :
530
512
h = w = self .config .vision_config .image_size
@@ -653,7 +635,8 @@ def forward(
653
635
654
636
if image_input is not None :
655
637
vision_embeddings = self ._process_image_input (image_input )
656
- inputs_embeds = self .language_model .get_input_embeddings (input_ids )
638
+ inputs_embeds = self .language_model .model .get_input_embeddings (
639
+ input_ids )
657
640
658
641
inputs_embeds = merge_multimodal_embeddings (
659
642
input_ids , inputs_embeds , vision_embeddings ,
@@ -663,11 +646,11 @@ def forward(
663
646
else :
664
647
inputs_embeds = None
665
648
666
- hidden_states = self .language_model (input_ids ,
667
- positions ,
668
- kv_caches ,
669
- attn_metadata ,
670
- inputs_embeds = inputs_embeds )
649
+ hidden_states = self .language_model . model (input_ids ,
650
+ positions ,
651
+ kv_caches ,
652
+ attn_metadata ,
653
+ inputs_embeds = inputs_embeds )
671
654
672
655
return hidden_states
673
656
@@ -676,56 +659,46 @@ def compute_logits(
676
659
hidden_states : torch .Tensor ,
677
660
sampling_metadata : SamplingMetadata ,
678
661
) -> Optional [torch .Tensor ]:
679
- logits = self .logits_processor (self .get_lm_head (), hidden_states ,
680
- sampling_metadata )
681
- return logits
662
+ return self .language_model .compute_logits (hidden_states ,
663
+ sampling_metadata )
682
664
683
665
def sample (
684
666
self ,
685
667
logits : torch .Tensor ,
686
668
sampling_metadata : SamplingMetadata ,
687
669
) -> Optional [SamplerOutput ]:
688
- next_tokens = self .sampler (logits , sampling_metadata )
689
- return next_tokens
670
+ return self .language_model .sample (logits , sampling_metadata )
690
671
691
672
def load_weights (self , weights : Iterable [Tuple [str , torch .Tensor ]]):
692
- # only doing this for language model part for now.
693
- stacked_params_mapping = [
694
- # (param_name, shard_name, shard_id)
695
- ("qkv_proj" , "q_proj" , "q" ),
696
- ("qkv_proj" , "k_proj" , "k" ),
697
- ("qkv_proj" , "v_proj" , "v" ),
698
- ("gate_up_proj" , "gate_proj" , 0 ),
699
- ("gate_up_proj" , "up_proj" , 1 ),
700
- ]
701
- params_dict = dict (self .named_parameters ())
702
-
703
- for name , loaded_weight in weights :
704
- if "lm_head.weight" in name :
705
- continue
706
- if "rotary_emb.inv_freq" in name :
707
- continue
708
- for key_to_modify , new_key in _KEYS_TO_MODIFY_MAPPING .items ():
709
- if key_to_modify in name :
710
- name = name .replace (key_to_modify , new_key )
711
- use_default_weight_loading = False
712
- if "vision" in name :
713
- if self .vision_model is not None :
714
- # BlipVisionModel does not need sharding
715
- use_default_weight_loading = True
716
- else :
717
- for (param_name , weight_name ,
718
- shard_id ) in stacked_params_mapping :
719
- if weight_name not in name :
720
- continue
721
- param = params_dict [name .replace (weight_name , param_name )]
722
- weight_loader = param .weight_loader
723
- weight_loader (param , loaded_weight , shard_id )
724
- break
725
- else :
726
- use_default_weight_loading = True
727
- if use_default_weight_loading :
728
- param = params_dict [name ]
729
- weight_loader = getattr (param , "weight_loader" ,
730
- default_weight_loader )
731
- weight_loader (param , loaded_weight )
673
+ # prepare weight iterators for components
674
+ weights_group = group_weights_with_prefix (weights )
675
+
676
+ # load vision encoder
677
+ self .vision_model .load_weights (weights_group ["vision_model" ])
678
+
679
+ # load query tokens
680
+ for name , loaded_weight in weights_group ["query_tokens" ]:
681
+ assert name == ""
682
+ param = self .query_tokens
683
+ weight_loader = getattr (param , "weight_loader" ,
684
+ default_weight_loader )
685
+ weight_loader (param , loaded_weight )
686
+
687
+ # load qformer
688
+ qformer_params_dict = dict (self .qformer .named_parameters ())
689
+ for name , loaded_weight in weights_group ["qformer" ]:
690
+ param = qformer_params_dict [name ]
691
+ weight_loader = getattr (param , "weight_loader" ,
692
+ default_weight_loader )
693
+ weight_loader (param , loaded_weight )
694
+
695
+ # load mlp projector
696
+ mlp_params_dict = dict (self .language_projection .named_parameters ())
697
+ for name , loaded_weight in weights_group ["language_projection" ]:
698
+ param = mlp_params_dict [name ]
699
+ weight_loader = getattr (param , "weight_loader" ,
700
+ default_weight_loader )
701
+ weight_loader (param , loaded_weight )
702
+
703
+ # load llm backbone
704
+ self .language_model .load_weights (weights_group ["language_model" ])
0 commit comments