|
1 |
| -import itertools |
2 | 1 | from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
3 | 2 | TypedDict, Union)
|
4 | 3 |
|
|
30 | 29 | from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
31 | 30 | dummy_seq_data_for_siglip, get_siglip_image_feature_size,
|
32 | 31 | get_siglip_patch_grid_length, input_processor_for_siglip)
|
33 |
| -from .utils import (filter_weights, flatten_bn, init_vllm_registered_model, |
34 |
| - merge_multimodal_embeddings) |
| 32 | +from .utils import (flatten_bn, group_weights_with_prefix, |
| 33 | + init_vllm_registered_model, merge_multimodal_embeddings) |
35 | 34 |
|
36 | 35 | logger = init_logger(__name__)
|
37 | 36 |
|
@@ -637,31 +636,26 @@ def sample(
|
637 | 636 |
|
638 | 637 | def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
639 | 638 | # prepare weight iterators for components
|
640 |
| - vit_weights, mlp_weights, newline_weights, llm_weights = itertools.tee( |
641 |
| - weights, 4) |
| 639 | + weights_group = group_weights_with_prefix(weights) |
642 | 640 |
|
643 | 641 | # load vision encoder
|
644 |
| - vit_weights = filter_weights(vit_weights, "vision_tower") |
645 |
| - self.vision_tower.load_weights(vit_weights) |
| 642 | + self.vision_tower.load_weights(weights_group["vision_tower"]) |
646 | 643 |
|
647 | 644 | # load mlp projector
|
648 |
| - mlp_weights = filter_weights(mlp_weights, "multi_modal_projector") |
649 | 645 | mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
|
650 |
| - for name, loaded_weight in mlp_weights: |
| 646 | + for name, loaded_weight in weights_group["multi_modal_projector"]: |
651 | 647 | param = mlp_params_dict[name]
|
652 | 648 | weight_loader = getattr(param, "weight_loader",
|
653 | 649 | default_weight_loader)
|
654 | 650 | weight_loader(param, loaded_weight)
|
655 | 651 |
|
656 | 652 | # load newline
|
657 |
| - newline_weights = filter_weights(newline_weights, "image_newline") |
658 |
| - for name, loaded_weight in newline_weights: |
| 653 | + for name, loaded_weight in weights_group["image_newline"]: |
659 | 654 | assert name == ""
|
660 | 655 | param = self.image_newline
|
661 | 656 | weight_loader = getattr(param, "weight_loader",
|
662 | 657 | default_weight_loader)
|
663 | 658 | weight_loader(param, loaded_weight)
|
664 | 659 |
|
665 | 660 | # load llm backbone
|
666 |
| - llm_weights = filter_weights(llm_weights, "language_model") |
667 |
| - self.language_model.load_weights(llm_weights) |
| 661 | + self.language_model.load_weights(weights_group["language_model"]) |
0 commit comments