Skip to content

Commit 8463754

Browse files
Isotr0pyAlvant
authored andcommitted
[Bugfix] Refactor composite weight loading logic (vllm-project#8656)
Signed-off-by: Alvant <alvasian@yandex.ru>
1 parent 6e16cea commit 8463754

File tree

7 files changed

+70
-61
lines changed

7 files changed

+70
-61
lines changed

vllm/model_executor/models/internvl.py

+6-10
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# Copyright (c) 2023 OpenGVLab
55
# Licensed under The MIT License [see LICENSE for details]
66
# --------------------------------------------------------
7-
import itertools
87
import re
98
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
109
TypedDict, Union)
@@ -33,8 +32,8 @@
3332
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
3433
get_clip_num_patches)
3534
from .interfaces import SupportsMultiModal
36-
from .utils import (filter_weights, flatten_bn, init_vllm_registered_model,
37-
merge_multimodal_embeddings)
35+
from .utils import (flatten_bn, group_weights_with_prefix,
36+
init_vllm_registered_model, merge_multimodal_embeddings)
3837

3938
IMG_START = '<img>'
4039
IMG_END = '</img>'
@@ -518,21 +517,18 @@ def sample(
518517

519518
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
520519
# prepare weight iterators for components
521-
vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)
520+
weights_group = group_weights_with_prefix(weights)
522521

523522
# load vision encoder
524-
vit_weights = filter_weights(vit_weights, "vision_model")
525-
self.vision_model.load_weights(vit_weights)
523+
self.vision_model.load_weights(weights_group["vision_model"])
526524

527525
# load mlp projector
528-
mlp_weights = filter_weights(mlp_weights, "mlp1")
529526
mlp_params_dict = dict(self.mlp1.named_parameters())
530-
for name, loaded_weight in mlp_weights:
527+
for name, loaded_weight in weights_group["mlp1"]:
531528
param = mlp_params_dict[name]
532529
weight_loader = getattr(param, "weight_loader",
533530
default_weight_loader)
534531
weight_loader(param, loaded_weight)
535532

536533
# load llm backbone
537-
llm_weights = filter_weights(llm_weights, "language_model")
538-
self.language_model.load_weights(llm_weights)
534+
self.language_model.load_weights(weights_group["language_model"])

vllm/model_executor/models/llava.py

+6-10
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import itertools
21
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
32
TypedDict, Union)
43

@@ -26,8 +25,8 @@
2625
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
2726
dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
2827
input_processor_for_siglip)
29-
from .utils import (filter_weights, flatten_bn, init_vllm_registered_model,
30-
merge_multimodal_embeddings)
28+
from .utils import (flatten_bn, group_weights_with_prefix,
29+
init_vllm_registered_model, merge_multimodal_embeddings)
3130

3231

3332
class LlavaImagePixelInputs(TypedDict):
@@ -393,21 +392,18 @@ def sample(
393392

394393
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
395394
# prepare weight iterators for components
396-
vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)
395+
weights_group = group_weights_with_prefix(weights)
397396

398397
# load vision encoder
399-
vit_weights = filter_weights(vit_weights, "vision_tower")
400-
self.vision_tower.load_weights(vit_weights)
398+
self.vision_tower.load_weights(weights_group["vision_tower"])
401399

402400
# load mlp projector
403-
mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
404401
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
405-
for name, loaded_weight in mlp_weights:
402+
for name, loaded_weight in weights_group["multi_modal_projector"]:
406403
param = mlp_params_dict[name]
407404
weight_loader = getattr(param, "weight_loader",
408405
default_weight_loader)
409406
weight_loader(param, loaded_weight)
410407

411408
# load llm backbone
412-
llm_weights = filter_weights(llm_weights, "language_model")
413-
self.language_model.load_weights(llm_weights)
409+
self.language_model.load_weights(weights_group["language_model"])

vllm/model_executor/models/llava_next.py

+7-13
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import itertools
21
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
32
TypedDict, Union)
43

@@ -30,8 +29,8 @@
3029
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
3130
dummy_seq_data_for_siglip, get_siglip_image_feature_size,
3231
get_siglip_patch_grid_length, input_processor_for_siglip)
33-
from .utils import (filter_weights, flatten_bn, init_vllm_registered_model,
34-
merge_multimodal_embeddings)
32+
from .utils import (flatten_bn, group_weights_with_prefix,
33+
init_vllm_registered_model, merge_multimodal_embeddings)
3534

3635
logger = init_logger(__name__)
3736

@@ -637,31 +636,26 @@ def sample(
637636

638637
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
639638
# prepare weight iterators for components
640-
vit_weights, mlp_weights, newline_weights, llm_weights = itertools.tee(
641-
weights, 4)
639+
weights_group = group_weights_with_prefix(weights)
642640

643641
# load vision encoder
644-
vit_weights = filter_weights(vit_weights, "vision_tower")
645-
self.vision_tower.load_weights(vit_weights)
642+
self.vision_tower.load_weights(weights_group["vision_tower"])
646643

647644
# load mlp projector
648-
mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
649645
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
650-
for name, loaded_weight in mlp_weights:
646+
for name, loaded_weight in weights_group["multi_modal_projector"]:
651647
param = mlp_params_dict[name]
652648
weight_loader = getattr(param, "weight_loader",
653649
default_weight_loader)
654650
weight_loader(param, loaded_weight)
655651

656652
# load newline
657-
newline_weights = filter_weights(newline_weights, "image_newline")
658-
for name, loaded_weight in newline_weights:
653+
for name, loaded_weight in weights_group["image_newline"]:
659654
assert name == ""
660655
param = self.image_newline
661656
weight_loader = getattr(param, "weight_loader",
662657
default_weight_loader)
663658
weight_loader(param, loaded_weight)
664659

665660
# load llm backbone
666-
llm_weights = filter_weights(llm_weights, "language_model")
667-
self.language_model.load_weights(llm_weights)
661+
self.language_model.load_weights(weights_group["language_model"])

vllm/model_executor/models/llava_next_video.py

+6-11
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import itertools
21
import math
32
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
43
TypedDict, Union)
@@ -30,7 +29,7 @@
3029
from .interfaces import SupportsMultiModal
3130
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
3231
dummy_seq_data_for_siglip)
33-
from .utils import (filter_weights, init_vllm_registered_model,
32+
from .utils import (group_weights_with_prefix, init_vllm_registered_model,
3433
merge_multimodal_embeddings)
3534

3635
logger = init_logger(__name__)
@@ -449,23 +448,19 @@ def sample(
449448
return self.language_model.sample(logits, sampling_metadata)
450449

451450
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
452-
# prepare weight iterators
453-
vit_weights, mlp_weights, newline_weights, llm_weights = itertools.tee(
454-
weights, 4)
451+
# prepare weight iterators for components
452+
weights_group = group_weights_with_prefix(weights)
455453

456454
# load vision encoder
457-
vit_weights = filter_weights(vit_weights, "vision_tower")
458-
self.vision_tower.load_weights(vit_weights)
455+
self.vision_tower.load_weights(weights_group["vision_tower"])
459456

460457
# load mlp projector
461-
mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
462458
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
463-
for name, loaded_weight in mlp_weights:
459+
for name, loaded_weight in weights_group["multi_modal_projector"]:
464460
param = mlp_params_dict[name]
465461
weight_loader = getattr(param, "weight_loader",
466462
default_weight_loader)
467463
weight_loader(param, loaded_weight)
468464

469465
# load llm backbone
470-
llm_weights = filter_weights(llm_weights, "language_model")
471-
self.language_model.load_weights(llm_weights)
466+
self.language_model.load_weights(weights_group["language_model"])

vllm/model_executor/models/paligemma.py

+5-9
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import itertools
21
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
32
TypedDict, Union)
43

@@ -23,7 +22,7 @@
2322
from .interfaces import SupportsMultiModal
2423
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
2524
dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
26-
from .utils import filter_weights, merge_multimodal_embeddings
25+
from .utils import group_weights_with_prefix, merge_multimodal_embeddings
2726

2827
logger = init_logger(__name__)
2928

@@ -286,21 +285,18 @@ def sample(
286285

287286
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
288287
# prepare weight iterators for components
289-
vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)
288+
weights_group = group_weights_with_prefix(weights)
290289

291290
# load vision tower
292-
vit_weights = filter_weights(vit_weights, "vision_tower")
293-
self.vision_tower.load_weights(vit_weights)
291+
self.vision_tower.load_weights(weights_group["vision_tower"])
294292

295293
# load mlp projector
296-
mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
297294
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
298-
for name, loaded_weight in mlp_weights:
295+
for name, loaded_weight in weights_group["multi_modal_projector"]:
299296
param = mlp_params_dict[name]
300297
weight_loader = getattr(param, "weight_loader",
301298
default_weight_loader)
302299
weight_loader(param, loaded_weight)
303300

304301
# load llm backbone
305-
llm_weights = filter_weights(llm_weights, "language_model")
306-
self.language_model.load_weights(llm_weights)
302+
self.language_model.load_weights(weights_group["language_model"])

vllm/model_executor/models/ultravox.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
22
"""PyTorch Ultravox model."""
33

4-
import itertools
54
import math
65
from array import array
76
from functools import lru_cache
@@ -29,7 +28,8 @@
2928
from vllm.model_executor.layers.sampler import SamplerOutput
3029
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
3130
from vllm.model_executor.models.interfaces import SupportsMultiModal
32-
from vllm.model_executor.models.utils import (filter_weights, flatten_bn,
31+
from vllm.model_executor.models.utils import (flatten_bn,
32+
group_weights_with_prefix,
3333
init_vllm_registered_model,
3434
merge_multimodal_embeddings)
3535
from vllm.model_executor.sampling_metadata import SamplingMetadata
@@ -467,11 +467,10 @@ def sample(
467467

468468
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
469469
# prepare weight iterators for components
470-
projector_weights, llm_weights = itertools.tee(weights, 2)
470+
weights_group = group_weights_with_prefix(weights)
471471

472472
# load projector weights
473-
projector_weights = filter_weights(projector_weights,
474-
"multi_modal_projector")
473+
projector_weights = weights_group["multi_modal_projector"]
475474
projector_params_dict = dict(
476475
self.multi_modal_projector.named_parameters())
477476
for name, loaded_weight in projector_weights:
@@ -481,5 +480,4 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
481480
weight_loader(param, loaded_weight)
482481

483482
# load llm backbone
484-
llm_weights = filter_weights(llm_weights, "language_model")
485-
self.language_model.load_weights(llm_weights)
483+
self.language_model.load_weights(weights_group["language_model"])

vllm/model_executor/models/utils.py

+35-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import itertools
2+
from collections import UserDict
13
from typing import (Dict, Iterable, List, Literal, Optional, Protocol, Tuple,
24
Union, overload)
35

@@ -16,7 +18,23 @@
1618
from vllm.utils import is_pin_memory_available
1719

1820

19-
def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], prefix: str):
21+
class WeightsGroup(UserDict):
22+
"""
23+
Wraps grouped weights dictionary for a more informative error message
24+
when attempting to access a weight component that does not exist.
25+
"""
26+
27+
def __getitem__(self, key: str) -> int:
28+
try:
29+
return super().__getitem__(key)
30+
except KeyError as exc:
31+
msg = (f"There is no weights named with the prefix: {key}. "
32+
f"Available prefix: {set(self.keys())}")
33+
raise KeyError(msg) from exc
34+
35+
36+
def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]],
37+
prefix: str) -> Iterable[Tuple[str, torch.Tensor]]:
2038
"""
2139
Helper function to load weights for inner vLLM models.
2240
@@ -30,6 +48,22 @@ def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], prefix: str):
3048
yield name, loaded_weight
3149

3250

51+
def group_weights_with_prefix(
52+
weights: Iterable[Tuple[str, torch.Tensor]]
53+
) -> Dict[str, Iterable[Tuple[str, torch.Tensor]]]:
54+
"""
55+
Helper function to group weights with prefix
56+
"""
57+
init_weights, repeated_weights = itertools.tee(weights, 2)
58+
weights_prefix = {name.split(".")[0] for name, _ in init_weights}
59+
repeated_weights = itertools.tee(repeated_weights, len(weights_prefix))
60+
61+
return WeightsGroup({
62+
prefix: filter_weights(component, prefix)
63+
for component, prefix in zip(repeated_weights, weights_prefix)
64+
})
65+
66+
3367
def init_vllm_registered_model(
3468
hf_config: PretrainedConfig,
3569
cache_config: Optional[CacheConfig],

0 commit comments

Comments
 (0)