From 06ed2815e2be50e527839c7ab09ce2639b7910b6 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sun, 22 Sep 2024 20:24:21 +0800 Subject: [PATCH] [Model] Refactor BLIP/BLIP-2 to support composite model loading (#8407) --- vllm/model_executor/models/blip.py | 61 ++++++++- vllm/model_executor/models/blip2.py | 121 +++++++----------- vllm/model_executor/models/chameleon.py | 3 - vllm/model_executor/models/clip.py | 11 +- vllm/model_executor/models/fuyu.py | 3 - vllm/model_executor/models/llava_next.py | 8 -- .../model_executor/models/llava_next_video.py | 3 - vllm/model_executor/models/minicpmv.py | 3 - vllm/model_executor/models/siglip.py | 11 +- vllm/model_executor/models/ultravox.py | 3 - 10 files changed, 113 insertions(+), 114 deletions(-) diff --git a/vllm/model_executor/models/blip.py b/vllm/model_executor/models/blip.py index e943427eda8e1..7c8e76461dd67 100644 --- a/vllm/model_executor/models/blip.py +++ b/vllm/model_executor/models/blip.py @@ -1,6 +1,6 @@ """Minimal implementation of BlipVisionModel intended to be only used within a vision language model.""" -from typing import Optional, Union +from typing import Iterable, Optional, Tuple, Union import torch import torch.nn as nn @@ -16,6 +16,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal.utils import (cached_get_tokenizer, repeat_and_pad_placeholder_tokens) from vllm.sequence import SequenceData @@ -342,6 +343,10 @@ def __init__(self, num_hidden_layers_override: Optional[int] = None): super().__init__() + tp_size = get_tensor_model_parallel_world_size() + num_heads = config.num_attention_heads + self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0 + self.config = config self.embeddings = BlipVisionEmbeddings(config) @@ -350,11 +355,61 @@ def __init__(self, quant_config=quant_config, num_hidden_layers_override=num_hidden_layers_override, ) - self.post_layernorm = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_eps) + + if len(self.encoder.layers) > config.num_hidden_layers: + raise ValueError( + f"The original encoder only has {config.num_hidden_layers} " + f"layers, but you requested {len(self.encoder.layers)} layers." + ) + elif len(self.encoder.layers) == config.num_hidden_layers: + self.post_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + else: + # post_layernorm is unused when we extract intermediate features + # In this case, we can skip it to conserve memory + self.post_layernorm = None def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: hidden_states = self.embeddings(pixel_values) hidden_states = self.encoder(inputs_embeds=hidden_states) + if self.post_layernorm is None: + return hidden_states + return self.post_layernorm(hidden_states) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] if self.shard_weight else [] + params_dict = dict(self.named_parameters()) + layer_count = len(self.encoder.layers) + + for name, loaded_weight in weights: + # post_layernorm is not needed in BlipVisionModel + if (name.startswith("post_layernorm") + and self.post_layernorm is None): + continue + + # omit layers when num_hidden_layers_override is set + if name.startswith("encoder.layers"): + layer_idx = int(name.split(".")[2]) + if layer_idx >= layer_count: + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + + param = params_dict[name.replace(weight_name, param_name)] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 37fabf3f3f9a8..b28d7699afa01 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -10,11 +10,9 @@ from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.opt import OPTModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import IntermediateTensors, SequenceData @@ -22,12 +20,8 @@ from .blip import (BlipVisionModel, dummy_image_for_blip, get_max_blip_image_tokens) from .interfaces import SupportsMultiModal -from .utils import merge_multimodal_embeddings - -_KEYS_TO_MODIFY_MAPPING = { - "language_model.lm_head": "lm_head", - "language_model.model": "language_model", -} +from .utils import (group_weights_with_prefix, init_vllm_registered_model, + merge_multimodal_embeddings) # We use this internally as placeholders since there is no image token # defined on the HuggingFace repo @@ -491,9 +485,6 @@ def __init__(self, super().__init__() - # currently all existing BLIP-2 models have `tie_word_embeddings` - # enabled - assert config.tie_word_embeddings self.config = config self.multimodal_config = multimodal_config @@ -514,17 +505,8 @@ def __init__(self, bias=True, ) - self.quant_config = quant_config - - self.language_model = OPTModel(config.text_config, cache_config, - quant_config) - - self.unpadded_vocab_size = config.text_config.vocab_size - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size) - self.sampler = Sampler() - - def get_lm_head(self): - return self.language_model.decoder.embed_tokens + self.language_model = init_vllm_registered_model( + config.text_config, cache_config, quant_config) def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: h = w = self.config.vision_config.image_size @@ -653,7 +635,8 @@ def forward( if image_input is not None: vision_embeddings = self._process_image_input(image_input) - inputs_embeds = self.language_model.get_input_embeddings(input_ids) + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, vision_embeddings, @@ -663,11 +646,11 @@ def forward( else: inputs_embeds = None - hidden_states = self.language_model(input_ids, - positions, - kv_caches, - attn_metadata, - inputs_embeds=inputs_embeds) + hidden_states = self.language_model.model(input_ids, + positions, + kv_caches, + attn_metadata, + inputs_embeds=inputs_embeds) return hidden_states @@ -676,56 +659,46 @@ def compute_logits( hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.get_lm_head(), hidden_states, - sampling_metadata) - return logits + return self.language_model.compute_logits(hidden_states, + sampling_metadata) def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens + return self.language_model.sample(logits, sampling_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - # only doing this for language model part for now. - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.named_parameters()) - - for name, loaded_weight in weights: - if "lm_head.weight" in name: - continue - if "rotary_emb.inv_freq" in name: - continue - for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): - if key_to_modify in name: - name = name.replace(key_to_modify, new_key) - use_default_weight_loading = False - if "vision" in name: - if self.vision_model is not None: - # BlipVisionModel does not need sharding - use_default_weight_loading = True - else: - for (param_name, weight_name, - shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - param = params_dict[name.replace(weight_name, param_name)] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - use_default_weight_loading = True - if use_default_weight_loading: - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + # prepare weight iterators for components + weights_group = group_weights_with_prefix(weights) + + # load vision encoder + self.vision_model.load_weights(weights_group["vision_model"]) + + # load query tokens + for name, loaded_weight in weights_group["query_tokens"]: + assert name == "" + param = self.query_tokens + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + # load qformer + qformer_params_dict = dict(self.qformer.named_parameters()) + for name, loaded_weight in weights_group["qformer"]: + param = qformer_params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + # load mlp projector + mlp_params_dict = dict(self.language_projection.named_parameters()) + for name, loaded_weight in weights_group["language_projection"]: + param = mlp_params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + # load llm backbone + self.language_model.load_weights(weights_group["language_model"]) diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 51a61485caf65..973e47f5f0ccd 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -12,7 +12,6 @@ from vllm.config import CacheConfig, MultiModalConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs -from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -36,8 +35,6 @@ from .interfaces import SupportsMultiModal -logger = init_logger(__name__) - # These configs are not part of the model config but the preprocessor # and processor files, so we hardcode them in the model file for now. CHAMELEON_CROP_SIZE_HEIGHT = CHAMELEON_CROP_SIZE_WIDTH = 512 diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index a7754f70e2786..c353635404d9a 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -391,6 +391,7 @@ def __init__(self, quant_config: Optional[QuantizationConfig] = None, num_hidden_layers_override: Optional[int] = None): super().__init__() + tp_size = get_tensor_model_parallel_world_size() num_heads = config.num_attention_heads self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0 @@ -400,10 +401,6 @@ def __init__(self, quant_config=quant_config, num_hidden_layers_override=num_hidden_layers_override) - @property - def _require_post_layernorm(self) -> bool: - return self.vision_model.post_layernorm is not None - def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: return self.vision_model(pixel_values) @@ -425,12 +422,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for name, loaded_weight in weights: # post_layernorm is not needed in CLIPVisionModel - if ("vision_model.post_layernorm" in name - and not self._require_post_layernorm): + if (name.startswith("vision_model.post_layernorm") + and self.vision_model.post_layernorm is None): continue # omit layers when num_hidden_layers_override is set - if "vision_model.encoder.layers." in name: + if name.startswith("vision_model.encoder.layers"): layer_idx = int(name.split(".")[3]) if layer_idx >= layer_count: continue diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index beeae14229575..4cf3b0b93dcf5 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -28,7 +28,6 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs -from vllm.logger import init_logger from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput @@ -45,8 +44,6 @@ from .interfaces import SupportsMultiModal from .utils import merge_multimodal_embeddings -logger = init_logger(__name__) - # Cannot find the following 2 numbers from hf config. _IMAGE_TOKEN_ID = 71011 _NEWLINE_TOKEN_ID = 71019 diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 96034b254e49b..4341cc38bdd28 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -12,7 +12,6 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs -from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -32,13 +31,6 @@ from .utils import (flatten_bn, group_weights_with_prefix, init_vllm_registered_model, merge_multimodal_embeddings) -logger = init_logger(__name__) - -_KEYS_TO_MODIFY_MAPPING = { - "language_model.lm_head": "lm_head", - "language_model.model": "language_model", -} - # Result in the max possible feature size (2x2 grid of 336x336px tiles) MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448 diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index a8b5176dc43cf..397a6cce5af2c 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -11,7 +11,6 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs -from vllm.logger import init_logger from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) @@ -32,8 +31,6 @@ from .utils import (group_weights_with_prefix, init_vllm_registered_model, merge_multimodal_embeddings) -logger = init_logger(__name__) - # For profile run _MAX_FRAMES_PER_VIDEO = 32 _MAX_NUM_VIDEOS = 1 diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 5579205832aa8..c0fb6fef78bab 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -37,7 +37,6 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs -from vllm.logger import init_logger from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig @@ -59,8 +58,6 @@ from .idefics2_vision_model import Idefics2VisionTransformer -logger = init_logger(__name__) - _KEYS_TO_MODIFY_MAPPING = { "llm.lm_head": "lm_head", "llm.model": "llm", diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index 5b332fa1a24d7..6cf7df4e6ac63 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -501,6 +501,7 @@ def __init__( num_hidden_layers_override: Optional[int] = None, ): super().__init__() + num_heads = config.num_attention_heads tp_size = get_tensor_model_parallel_world_size() self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0 @@ -511,10 +512,6 @@ def __init__( num_hidden_layers_override=num_hidden_layers_override, ) - @property - def _require_post_layernorm(self) -> bool: - return self.vision_model.post_layernorm is not None - def get_input_embeddings(self) -> nn.Module: return self.vision_model.embeddings.patch_embedding @@ -540,12 +537,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for name, loaded_weight in weights: # post_layernorm is optional in SiglipVisionModel - if ("vision_model.post_layernorm" in name - and not self._require_post_layernorm): + if (name.startswith("vision_model.post_layernorm") + and self.vision_model.post_layernorm is None): continue # omit layers when num_hidden_layers_override is set - if "vision_model.encoder.layers." in name: + if name.startswith("vision_model.encoder.layers"): layer_idx = int(name.split(".")[3]) if layer_idx >= layer_count: continue diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index b89c9dafd9cd8..32a0e895005cb 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -20,7 +20,6 @@ from vllm.inputs import INPUT_REGISTRY from vllm.inputs.data import LLMInputs from vllm.inputs.registry import InputContext -from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.base_config import ( @@ -43,8 +42,6 @@ _AUDIO_PLACEHOLDER_TOKEN = 128002 _AUDIO_TOKENS_PER_SECOND = 6.25 -logger = init_logger(__name__) - class UltravoxAudioFeatureInputs(TypedDict): type: Literal["audio_features"]