19
19
from torchchat .cli .builder import BuilderArgs , _load_checkpoint
20
20
21
21
22
- _DEFAULT_SAFETENSOR_FILE_NAME = "model.safetensors.index.json"
22
+ _DEFAULT_SAFETENSOR_INDEX = "model.safetensors.index.json"
23
+ _DEFAULT_BIN_INDEX = "pytorch_model.bin.index.json"
23
24
_CONFIG_NAME = "config.json"
24
25
25
26
@@ -81,31 +82,6 @@ def get_hf_path_from_model_id(model_id: str) -> str:
81
82
return file_location
82
83
83
84
84
- def get_hf_weight_map_and_path (
85
- model_id : str ,
86
- ) -> Tuple [Dict [str , str ], str , Dict [str , str ]]:
87
- """Get the weight map for a given HF model id and also the cache path for loading the weights"""
88
- index_file = cached_file (model_id , _DEFAULT_SAFETENSOR_FILE_NAME )
89
- if not os .path .exists (index_file ):
90
- raise FileNotFoundError (
91
- f"Weight index file for { model_id } does not exist in HF cache."
92
- )
93
- logger .info (
94
- f"Loading weight map from: { index_file } "
95
- )
96
- weight_map = read_weights_from_json (index_file )
97
- if weight_map is None :
98
- raise ValueError (f"Weight map not found in config file { index_file } " )
99
- weight_map , new_to_old_keymap = remap_weight_keys (weight_map )
100
- weight_path = os .path .dirname (index_file )
101
- if not os .path .exists (weight_path ):
102
- raise FileNotFoundError (f"Weight path { weight_path } does not exist" )
103
- logger .info (
104
- f"Loading weights from: { weight_path } "
105
- )
106
- return weight_map , weight_path , new_to_old_keymap
107
-
108
-
109
85
def remap_weight_keys (dictionary ):
110
86
"""Remap the keys of a dictionary to match the expected format of the tune model."""
111
87
# hf_key : dist_model_key
@@ -141,12 +117,13 @@ def remap_weight_keys(dictionary):
141
117
return new_dict , key_mapping
142
118
143
119
144
- def load_safetensor_weights (
120
+ def load_weights_per_map (
145
121
stage_module : Module ,
146
122
weight_map : Dict [str , str ],
147
123
file_location : str ,
148
124
new_to_old_keymap : Dict [str , str ],
149
- device : torch .device = "cuda" ,
125
+ device : torch .device ,
126
+ is_safetensor : bool ,
150
127
purge_model_prefix : bool = True ,
151
128
ignore_cache_layers : bool = True ,
152
129
model_config : Optional [Dict ] = None ,
@@ -160,6 +137,7 @@ def load_safetensor_weights(
160
137
file_location (str): Directory containing the weight files.
161
138
new_to_old_keymap (Dict[str, str]): Mapping of new parameter names to old ones.
162
139
device (torch.device): The device to load tensors onto.
140
+ is_safetensor (bool): Whether the files are safetensors.
163
141
purge_model_prefix (bool): Whether to remove 'model.' prefix from keys.
164
142
ignore_cache_layers (bool): Whether to ignore cache layers when reporting missing keys.
165
143
model_config (Optional[Dict]): Model configuration.
@@ -178,9 +156,13 @@ def load_safetensor_weights(
178
156
for file in needed_files :
179
157
full_path = os .path .join (file_location , file )
180
158
# logger.info(f"Loading checkpoint file: {full_path}")
181
- try :
182
- checkpoint = load_safetensor_file (full_path , "cpu" ) # device)
159
+ # TODO: directly load to device
160
+ if is_safetensor :
161
+ checkpoint = load_safetensor_file (full_path )
162
+ else :
163
+ checkpoint = torch .load (full_path , mmap = True , weights_only = True )
183
164
165
+ try :
184
166
update_state_dict (
185
167
stage_state_dict ,
186
168
checkpoint ,
@@ -189,10 +171,9 @@ def load_safetensor_weights(
189
171
new_to_old_keymap = new_to_old_keymap ,
190
172
updated_states = updated_states ,
191
173
)
192
- except FileNotFoundError :
193
- logger .error (f"File not found: { full_path } " )
194
174
except Exception as e :
195
- logger .error (f"Error during checkpoint processing of { full_path } : { str (e )} " )
175
+ logger .error (f"Error during checkpoint processing:" )
176
+ raise e
196
177
197
178
missing_keys = handle_missing_keys (
198
179
stage_state_dict , updated_states , ignore_cache_layers
@@ -244,12 +225,14 @@ def get_needed_files(
244
225
return needed_files
245
226
246
227
247
- def load_safetensor_file (full_path : str , device : torch .device ) -> Dict [str , torch .Tensor ]:
228
+ def load_safetensor_file (
229
+ full_path : str ,
230
+ device : str = "cpu" ,
231
+ ) -> Dict [str , torch .Tensor ]:
248
232
tensors = {}
249
233
with safe_open (full_path , framework = "pt" , device = device ) as f :
250
234
for k in f .keys ():
251
235
tensors [k ] = f .get_tensor (k )
252
- logger .info (f"Loaded { len (tensors )} tensors from { full_path } " )
253
236
return tensors
254
237
255
238
@@ -378,15 +361,35 @@ def load_weights_from_hf_format(stage_module, distribution, device, model_config
378
361
files), and fill into `stage_module`. Model config is needed b/c we permute
379
362
wq and wk weights based on attn heads.
380
363
"""
364
+ # Get the weight map for a given HF model id
365
+ try :
366
+ index_file = cached_file (distribution , _DEFAULT_SAFETENSOR_INDEX )
367
+ is_safetensor = True
368
+ except :
369
+ index_file = cached_file (distribution , _DEFAULT_BIN_INDEX )
370
+ is_safetensor = False
371
+ logger .info (f"Loading weight map from: { index_file } " )
372
+
373
+ # Read the weight map from the index file
374
+ weight_map = read_weights_from_json (index_file )
375
+ if weight_map is None :
376
+ raise ValueError (f"Weight map not found in config file { index_file } " )
377
+
378
+ # Remap the FQNs to the FQNs in HF checkpoints
379
+ weight_map , new_to_old_keymap = remap_weight_keys (weight_map )
381
380
382
- weight_map , weight_path , key_map = get_hf_weight_map_and_path (distribution )
381
+ # Get the dir containing the weight files
382
+ weight_dir = os .path .dirname (index_file )
383
+ logger .info (f"Loading weights from: { weight_dir } " )
383
384
384
- num_loaded_weights , num_missing_weights = load_safetensor_weights (
385
+ # Load the weights into the stage module
386
+ num_loaded_weights , num_missing_weights = load_weights_per_map (
385
387
stage_module ,
386
388
weight_map ,
387
- weight_path ,
388
- key_map ,
389
+ weight_dir ,
390
+ new_to_old_keymap ,
389
391
device ,
392
+ is_safetensor ,
390
393
model_config = model_config ,
391
394
)
392
395
logger .info (
0 commit comments