Skip to content

Commit 3f06bae

Browse files
authored
[Core][Model] Support loading weights by ID within models (#7931)
1 parent b8747e8 commit 3f06bae

File tree

2 files changed

+73
-17
lines changed

2 files changed

+73
-17
lines changed

vllm/model_executor/model_loader/loader.py

+47-13
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
# ruff: noqa: SIM117
22
import collections
33
import copy
4+
import dataclasses
45
import fnmatch
56
import glob
67
import json
78
import math
89
import os
910
from abc import ABC, abstractmethod
1011
from contextlib import contextmanager
11-
from typing import Any, Dict, Generator, List, Optional, Tuple, Type
12+
from typing import (Any, Dict, Generator, Iterable, List, Optional, Tuple,
13+
Type, cast)
1214

1315
import gguf
1416
import huggingface_hub
@@ -207,6 +209,22 @@ def load_model(self, *, model_config: ModelConfig,
207209
class DefaultModelLoader(BaseModelLoader):
208210
"""Model loader that can load different file types from disk."""
209211

212+
@dataclasses.dataclass
213+
class Source:
214+
"""A source for weights."""
215+
216+
model_or_path: str
217+
"""The model ID or path."""
218+
219+
revision: Optional[str]
220+
"""The optional model revision."""
221+
222+
prefix: str = ""
223+
"""A prefix to prepend to all weights."""
224+
225+
fall_back_to_pt: bool = True
226+
"""Whether .pt weights can be used."""
227+
210228
def __init__(self, load_config: LoadConfig):
211229
super().__init__(load_config)
212230
if load_config.model_loader_extra_config:
@@ -313,17 +331,16 @@ def _prepare_weights(self, model_name_or_path: str,
313331
return hf_folder, hf_weights_files, use_safetensors
314332

315333
def _get_weights_iterator(
316-
self, model_name_or_path: str, revision: Optional[str],
317-
fall_back_to_pt: bool
334+
self, source: "Source"
318335
) -> Generator[Tuple[str, torch.Tensor], None, None]:
319336
"""Get an iterator for the model weights based on the load format."""
320337
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
321-
model_name_or_path, revision, fall_back_to_pt)
338+
source.model_or_path, source.revision, source.fall_back_to_pt)
322339
if self.load_config.load_format == LoadFormat.NPCACHE:
323340
# Currently np_cache only support *.bin checkpoints
324341
assert use_safetensors is False
325342
weights_iterator = np_cache_weights_iterator(
326-
model_name_or_path, self.load_config.download_dir, hf_folder,
343+
source.model_or_path, self.load_config.download_dir, hf_folder,
327344
hf_weights_files)
328345
elif use_safetensors:
329346
weights_iterator = safetensors_weights_iterator(hf_weights_files)
@@ -341,7 +358,29 @@ def _xla_weights_iterator(iterator: Generator):
341358
xm.mark_step()
342359

343360
weights_iterator = _xla_weights_iterator(weights_iterator)
344-
return weights_iterator
361+
362+
# Apply the prefix.
363+
return ((source.prefix + name, tensor)
364+
for (name, tensor) in weights_iterator)
365+
366+
def _get_all_weights(
367+
self,
368+
model_config: ModelConfig,
369+
model: nn.Module,
370+
) -> Generator[Tuple[str, torch.Tensor], None, None]:
371+
372+
primary_weights = DefaultModelLoader.Source(
373+
model_config.model,
374+
model_config.revision,
375+
prefix="",
376+
fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load",
377+
True))
378+
yield from self._get_weights_iterator(primary_weights)
379+
380+
secondary_weights = cast(Iterable[DefaultModelLoader.Source],
381+
getattr(model, "secondary_weights", ()))
382+
for source in secondary_weights:
383+
yield from self._get_weights_iterator(source)
345384

346385
def download_model(self, model_config: ModelConfig) -> None:
347386
self._prepare_weights(model_config.model,
@@ -360,13 +399,8 @@ def load_model(self, *, model_config: ModelConfig,
360399
model = _initialize_model(model_config, self.load_config,
361400
lora_config, cache_config,
362401
scheduler_config)
363-
model.load_weights(
364-
self._get_weights_iterator(model_config.model,
365-
model_config.revision,
366-
fall_back_to_pt=getattr(
367-
model,
368-
"fall_back_to_pt_during_load",
369-
True)), )
402+
403+
model.load_weights(self._get_all_weights(model_config, model))
370404

371405
for _, module in model.named_modules():
372406
quant_method = getattr(module, "quant_method", None)

vllm/model_executor/models/ultravox.py

+26-4
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from vllm.model_executor.layers.quantization.base_config import (
2626
QuantizationConfig)
2727
from vllm.model_executor.layers.sampler import SamplerOutput
28+
from vllm.model_executor.model_loader.loader import DefaultModelLoader
2829
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
2930
from vllm.model_executor.models.interfaces import SupportsMultiModal
3031
from vllm.model_executor.models.utils import (flatten_bn,
@@ -334,14 +335,23 @@ def __init__(self,
334335
self.multi_modal_config = multimodal_config
335336
assert self.multi_modal_config
336337

338+
self.secondary_weights = []
339+
self.audio_tower = ModifiedWhisperEncoder(config.audio_config)
337340
if config.audio_model_id is not None:
338-
self.audio_tower = ModifiedWhisperEncoder.from_pretrained(
339-
config.audio_model_id)
340-
else:
341-
self.audio_tower = ModifiedWhisperEncoder(config.audio_config)
341+
self.secondary_weights.append(
342+
DefaultModelLoader.Source(
343+
model_or_path=config.audio_model_id,
344+
revision=None,
345+
prefix="audio_tower.",
346+
))
342347
self.multi_modal_projector = UltravoxProjector(config)
343348
self.language_model = init_vllm_registered_model(
344349
config.text_config, cache_config, quant_config)
350+
if config.text_model_id is not None:
351+
self.secondary_weights.append(
352+
DefaultModelLoader.Source(model_or_path=config.text_model_id,
353+
revision=None,
354+
prefix="language_model."))
345355

346356
def _audio_features_to_embeddings(
347357
self, input_features: torch.Tensor) -> torch.Tensor:
@@ -466,6 +476,18 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
466476
# prepare weight iterators for components
467477
weights_group = group_weights_with_prefix(weights)
468478

479+
# load audio tower weights
480+
audio_tower_weights = weights_group["audio_tower"]
481+
audio_tower_params_dict = dict(
482+
self.audio_tower.named_parameters(
483+
prefix=self.audio_tower.base_model_prefix))
484+
for name, loaded_weight in audio_tower_weights:
485+
if name in audio_tower_params_dict:
486+
param = audio_tower_params_dict[name]
487+
weight_loader = getattr(param, "weight_loader",
488+
default_weight_loader)
489+
weight_loader(param, loaded_weight)
490+
469491
# load projector weights
470492
projector_weights = weights_group["multi_modal_projector"]
471493
projector_params_dict = dict(

0 commit comments

Comments
 (0)