Skip to content

Commit fc436d7

Browse files
committed
Fix up
1 parent f4a9cf4 commit fc436d7

File tree

1 file changed

+2
-6
lines changed

1 file changed

+2
-6
lines changed

src/transformers/modeling_utils.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2133,13 +2133,9 @@ def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None, mean
21332133
if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings:
21342134
old_lm_head = self.get_output_embeddings()
21352135
if isinstance(old_lm_head, torch.nn.Embedding):
2136-
new_lm_head = self._get_resized_embeddings(
2137-
old_lm_head, new_num_tokens, mean_resizing=mean_resizing
2138-
)
2136+
new_lm_head = self._get_resized_embeddings(old_lm_head, new_num_tokens, mean_resizing=mean_resizing)
21392137
else:
2140-
new_lm_head = self._get_resized_lm_head(
2141-
old_lm_head, new_num_tokens, mean_resizing=mean_resizing
2142-
)
2138+
new_lm_head = self._get_resized_lm_head(old_lm_head, new_num_tokens, mean_resizing=mean_resizing)
21432139
if hasattr(old_lm_head, "_hf_hook"):
21442140
hook = old_lm_head._hf_hook
21452141
add_hook_to_module(new_lm_head, hook)

0 commit comments

Comments
 (0)