Skip to content

[Model] Refactor BLIP/BLIP-2 to support composite model loading #8407

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Sep 22, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 58 additions & 3 deletions vllm/model_executor/models/blip.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Minimal implementation of BlipVisionModel intended to be only used
within a vision language model."""
from array import array
from typing import Optional, Union
from typing import Iterable, Optional, Tuple, Union

import torch
import torch.nn as nn
Expand All @@ -17,6 +17,7 @@
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
Expand Down Expand Up @@ -343,6 +344,10 @@ def __init__(self,
num_hidden_layers_override: Optional[int] = None):
super().__init__()

tp_size = get_tensor_model_parallel_world_size()
num_heads = config.num_attention_heads
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0

self.config = config

self.embeddings = BlipVisionEmbeddings(config)
Expand All @@ -351,11 +356,61 @@ def __init__(self,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
)
self.post_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)

if len(self.encoder.layers) > config.num_hidden_layers:
raise ValueError(
f"The original encoder only has {config.num_hidden_layers} "
f"layers, but you requested {len(self.encoder.layers)} layers."
)
elif len(self.encoder.layers) == config.num_hidden_layers:
self.post_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
else:
# post_layernorm is unused when we extract intermediate features
# In this case, we can skip it to conserve memory
self.post_layernorm = None

def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
hidden_states = self.embeddings(pixel_values)
hidden_states = self.encoder(inputs_embeds=hidden_states)

if self.post_layernorm is None:
return hidden_states

return self.post_layernorm(hidden_states)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
] if self.shard_weight else []
params_dict = dict(self.named_parameters())
layer_count = len(self.encoder.layers)

for name, loaded_weight in weights:
# post_layernorm is not needed in CLIPVisionModel
if (name.startswith("post_layernorm")
and self.post_layernorm is None):
continue

# omit layers when num_hidden_layers_override is set
if name.startswith("encoder.layers"):
layer_idx = int(name.split(".")[2])
if layer_idx >= layer_count:
continue

for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue

param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
133 changes: 59 additions & 74 deletions vllm/model_executor/models/blip2.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
from array import array
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
Expand All @@ -11,11 +12,9 @@
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.opt import OPTModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
Expand All @@ -24,12 +23,8 @@
from .blip import (BlipVisionModel, dummy_image_for_blip,
get_max_blip_image_tokens)
from .interfaces import SupportsMultiModal
from .utils import merge_multimodal_embeddings

_KEYS_TO_MODIFY_MAPPING = {
"language_model.lm_head": "lm_head",
"language_model.model": "language_model",
}
from .utils import (filter_weights, init_vllm_registered_model,
merge_multimodal_embeddings)

# We use this internally as placeholders since there is no image token
# defined on the HuggingFace repo
Expand Down Expand Up @@ -494,9 +489,6 @@ def __init__(self,

super().__init__()

# currently all existing BLIP-2 models have `tie_word_embeddings`
# enabled
assert config.tie_word_embeddings
self.config = config
self.multimodal_config = multimodal_config

Expand All @@ -517,17 +509,8 @@ def __init__(self,
bias=True,
)

self.quant_config = quant_config

self.language_model = OPTModel(config.text_config, cache_config,
quant_config)

self.unpadded_vocab_size = config.text_config.vocab_size
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size)
self.sampler = Sampler()

def get_lm_head(self):
return self.language_model.decoder.embed_tokens
self.language_model = init_vllm_registered_model(
config.text_config, cache_config, quant_config)

def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size
Expand Down Expand Up @@ -656,7 +639,8 @@ def forward(

if image_input is not None:
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)

inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
Expand All @@ -666,11 +650,11 @@ def forward(
else:
inputs_embeds = None

hidden_states = self.language_model(input_ids,
positions,
kv_caches,
attn_metadata,
inputs_embeds=inputs_embeds)
hidden_states = self.language_model.model(input_ids,
positions,
kv_caches,
attn_metadata,
inputs_embeds=inputs_embeds)

return hidden_states

Expand All @@ -679,56 +663,57 @@ def compute_logits(
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.get_lm_head(), hidden_states,
sampling_metadata)
return logits
return self.language_model.compute_logits(hidden_states,
sampling_metadata)

def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
return self.language_model.sample(logits, sampling_metadata)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# only doing this for language model part for now.
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())

for name, loaded_weight in weights:
if "lm_head.weight" in name:
continue
if "rotary_emb.inv_freq" in name:
continue
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in name:
name = name.replace(key_to_modify, new_key)
use_default_weight_loading = False
if "vision" in name:
if self.vision_model is not None:
# BlipVisionModel does not need sharding
use_default_weight_loading = True
else:
for (param_name, weight_name,
shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
use_default_weight_loading = True
if use_default_weight_loading:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# prepare weight iterators for components
(
vit_weights,
query_weights,
qformer_weights,
mlp_weights,
llm_weights,
) = itertools.tee(weights, 5)

# load vision encoder
vit_weights = filter_weights(vit_weights, "vision_model")
self.vision_model.load_weights(vit_weights)

# load query tokens
query_weights = filter_weights(query_weights, "query_tokens")
for name, loaded_weight in query_weights:
assert name == ""
param = self.query_tokens
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

# load qformer
qformer_weights = filter_weights(qformer_weights, "qformer")
qformer_params_dict = dict(self.qformer.named_parameters())
for name, loaded_weight in qformer_weights:
param = qformer_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

# load mlp projector
mlp_weights = filter_weights(mlp_weights, "language_projection")
mlp_params_dict = dict(self.language_projection.named_parameters())
for name, loaded_weight in mlp_weights:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

# load llm backbone
llm_weights = filter_weights(llm_weights, "language_model")
self.language_model.load_weights(llm_weights)
11 changes: 4 additions & 7 deletions vllm/model_executor/models/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ def __init__(self,
quant_config: Optional[QuantizationConfig] = None,
num_hidden_layers_override: Optional[int] = None):
super().__init__()

tp_size = get_tensor_model_parallel_world_size()
num_heads = config.num_attention_heads
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
Expand All @@ -402,10 +403,6 @@ def __init__(self,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override)

@property
def _require_post_layernorm(self) -> bool:
return self.vision_model.post_layernorm is not None

def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
return self.vision_model(pixel_values)

Expand All @@ -427,12 +424,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

for name, loaded_weight in weights:
# post_layernorm is not needed in CLIPVisionModel
if ("vision_model.post_layernorm" in name
and not self._require_post_layernorm):
if (name.startswith("vision_model.post_layernorm")
and self.vision_model.post_layernorm is None):
continue
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since _require_post_layernorm is only used here, the post_layernorm layer can be directly referenced here to reduce a layer of indirection, making the code easier to understand.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using startswith to check the weight names is more robust and performant.

Copy link
Member

@ywang96 ywang96 Sep 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought with the composite weight loading now that the vision_model prefix will be dropped - is that not the case?

nvm - I was thinking vision_tower. Ignore my comment here


# omit layers when num_hidden_layers_override is set
if "vision_model.encoder.layers." in name:
if name.startswith("vision_model.encoder.layers"):
layer_idx = int(name.split(".")[3])
if layer_idx >= layer_count:
continue
Expand Down
13 changes: 6 additions & 7 deletions vllm/model_executor/models/llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,6 @@

logger = init_logger(__name__)

_KEYS_TO_MODIFY_MAPPING = {
"language_model.lm_head": "lm_head",
"language_model.model": "language_model",
}
Comment on lines -37 to -40
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is unused so I removed it.


# Result in the max possible feature size (2x2 grid of 336x336px tiles)
MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448

Expand Down Expand Up @@ -635,8 +630,12 @@ def sample(

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components
vit_weights, mlp_weights, newline_weights, llm_weights = itertools.tee(
weights, 4)
(
vit_weights,
mlp_weights,
newline_weights,
llm_weights,
) = itertools.tee(weights, 4)

# load vision encoder
vit_weights = filter_weights(vit_weights, "vision_tower")
Expand Down
3 changes: 1 addition & 2 deletions vllm/model_executor/models/llava_next_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,8 +450,7 @@ def sample(

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators
vit_weights, mlp_weights, newline_weights, llm_weights = itertools.tee(
weights, 4)
vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)

# load vision encoder
vit_weights = filter_weights(vit_weights, "vision_tower")
Expand Down
11 changes: 4 additions & 7 deletions vllm/model_executor/models/siglip.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,7 @@ def __init__(
num_hidden_layers_override: Optional[int] = None,
):
super().__init__()

num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
Expand All @@ -513,10 +514,6 @@ def __init__(
num_hidden_layers_override=num_hidden_layers_override,
)

@property
def _require_post_layernorm(self) -> bool:
return self.vision_model.post_layernorm is not None

def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding

Expand All @@ -542,12 +539,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

for name, loaded_weight in weights:
# post_layernorm is optional in SiglipVisionModel
if ("vision_model.post_layernorm" in name
and not self._require_post_layernorm):
if (name.startswith("vision_model.post_layernorm")
and self.vision_model.post_layernorm is None):
continue

# omit layers when num_hidden_layers_override is set
if "vision_model.encoder.layers." in name:
if name.startswith("vision_model.encoder.layers"):
layer_idx = int(name.split(".")[3])
if layer_idx >= layer_count:
continue
Expand Down
Loading