Skip to content

Commit

Permalink
use original loaded keys to find mismatched keys (#16920)
Browse files Browse the repository at this point in the history
  • Loading branch information
tricktreat authored Apr 26, 2022
1 parent d365f50 commit 2d91e3c
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2022,6 +2022,7 @@ def _fix_key(key):
return key.replace("gamma", "weight")
return key

original_loaded_keys = loaded_keys
loaded_keys = [_fix_key(key) for key in loaded_keys]

if len(prefix) > 0:
Expand Down Expand Up @@ -2114,7 +2115,7 @@ def _find_mismatched_keys(
mismatched_keys = _find_mismatched_keys(
state_dict,
model_state_dict,
loaded_keys,
original_loaded_keys,
add_prefix_to_model,
remove_prefix_from_model,
ignore_mismatched_sizes,
Expand All @@ -2140,7 +2141,7 @@ def _find_mismatched_keys(
mismatched_keys += _find_mismatched_keys(
state_dict,
model_state_dict,
loaded_keys,
original_loaded_keys,
add_prefix_to_model,
remove_prefix_from_model,
ignore_mismatched_sizes,
Expand Down

0 comments on commit 2d91e3c

Please sign in to comment.