diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index b6f4275fbc948..5fd39b5e35be6 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -1,3 +1,4 @@ +import itertools from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) @@ -13,7 +14,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.gemma import GemmaModel +from vllm.model_executor.models.gemma import GemmaForCausalLM from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.utils import cached_get_tokenizer @@ -22,14 +23,10 @@ from .interfaces import SupportsMultiModal from .siglip import (SiglipVisionModel, dummy_image_for_siglip, dummy_seq_data_for_siglip, get_max_siglip_image_tokens) -from .utils import merge_multimodal_embeddings +from .utils import filter_weights, merge_multimodal_embeddings logger = init_logger(__name__) -_KEYS_TO_MODIFY_MAPPING = { - "language_model.model": "language_model", -} - class PaliGemmaImagePixelInputs(TypedDict): type: Literal["pixel_values"] @@ -151,8 +148,8 @@ def __init__(self, projection_dim=config.vision_config.projection_dim) self.quant_config = quant_config - self.language_model = GemmaModel(config.text_config, cache_config, - quant_config) + self.language_model = GemmaForCausalLM(config.text_config, + cache_config, quant_config) self.unpadded_vocab_size = config.text_config.vocab_size logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, @@ -252,7 +249,8 @@ def forward(self, vision_embeddings = vision_embeddings * (self.config.hidden_size** -0.5) - inputs_embeds = self.language_model.get_input_embeddings(input_ids) + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, vision_embeddings, @@ -262,87 +260,47 @@ def forward(self, else: inputs_embeds = None - hidden_states = self.language_model(input_ids, - positions, - kv_caches, - attn_metadata, - None, - inputs_embeds=inputs_embeds) + hidden_states = self.language_model.model(input_ids, + positions, + kv_caches, + attn_metadata, + None, + inputs_embeds=inputs_embeds) return hidden_states - # Copied from vllm/model_executor/models/gemma.py def compute_logits( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.language_model.embed_tokens, - hidden_states, sampling_metadata) - return logits + return self.language_model.compute_logits(hidden_states, + sampling_metadata) - # Copied from vllm/model_executor/models/gemma.py def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens + return self.language_model.sample(logits, sampling_metadata) - # Adapted from vllm/model_executor/models/gemma.py def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.named_parameters()) - loaded_params = set() - for name, loaded_weight in weights: - for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): - if key_to_modify in name: - name = name.replace(key_to_modify, new_key) - use_default_weight_loading = False - if "vision" not in name or self.vision_tower.shard_weight: - for (param_name, shard_name, - shard_id) in stacked_params_mapping: - if shard_name not in name: - continue - name = name.replace(shard_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # lm_head is not used in vllm as it is tied with - # embed_token. To prevent errors, skip loading - # lm_head.weight. - if "lm_head.weight" in name: - continue - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - use_default_weight_loading = True - else: - use_default_weight_loading = True - - if use_default_weight_loading: - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - - loaded_params.add(name) - - unloaded_params = params_dict.keys() - loaded_params - if unloaded_params: - logger.warning( - "Some weights are not initialized from checkpoints: %s", - unloaded_params) + # prepare weight iterators for components + vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3) + + # load vision tower + vit_weights = filter_weights(vit_weights, "vision_tower") + self.vision_tower.load_weights(vit_weights) + + # load mlp projector + mlp_weights = filter_weights(mlp_weights, "multi_modal_projector") + mlp_params_dict = dict(self.multi_modal_projector.named_parameters()) + for name, loaded_weight in mlp_weights: + param = mlp_params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + # load llm backbone + llm_weights = filter_weights(llm_weights, "language_model") + self.language_model.load_weights(llm_weights) diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index fb4c30c1a13f9..13d09e4cd4c23 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -529,6 +529,12 @@ def forward( ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] if self.shard_weight else [] params_dict = dict(self.named_parameters()) layer_count = len(self.vision_model.encoder.layers) @@ -544,7 +550,16 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if layer_idx >= layer_count: continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + + param = params_dict[name.replace(weight_name, param_name)] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight)