Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 28 additions & 27 deletions paddleformers/transformers/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2403,36 +2403,37 @@ def _load_pretrained_model(
missing_keys = list(set(expected_keys) - set(loaded_keys))
unexpected_keys = list(set(loaded_keys) - set(expected_keys))

# Optimize for skip unused shard files for supper large model
if sharded_metadata is not None:
assert isinstance(resolved_archive_file, list)
new_archive_file = []
skip_archive_file = []
if quantization_linear_list is None:
expected_keys_set = set(expected_keys)
else:
origin_expected_keys = [k.replace("quant_weight", "weight") for k in expected_keys]
expected_keys_set = set(expected_keys + origin_expected_keys)
if not cls._get_fuse_or_split_param_mappings(config, is_fuse=True):
# Optimize for skip unused shard files for supper large model
if sharded_metadata is not None:
assert isinstance(resolved_archive_file, list)
new_archive_file = []
skip_archive_file = []
if quantization_linear_list is None:
expected_keys_set = set(expected_keys)
else:
origin_expected_keys = [k.replace("quant_weight", "weight") for k in expected_keys]
expected_keys_set = set(expected_keys + origin_expected_keys)

if key_mapping is not None:
# Determine the precise set of original checkpoint keys that are actually needed for the current file.
# This set will be used to identify which sharded checkpoint files are relevant and must be loaded.
expected_keys_set = {
reverse_key_renaming_mapping[key]
for key in list(expected_keys_set)
if key not in missing_keys and key not in unexpected_keys
}
if key_mapping is not None:
# Determine the precise set of original checkpoint keys that are actually needed for the current file.
# This set will be used to identify which sharded checkpoint files are relevant and must be loaded.
expected_keys_set = {
reverse_key_renaming_mapping[key]
for key in list(expected_keys_set)
if key not in missing_keys and key not in unexpected_keys
}

for file in resolved_archive_file:
filename = os.path.split(file)[-1]
if not expected_keys_set.isdisjoint(set(sharded_metadata["file_map"][filename])):
new_archive_file.append(file)
else:
skip_archive_file.append(filename)
for file in resolved_archive_file:
filename = os.path.split(file)[-1]
if not expected_keys_set.isdisjoint(set(sharded_metadata["file_map"][filename])):
new_archive_file.append(file)
else:
skip_archive_file.append(filename)

resolved_archive_file = new_archive_file
if len(skip_archive_file) > 0:
logger.info(f"Skip load files for not contains expected key, {skip_archive_file}")
resolved_archive_file = new_archive_file
if len(skip_archive_file) > 0:
logger.info(f"Skip load files for not contains expected key, {skip_archive_file}")

# Some models may have keys that are not in the state by design, removing them before needlessly warning
# the user.
Expand Down
Loading