Skip to content

Commit b1a74c9

Browse files
DarkLight1337garg-amit
authored andcommitted
[Model] Refactor BLIP/BLIP-2 to support composite model loading (vllm-project#8407)
Signed-off-by: Amit Garg <mitgarg17495@gmail.com>
1 parent 6b3f347 commit b1a74c9

File tree

10 files changed

+113
-114
lines changed

10 files changed

+113
-114
lines changed

vllm/model_executor/models/blip.py

+58-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Minimal implementation of BlipVisionModel intended to be only used
22
within a vision language model."""
3-
from typing import Optional, Union
3+
from typing import Iterable, Optional, Tuple, Union
44

55
import torch
66
import torch.nn as nn
@@ -16,6 +16,7 @@
1616
QKVParallelLinear,
1717
RowParallelLinear)
1818
from vllm.model_executor.layers.quantization import QuantizationConfig
19+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
1920
from vllm.multimodal.utils import (cached_get_tokenizer,
2021
repeat_and_pad_placeholder_tokens)
2122
from vllm.sequence import SequenceData
@@ -342,6 +343,10 @@ def __init__(self,
342343
num_hidden_layers_override: Optional[int] = None):
343344
super().__init__()
344345

346+
tp_size = get_tensor_model_parallel_world_size()
347+
num_heads = config.num_attention_heads
348+
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
349+
345350
self.config = config
346351

347352
self.embeddings = BlipVisionEmbeddings(config)
@@ -350,11 +355,61 @@ def __init__(self,
350355
quant_config=quant_config,
351356
num_hidden_layers_override=num_hidden_layers_override,
352357
)
353-
self.post_layernorm = nn.LayerNorm(config.hidden_size,
354-
eps=config.layer_norm_eps)
358+
359+
if len(self.encoder.layers) > config.num_hidden_layers:
360+
raise ValueError(
361+
f"The original encoder only has {config.num_hidden_layers} "
362+
f"layers, but you requested {len(self.encoder.layers)} layers."
363+
)
364+
elif len(self.encoder.layers) == config.num_hidden_layers:
365+
self.post_layernorm = nn.LayerNorm(config.hidden_size,
366+
eps=config.layer_norm_eps)
367+
else:
368+
# post_layernorm is unused when we extract intermediate features
369+
# In this case, we can skip it to conserve memory
370+
self.post_layernorm = None
355371

356372
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
357373
hidden_states = self.embeddings(pixel_values)
358374
hidden_states = self.encoder(inputs_embeds=hidden_states)
359375

376+
if self.post_layernorm is None:
377+
return hidden_states
378+
360379
return self.post_layernorm(hidden_states)
380+
381+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
382+
stacked_params_mapping = [
383+
# (param_name, shard_name, shard_id)
384+
("qkv_proj", "q_proj", "q"),
385+
("qkv_proj", "k_proj", "k"),
386+
("qkv_proj", "v_proj", "v"),
387+
] if self.shard_weight else []
388+
params_dict = dict(self.named_parameters())
389+
layer_count = len(self.encoder.layers)
390+
391+
for name, loaded_weight in weights:
392+
# post_layernorm is not needed in BlipVisionModel
393+
if (name.startswith("post_layernorm")
394+
and self.post_layernorm is None):
395+
continue
396+
397+
# omit layers when num_hidden_layers_override is set
398+
if name.startswith("encoder.layers"):
399+
layer_idx = int(name.split(".")[2])
400+
if layer_idx >= layer_count:
401+
continue
402+
403+
for (param_name, weight_name, shard_id) in stacked_params_mapping:
404+
if weight_name not in name:
405+
continue
406+
407+
param = params_dict[name.replace(weight_name, param_name)]
408+
weight_loader = param.weight_loader
409+
weight_loader(param, loaded_weight, shard_id)
410+
break
411+
else:
412+
param = params_dict[name]
413+
weight_loader = getattr(param, "weight_loader",
414+
default_weight_loader)
415+
weight_loader(param, loaded_weight)

vllm/model_executor/models/blip2.py

+47-74
Original file line numberDiff line numberDiff line change
@@ -10,24 +10,18 @@
1010
from vllm.config import CacheConfig, MultiModalConfig
1111
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
1212
from vllm.model_executor.layers.activation import get_act_fn
13-
from vllm.model_executor.layers.logits_processor import LogitsProcessor
1413
from vllm.model_executor.layers.quantization import QuantizationConfig
15-
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
14+
from vllm.model_executor.layers.sampler import SamplerOutput
1615
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
17-
from vllm.model_executor.models.opt import OPTModel
1816
from vllm.model_executor.sampling_metadata import SamplingMetadata
1917
from vllm.multimodal import MULTIMODAL_REGISTRY
2018
from vllm.sequence import IntermediateTensors, SequenceData
2119

2220
from .blip import (BlipVisionModel, dummy_image_for_blip,
2321
get_max_blip_image_tokens)
2422
from .interfaces import SupportsMultiModal
25-
from .utils import merge_multimodal_embeddings
26-
27-
_KEYS_TO_MODIFY_MAPPING = {
28-
"language_model.lm_head": "lm_head",
29-
"language_model.model": "language_model",
30-
}
23+
from .utils import (group_weights_with_prefix, init_vllm_registered_model,
24+
merge_multimodal_embeddings)
3125

3226
# We use this internally as placeholders since there is no image token
3327
# defined on the HuggingFace repo
@@ -491,9 +485,6 @@ def __init__(self,
491485

492486
super().__init__()
493487

494-
# currently all existing BLIP-2 models have `tie_word_embeddings`
495-
# enabled
496-
assert config.tie_word_embeddings
497488
self.config = config
498489
self.multimodal_config = multimodal_config
499490

@@ -514,17 +505,8 @@ def __init__(self,
514505
bias=True,
515506
)
516507

517-
self.quant_config = quant_config
518-
519-
self.language_model = OPTModel(config.text_config, cache_config,
520-
quant_config)
521-
522-
self.unpadded_vocab_size = config.text_config.vocab_size
523-
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size)
524-
self.sampler = Sampler()
525-
526-
def get_lm_head(self):
527-
return self.language_model.decoder.embed_tokens
508+
self.language_model = init_vllm_registered_model(
509+
config.text_config, cache_config, quant_config)
528510

529511
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
530512
h = w = self.config.vision_config.image_size
@@ -653,7 +635,8 @@ def forward(
653635

654636
if image_input is not None:
655637
vision_embeddings = self._process_image_input(image_input)
656-
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
638+
inputs_embeds = self.language_model.model.get_input_embeddings(
639+
input_ids)
657640

658641
inputs_embeds = merge_multimodal_embeddings(
659642
input_ids, inputs_embeds, vision_embeddings,
@@ -663,11 +646,11 @@ def forward(
663646
else:
664647
inputs_embeds = None
665648

666-
hidden_states = self.language_model(input_ids,
667-
positions,
668-
kv_caches,
669-
attn_metadata,
670-
inputs_embeds=inputs_embeds)
649+
hidden_states = self.language_model.model(input_ids,
650+
positions,
651+
kv_caches,
652+
attn_metadata,
653+
inputs_embeds=inputs_embeds)
671654

672655
return hidden_states
673656

@@ -676,56 +659,46 @@ def compute_logits(
676659
hidden_states: torch.Tensor,
677660
sampling_metadata: SamplingMetadata,
678661
) -> Optional[torch.Tensor]:
679-
logits = self.logits_processor(self.get_lm_head(), hidden_states,
680-
sampling_metadata)
681-
return logits
662+
return self.language_model.compute_logits(hidden_states,
663+
sampling_metadata)
682664

683665
def sample(
684666
self,
685667
logits: torch.Tensor,
686668
sampling_metadata: SamplingMetadata,
687669
) -> Optional[SamplerOutput]:
688-
next_tokens = self.sampler(logits, sampling_metadata)
689-
return next_tokens
670+
return self.language_model.sample(logits, sampling_metadata)
690671

691672
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
692-
# only doing this for language model part for now.
693-
stacked_params_mapping = [
694-
# (param_name, shard_name, shard_id)
695-
("qkv_proj", "q_proj", "q"),
696-
("qkv_proj", "k_proj", "k"),
697-
("qkv_proj", "v_proj", "v"),
698-
("gate_up_proj", "gate_proj", 0),
699-
("gate_up_proj", "up_proj", 1),
700-
]
701-
params_dict = dict(self.named_parameters())
702-
703-
for name, loaded_weight in weights:
704-
if "lm_head.weight" in name:
705-
continue
706-
if "rotary_emb.inv_freq" in name:
707-
continue
708-
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
709-
if key_to_modify in name:
710-
name = name.replace(key_to_modify, new_key)
711-
use_default_weight_loading = False
712-
if "vision" in name:
713-
if self.vision_model is not None:
714-
# BlipVisionModel does not need sharding
715-
use_default_weight_loading = True
716-
else:
717-
for (param_name, weight_name,
718-
shard_id) in stacked_params_mapping:
719-
if weight_name not in name:
720-
continue
721-
param = params_dict[name.replace(weight_name, param_name)]
722-
weight_loader = param.weight_loader
723-
weight_loader(param, loaded_weight, shard_id)
724-
break
725-
else:
726-
use_default_weight_loading = True
727-
if use_default_weight_loading:
728-
param = params_dict[name]
729-
weight_loader = getattr(param, "weight_loader",
730-
default_weight_loader)
731-
weight_loader(param, loaded_weight)
673+
# prepare weight iterators for components
674+
weights_group = group_weights_with_prefix(weights)
675+
676+
# load vision encoder
677+
self.vision_model.load_weights(weights_group["vision_model"])
678+
679+
# load query tokens
680+
for name, loaded_weight in weights_group["query_tokens"]:
681+
assert name == ""
682+
param = self.query_tokens
683+
weight_loader = getattr(param, "weight_loader",
684+
default_weight_loader)
685+
weight_loader(param, loaded_weight)
686+
687+
# load qformer
688+
qformer_params_dict = dict(self.qformer.named_parameters())
689+
for name, loaded_weight in weights_group["qformer"]:
690+
param = qformer_params_dict[name]
691+
weight_loader = getattr(param, "weight_loader",
692+
default_weight_loader)
693+
weight_loader(param, loaded_weight)
694+
695+
# load mlp projector
696+
mlp_params_dict = dict(self.language_projection.named_parameters())
697+
for name, loaded_weight in weights_group["language_projection"]:
698+
param = mlp_params_dict[name]
699+
weight_loader = getattr(param, "weight_loader",
700+
default_weight_loader)
701+
weight_loader(param, loaded_weight)
702+
703+
# load llm backbone
704+
self.language_model.load_weights(weights_group["language_model"])

vllm/model_executor/models/chameleon.py

-3
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from vllm.config import CacheConfig, MultiModalConfig
1313
from vllm.distributed import get_tensor_model_parallel_world_size
1414
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
15-
from vllm.logger import init_logger
1615
from vllm.model_executor.layers.activation import SiluAndMul
1716
from vllm.model_executor.layers.layernorm import RMSNorm
1817
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
@@ -36,8 +35,6 @@
3635

3736
from .interfaces import SupportsMultiModal
3837

39-
logger = init_logger(__name__)
40-
4138
# These configs are not part of the model config but the preprocessor
4239
# and processor files, so we hardcode them in the model file for now.
4340
CHAMELEON_CROP_SIZE_HEIGHT = CHAMELEON_CROP_SIZE_WIDTH = 512

vllm/model_executor/models/clip.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,7 @@ def __init__(self,
391391
quant_config: Optional[QuantizationConfig] = None,
392392
num_hidden_layers_override: Optional[int] = None):
393393
super().__init__()
394+
394395
tp_size = get_tensor_model_parallel_world_size()
395396
num_heads = config.num_attention_heads
396397
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
@@ -400,10 +401,6 @@ def __init__(self,
400401
quant_config=quant_config,
401402
num_hidden_layers_override=num_hidden_layers_override)
402403

403-
@property
404-
def _require_post_layernorm(self) -> bool:
405-
return self.vision_model.post_layernorm is not None
406-
407404
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
408405
return self.vision_model(pixel_values)
409406

@@ -425,12 +422,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
425422

426423
for name, loaded_weight in weights:
427424
# post_layernorm is not needed in CLIPVisionModel
428-
if ("vision_model.post_layernorm" in name
429-
and not self._require_post_layernorm):
425+
if (name.startswith("vision_model.post_layernorm")
426+
and self.vision_model.post_layernorm is None):
430427
continue
431428

432429
# omit layers when num_hidden_layers_override is set
433-
if "vision_model.encoder.layers." in name:
430+
if name.startswith("vision_model.encoder.layers"):
434431
layer_idx = int(name.split(".")[3])
435432
if layer_idx >= layer_count:
436433
continue

vllm/model_executor/models/fuyu.py

-3
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from vllm.attention import AttentionMetadata
2929
from vllm.config import CacheConfig, MultiModalConfig
3030
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
31-
from vllm.logger import init_logger
3231
from vllm.model_executor.layers.linear import ColumnParallelLinear
3332
from vllm.model_executor.layers.quantization import QuantizationConfig
3433
from vllm.model_executor.layers.sampler import SamplerOutput
@@ -45,8 +44,6 @@
4544
from .interfaces import SupportsMultiModal
4645
from .utils import merge_multimodal_embeddings
4746

48-
logger = init_logger(__name__)
49-
5047
# Cannot find the following 2 numbers from hf config.
5148
_IMAGE_TOKEN_ID = 71011
5249
_NEWLINE_TOKEN_ID = 71019

vllm/model_executor/models/llava_next.py

-8
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from vllm.attention import AttentionMetadata
1313
from vllm.config import CacheConfig, MultiModalConfig
1414
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
15-
from vllm.logger import init_logger
1615
from vllm.model_executor.layers.quantization import QuantizationConfig
1716
from vllm.model_executor.layers.sampler import SamplerOutput
1817
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@@ -32,13 +31,6 @@
3231
from .utils import (flatten_bn, group_weights_with_prefix,
3332
init_vllm_registered_model, merge_multimodal_embeddings)
3433

35-
logger = init_logger(__name__)
36-
37-
_KEYS_TO_MODIFY_MAPPING = {
38-
"language_model.lm_head": "lm_head",
39-
"language_model.model": "language_model",
40-
}
41-
4234
# Result in the max possible feature size (2x2 grid of 336x336px tiles)
4335
MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448
4436

vllm/model_executor/models/llava_next_video.py

-3
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from vllm.attention import AttentionMetadata
1212
from vllm.config import CacheConfig, MultiModalConfig
1313
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
14-
from vllm.logger import init_logger
1514
from vllm.model_executor.layers.activation import get_act_fn
1615
from vllm.model_executor.layers.quantization.base_config import (
1716
QuantizationConfig)
@@ -32,8 +31,6 @@
3231
from .utils import (group_weights_with_prefix, init_vllm_registered_model,
3332
merge_multimodal_embeddings)
3433

35-
logger = init_logger(__name__)
36-
3734
# For profile run
3835
_MAX_FRAMES_PER_VIDEO = 32
3936
_MAX_NUM_VIDEOS = 1

vllm/model_executor/models/minicpmv.py

-3
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
from vllm.attention import AttentionMetadata
3838
from vllm.config import CacheConfig, MultiModalConfig
3939
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
40-
from vllm.logger import init_logger
4140
from vllm.model_executor.layers.linear import ReplicatedLinear
4241
from vllm.model_executor.layers.logits_processor import LogitsProcessor
4342
from vllm.model_executor.layers.quantization import QuantizationConfig
@@ -59,8 +58,6 @@
5958

6059
from .idefics2_vision_model import Idefics2VisionTransformer
6160

62-
logger = init_logger(__name__)
63-
6461
_KEYS_TO_MODIFY_MAPPING = {
6562
"llm.lm_head": "lm_head",
6663
"llm.model": "llm",

0 commit comments

Comments
 (0)