Skip to content

Commit 75f3a35

Browse files
authored
[Distributed] Support index + multi-bin loading (#1275)
1 parent d75531b commit 75f3a35

File tree

1 file changed

+42
-39
lines changed

1 file changed

+42
-39
lines changed

torchchat/distributed/checkpoint_utils.py

+42-39
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
from torchchat.cli.builder import BuilderArgs, _load_checkpoint
2020

2121

22-
_DEFAULT_SAFETENSOR_FILE_NAME = "model.safetensors.index.json"
22+
_DEFAULT_SAFETENSOR_INDEX = "model.safetensors.index.json"
23+
_DEFAULT_BIN_INDEX = "pytorch_model.bin.index.json"
2324
_CONFIG_NAME = "config.json"
2425

2526

@@ -81,31 +82,6 @@ def get_hf_path_from_model_id(model_id: str) -> str:
8182
return file_location
8283

8384

84-
def get_hf_weight_map_and_path(
85-
model_id: str,
86-
) -> Tuple[Dict[str, str], str, Dict[str, str]]:
87-
"""Get the weight map for a given HF model id and also the cache path for loading the weights"""
88-
index_file = cached_file(model_id, _DEFAULT_SAFETENSOR_FILE_NAME)
89-
if not os.path.exists(index_file):
90-
raise FileNotFoundError(
91-
f"Weight index file for {model_id} does not exist in HF cache."
92-
)
93-
logger.info(
94-
f"Loading weight map from: {index_file}"
95-
)
96-
weight_map = read_weights_from_json(index_file)
97-
if weight_map is None:
98-
raise ValueError(f"Weight map not found in config file {index_file}")
99-
weight_map, new_to_old_keymap = remap_weight_keys(weight_map)
100-
weight_path = os.path.dirname(index_file)
101-
if not os.path.exists(weight_path):
102-
raise FileNotFoundError(f"Weight path {weight_path} does not exist")
103-
logger.info(
104-
f"Loading weights from: {weight_path}"
105-
)
106-
return weight_map, weight_path, new_to_old_keymap
107-
108-
10985
def remap_weight_keys(dictionary):
11086
"""Remap the keys of a dictionary to match the expected format of the tune model."""
11187
# hf_key : dist_model_key
@@ -141,12 +117,13 @@ def remap_weight_keys(dictionary):
141117
return new_dict, key_mapping
142118

143119

144-
def load_safetensor_weights(
120+
def load_weights_per_map(
145121
stage_module: Module,
146122
weight_map: Dict[str, str],
147123
file_location: str,
148124
new_to_old_keymap: Dict[str, str],
149-
device: torch.device = "cuda",
125+
device: torch.device,
126+
is_safetensor: bool,
150127
purge_model_prefix: bool = True,
151128
ignore_cache_layers: bool = True,
152129
model_config: Optional[Dict] = None,
@@ -160,6 +137,7 @@ def load_safetensor_weights(
160137
file_location (str): Directory containing the weight files.
161138
new_to_old_keymap (Dict[str, str]): Mapping of new parameter names to old ones.
162139
device (torch.device): The device to load tensors onto.
140+
is_safetensor (bool): Whether the files are safetensors.
163141
purge_model_prefix (bool): Whether to remove 'model.' prefix from keys.
164142
ignore_cache_layers (bool): Whether to ignore cache layers when reporting missing keys.
165143
model_config (Optional[Dict]): Model configuration.
@@ -178,9 +156,13 @@ def load_safetensor_weights(
178156
for file in needed_files:
179157
full_path = os.path.join(file_location, file)
180158
# logger.info(f"Loading checkpoint file: {full_path}")
181-
try:
182-
checkpoint = load_safetensor_file(full_path, "cpu") # device)
159+
# TODO: directly load to device
160+
if is_safetensor:
161+
checkpoint = load_safetensor_file(full_path)
162+
else:
163+
checkpoint = torch.load(full_path, mmap=True, weights_only=True)
183164

165+
try:
184166
update_state_dict(
185167
stage_state_dict,
186168
checkpoint,
@@ -189,10 +171,9 @@ def load_safetensor_weights(
189171
new_to_old_keymap=new_to_old_keymap,
190172
updated_states=updated_states,
191173
)
192-
except FileNotFoundError:
193-
logger.error(f"File not found: {full_path}")
194174
except Exception as e:
195-
logger.error(f"Error during checkpoint processing of {full_path}: {str(e)}")
175+
logger.error(f"Error during checkpoint processing:")
176+
raise e
196177

197178
missing_keys = handle_missing_keys(
198179
stage_state_dict, updated_states, ignore_cache_layers
@@ -244,12 +225,14 @@ def get_needed_files(
244225
return needed_files
245226

246227

247-
def load_safetensor_file(full_path: str, device: torch.device) -> Dict[str, torch.Tensor]:
228+
def load_safetensor_file(
229+
full_path: str,
230+
device: str = "cpu",
231+
) -> Dict[str, torch.Tensor]:
248232
tensors = {}
249233
with safe_open(full_path, framework="pt", device=device) as f:
250234
for k in f.keys():
251235
tensors[k] = f.get_tensor(k)
252-
logger.info(f"Loaded {len(tensors)} tensors from {full_path}")
253236
return tensors
254237

255238

@@ -378,15 +361,35 @@ def load_weights_from_hf_format(stage_module, distribution, device, model_config
378361
files), and fill into `stage_module`. Model config is needed b/c we permute
379362
wq and wk weights based on attn heads.
380363
"""
364+
# Get the weight map for a given HF model id
365+
try:
366+
index_file = cached_file(distribution, _DEFAULT_SAFETENSOR_INDEX)
367+
is_safetensor = True
368+
except:
369+
index_file = cached_file(distribution, _DEFAULT_BIN_INDEX)
370+
is_safetensor = False
371+
logger.info(f"Loading weight map from: {index_file}")
372+
373+
# Read the weight map from the index file
374+
weight_map = read_weights_from_json(index_file)
375+
if weight_map is None:
376+
raise ValueError(f"Weight map not found in config file {index_file}")
377+
378+
# Remap the FQNs to the FQNs in HF checkpoints
379+
weight_map, new_to_old_keymap = remap_weight_keys(weight_map)
381380

382-
weight_map, weight_path, key_map = get_hf_weight_map_and_path(distribution)
381+
# Get the dir containing the weight files
382+
weight_dir = os.path.dirname(index_file)
383+
logger.info(f"Loading weights from: {weight_dir}")
383384

384-
num_loaded_weights, num_missing_weights = load_safetensor_weights(
385+
# Load the weights into the stage module
386+
num_loaded_weights, num_missing_weights = load_weights_per_map(
385387
stage_module,
386388
weight_map,
387-
weight_path,
388-
key_map,
389+
weight_dir,
390+
new_to_old_keymap,
389391
device,
392+
is_safetensor,
390393
model_config=model_config,
391394
)
392395
logger.info(

0 commit comments

Comments
 (0)