1
1
# ruff: noqa: SIM117
2
2
import collections
3
3
import copy
4
+ import dataclasses
4
5
import fnmatch
5
6
import glob
6
7
import json
7
8
import math
8
9
import os
9
10
from abc import ABC , abstractmethod
10
11
from contextlib import contextmanager
11
- from typing import Any , Dict , Generator , List , Optional , Tuple , Type
12
+ from typing import (Any , Dict , Generator , Iterable , List , Optional , Tuple ,
13
+ Type , cast )
12
14
13
15
import gguf
14
16
import huggingface_hub
@@ -207,6 +209,22 @@ def load_model(self, *, model_config: ModelConfig,
207
209
class DefaultModelLoader (BaseModelLoader ):
208
210
"""Model loader that can load different file types from disk."""
209
211
212
+ @dataclasses .dataclass
213
+ class Source :
214
+ """A source for weights."""
215
+
216
+ model_or_path : str
217
+ """The model ID or path."""
218
+
219
+ revision : Optional [str ]
220
+ """The optional model revision."""
221
+
222
+ prefix : str = ""
223
+ """A prefix to prepend to all weights."""
224
+
225
+ fall_back_to_pt : bool = True
226
+ """Whether .pt weights can be used."""
227
+
210
228
def __init__ (self , load_config : LoadConfig ):
211
229
super ().__init__ (load_config )
212
230
if load_config .model_loader_extra_config :
@@ -313,17 +331,16 @@ def _prepare_weights(self, model_name_or_path: str,
313
331
return hf_folder , hf_weights_files , use_safetensors
314
332
315
333
def _get_weights_iterator (
316
- self , model_name_or_path : str , revision : Optional [str ],
317
- fall_back_to_pt : bool
334
+ self , source : "Source"
318
335
) -> Generator [Tuple [str , torch .Tensor ], None , None ]:
319
336
"""Get an iterator for the model weights based on the load format."""
320
337
hf_folder , hf_weights_files , use_safetensors = self ._prepare_weights (
321
- model_name_or_path , revision , fall_back_to_pt )
338
+ source . model_or_path , source . revision , source . fall_back_to_pt )
322
339
if self .load_config .load_format == LoadFormat .NPCACHE :
323
340
# Currently np_cache only support *.bin checkpoints
324
341
assert use_safetensors is False
325
342
weights_iterator = np_cache_weights_iterator (
326
- model_name_or_path , self .load_config .download_dir , hf_folder ,
343
+ source . model_or_path , self .load_config .download_dir , hf_folder ,
327
344
hf_weights_files )
328
345
elif use_safetensors :
329
346
weights_iterator = safetensors_weights_iterator (hf_weights_files )
@@ -341,7 +358,29 @@ def _xla_weights_iterator(iterator: Generator):
341
358
xm .mark_step ()
342
359
343
360
weights_iterator = _xla_weights_iterator (weights_iterator )
344
- return weights_iterator
361
+
362
+ # Apply the prefix.
363
+ return ((source .prefix + name , tensor )
364
+ for (name , tensor ) in weights_iterator )
365
+
366
+ def _get_all_weights (
367
+ self ,
368
+ model_config : ModelConfig ,
369
+ model : nn .Module ,
370
+ ) -> Generator [Tuple [str , torch .Tensor ], None , None ]:
371
+
372
+ primary_weights = DefaultModelLoader .Source (
373
+ model_config .model ,
374
+ model_config .revision ,
375
+ prefix = "" ,
376
+ fall_back_to_pt = getattr (model , "fall_back_to_pt_during_load" ,
377
+ True ))
378
+ yield from self ._get_weights_iterator (primary_weights )
379
+
380
+ secondary_weights = cast (Iterable [DefaultModelLoader .Source ],
381
+ getattr (model , "secondary_weights" , ()))
382
+ for source in secondary_weights :
383
+ yield from self ._get_weights_iterator (source )
345
384
346
385
def download_model (self , model_config : ModelConfig ) -> None :
347
386
self ._prepare_weights (model_config .model ,
@@ -360,13 +399,8 @@ def load_model(self, *, model_config: ModelConfig,
360
399
model = _initialize_model (model_config , self .load_config ,
361
400
lora_config , cache_config ,
362
401
scheduler_config )
363
- model .load_weights (
364
- self ._get_weights_iterator (model_config .model ,
365
- model_config .revision ,
366
- fall_back_to_pt = getattr (
367
- model ,
368
- "fall_back_to_pt_during_load" ,
369
- True )), )
402
+
403
+ model .load_weights (self ._get_all_weights (model_config , model ))
370
404
371
405
for _ , module in model .named_modules ():
372
406
quant_method = getattr (module , "quant_method" , None )
0 commit comments