Skip to content

Commit

Permalink
[Bugfix] Refactor composite weight loading logic (vllm-project#8656)
Browse files Browse the repository at this point in the history
Signed-off-by: Amit Garg <mitgarg17495@gmail.com>
  • Loading branch information
Isotr0py authored and garg-amit committed Oct 28, 2024
1 parent 7837ed9 commit 08321e1
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 61 deletions.
16 changes: 6 additions & 10 deletions vllm/model_executor/models/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import itertools
import re
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
Expand Down Expand Up @@ -33,8 +32,8 @@
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
get_clip_num_patches)
from .interfaces import SupportsMultiModal
from .utils import (filter_weights, flatten_bn, init_vllm_registered_model,
merge_multimodal_embeddings)
from .utils import (flatten_bn, group_weights_with_prefix,
init_vllm_registered_model, merge_multimodal_embeddings)

IMG_START = '<img>'
IMG_END = '</img>'
Expand Down Expand Up @@ -518,21 +517,18 @@ def sample(

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components
vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)
weights_group = group_weights_with_prefix(weights)

# load vision encoder
vit_weights = filter_weights(vit_weights, "vision_model")
self.vision_model.load_weights(vit_weights)
self.vision_model.load_weights(weights_group["vision_model"])

# load mlp projector
mlp_weights = filter_weights(mlp_weights, "mlp1")
mlp_params_dict = dict(self.mlp1.named_parameters())
for name, loaded_weight in mlp_weights:
for name, loaded_weight in weights_group["mlp1"]:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

# load llm backbone
llm_weights = filter_weights(llm_weights, "language_model")
self.language_model.load_weights(llm_weights)
self.language_model.load_weights(weights_group["language_model"])
16 changes: 6 additions & 10 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import itertools
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)

Expand Down Expand Up @@ -26,8 +25,8 @@
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
input_processor_for_siglip)
from .utils import (filter_weights, flatten_bn, init_vllm_registered_model,
merge_multimodal_embeddings)
from .utils import (flatten_bn, group_weights_with_prefix,
init_vllm_registered_model, merge_multimodal_embeddings)


class LlavaImagePixelInputs(TypedDict):
Expand Down Expand Up @@ -393,21 +392,18 @@ def sample(

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components
vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)
weights_group = group_weights_with_prefix(weights)

# load vision encoder
vit_weights = filter_weights(vit_weights, "vision_tower")
self.vision_tower.load_weights(vit_weights)
self.vision_tower.load_weights(weights_group["vision_tower"])

# load mlp projector
mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
for name, loaded_weight in mlp_weights:
for name, loaded_weight in weights_group["multi_modal_projector"]:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

# load llm backbone
llm_weights = filter_weights(llm_weights, "language_model")
self.language_model.load_weights(llm_weights)
self.language_model.load_weights(weights_group["language_model"])
20 changes: 7 additions & 13 deletions vllm/model_executor/models/llava_next.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import itertools
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)

Expand Down Expand Up @@ -30,8 +29,8 @@
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_siglip_image_feature_size,
get_siglip_patch_grid_length, input_processor_for_siglip)
from .utils import (filter_weights, flatten_bn, init_vllm_registered_model,
merge_multimodal_embeddings)
from .utils import (flatten_bn, group_weights_with_prefix,
init_vllm_registered_model, merge_multimodal_embeddings)

logger = init_logger(__name__)

Expand Down Expand Up @@ -637,31 +636,26 @@ def sample(

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components
vit_weights, mlp_weights, newline_weights, llm_weights = itertools.tee(
weights, 4)
weights_group = group_weights_with_prefix(weights)

# load vision encoder
vit_weights = filter_weights(vit_weights, "vision_tower")
self.vision_tower.load_weights(vit_weights)
self.vision_tower.load_weights(weights_group["vision_tower"])

# load mlp projector
mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
for name, loaded_weight in mlp_weights:
for name, loaded_weight in weights_group["multi_modal_projector"]:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

# load newline
newline_weights = filter_weights(newline_weights, "image_newline")
for name, loaded_weight in newline_weights:
for name, loaded_weight in weights_group["image_newline"]:
assert name == ""
param = self.image_newline
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

# load llm backbone
llm_weights = filter_weights(llm_weights, "language_model")
self.language_model.load_weights(llm_weights)
self.language_model.load_weights(weights_group["language_model"])
17 changes: 6 additions & 11 deletions vllm/model_executor/models/llava_next_video.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import itertools
import math
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
Expand Down Expand Up @@ -30,7 +29,7 @@
from .interfaces import SupportsMultiModal
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip)
from .utils import (filter_weights, init_vllm_registered_model,
from .utils import (group_weights_with_prefix, init_vllm_registered_model,
merge_multimodal_embeddings)

logger = init_logger(__name__)
Expand Down Expand Up @@ -449,23 +448,19 @@ def sample(
return self.language_model.sample(logits, sampling_metadata)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators
vit_weights, mlp_weights, newline_weights, llm_weights = itertools.tee(
weights, 4)
# prepare weight iterators for components
weights_group = group_weights_with_prefix(weights)

# load vision encoder
vit_weights = filter_weights(vit_weights, "vision_tower")
self.vision_tower.load_weights(vit_weights)
self.vision_tower.load_weights(weights_group["vision_tower"])

# load mlp projector
mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
for name, loaded_weight in mlp_weights:
for name, loaded_weight in weights_group["multi_modal_projector"]:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

# load llm backbone
llm_weights = filter_weights(llm_weights, "language_model")
self.language_model.load_weights(llm_weights)
self.language_model.load_weights(weights_group["language_model"])
14 changes: 5 additions & 9 deletions vllm/model_executor/models/paligemma.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import itertools
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)

Expand All @@ -23,7 +22,7 @@
from .interfaces import SupportsMultiModal
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
from .utils import filter_weights, merge_multimodal_embeddings
from .utils import group_weights_with_prefix, merge_multimodal_embeddings

logger = init_logger(__name__)

Expand Down Expand Up @@ -286,21 +285,18 @@ def sample(

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components
vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)
weights_group = group_weights_with_prefix(weights)

# load vision tower
vit_weights = filter_weights(vit_weights, "vision_tower")
self.vision_tower.load_weights(vit_weights)
self.vision_tower.load_weights(weights_group["vision_tower"])

# load mlp projector
mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
for name, loaded_weight in mlp_weights:
for name, loaded_weight in weights_group["multi_modal_projector"]:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

# load llm backbone
llm_weights = filter_weights(llm_weights, "language_model")
self.language_model.load_weights(llm_weights)
self.language_model.load_weights(weights_group["language_model"])
12 changes: 5 additions & 7 deletions vllm/model_executor/models/ultravox.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
"""PyTorch Ultravox model."""

import itertools
import math
from array import array
from functools import lru_cache
Expand Down Expand Up @@ -29,7 +28,8 @@
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.utils import (filter_weights, flatten_bn,
from vllm.model_executor.models.utils import (flatten_bn,
group_weights_with_prefix,
init_vllm_registered_model,
merge_multimodal_embeddings)
from vllm.model_executor.sampling_metadata import SamplingMetadata
Expand Down Expand Up @@ -467,11 +467,10 @@ def sample(

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components
projector_weights, llm_weights = itertools.tee(weights, 2)
weights_group = group_weights_with_prefix(weights)

# load projector weights
projector_weights = filter_weights(projector_weights,
"multi_modal_projector")
projector_weights = weights_group["multi_modal_projector"]
projector_params_dict = dict(
self.multi_modal_projector.named_parameters())
for name, loaded_weight in projector_weights:
Expand All @@ -481,5 +480,4 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weight_loader(param, loaded_weight)

# load llm backbone
llm_weights = filter_weights(llm_weights, "language_model")
self.language_model.load_weights(llm_weights)
self.language_model.load_weights(weights_group["language_model"])
36 changes: 35 additions & 1 deletion vllm/model_executor/models/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import itertools
from collections import UserDict
from typing import (Dict, Iterable, List, Literal, Optional, Protocol, Tuple,
Union, overload)

Expand All @@ -16,7 +18,23 @@
from vllm.utils import is_pin_memory_available


def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], prefix: str):
class WeightsGroup(UserDict):
"""
Wraps grouped weights dictionary for a more informative error message
when attempting to access a weight component that does not exist.
"""

def __getitem__(self, key: str) -> int:
try:
return super().__getitem__(key)
except KeyError as exc:
msg = (f"There is no weights named with the prefix: {key}. "
f"Available prefix: {set(self.keys())}")
raise KeyError(msg) from exc


def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]],
prefix: str) -> Iterable[Tuple[str, torch.Tensor]]:
"""
Helper function to load weights for inner vLLM models.
Expand All @@ -30,6 +48,22 @@ def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], prefix: str):
yield name, loaded_weight


def group_weights_with_prefix(
weights: Iterable[Tuple[str, torch.Tensor]]
) -> Dict[str, Iterable[Tuple[str, torch.Tensor]]]:
"""
Helper function to group weights with prefix
"""
init_weights, repeated_weights = itertools.tee(weights, 2)
weights_prefix = {name.split(".")[0] for name, _ in init_weights}
repeated_weights = itertools.tee(repeated_weights, len(weights_prefix))

return WeightsGroup({
prefix: filter_weights(component, prefix)
for component, prefix in zip(repeated_weights, weights_prefix)
})


def init_vllm_registered_model(
hf_config: PretrainedConfig,
cache_config: Optional[CacheConfig],
Expand Down

0 comments on commit 08321e1

Please sign in to comment.