Skip to content

Commit f4a9cf4

Browse files
committed
Use mean_resizing instead of multivariate_resizing
1 parent 5cdce5f commit f4a9cf4

File tree

1 file changed

+20
-20
lines changed

1 file changed

+20
-20
lines changed

src/transformers/modeling_utils.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2051,7 +2051,7 @@ def resize_token_embeddings(
20512051
self,
20522052
new_num_tokens: Optional[int] = None,
20532053
pad_to_multiple_of: Optional[int] = None,
2054-
multivariate_resizing: bool = True,
2054+
mean_resizing: bool = True,
20552055
) -> nn.Embedding:
20562056
"""
20572057
Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`.
@@ -2071,19 +2071,19 @@ def resize_token_embeddings(
20712071
`>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
20722072
details about this, or help on choosing the correct value for resizing, refer to this guide:
20732073
https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
2074-
multivariate_resizing (`bool`):
2074+
mean_resizing (`bool`):
20752075
Whether to initialize the added embeddings from a multivariate normal distribution that has old embeddings' mean and
20762076
covariance or to initialize them with a normal distribution that has a mean of zero and std equals `initializer_range`.
20772077
2078-
Setting `multivariate_resizing` to `True` is useful when increasing the size of the embedding for language models.
2078+
Setting `mean_resizing` to `True` is useful when increasing the size of the embedding for language models.
20792079
Where the generated tokens will not be affected by the added embeddings because this will reduce the kl-divergence
20802080
between the next token probability before and after adding the new embeddings.
20812081
Refer to this article for more information: https://nlp.stanford.edu/~johnhew/vocab-expansion.html
20822082
20832083
Return:
20842084
`torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.
20852085
"""
2086-
model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of, multivariate_resizing)
2086+
model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
20872087
if new_num_tokens is None and pad_to_multiple_of is None:
20882088
return model_embeds
20892089

@@ -2106,10 +2106,10 @@ def resize_token_embeddings(
21062106

21072107
return model_embeds
21082108

2109-
def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None, multivariate_resizing=True):
2109+
def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None, mean_resizing=True):
21102110
old_embeddings = self.get_input_embeddings()
21112111
new_embeddings = self._get_resized_embeddings(
2112-
old_embeddings, new_num_tokens, pad_to_multiple_of, multivariate_resizing
2112+
old_embeddings, new_num_tokens, pad_to_multiple_of, mean_resizing
21132113
)
21142114
if hasattr(old_embeddings, "_hf_hook"):
21152115
hook = old_embeddings._hf_hook
@@ -2134,11 +2134,11 @@ def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None, mult
21342134
old_lm_head = self.get_output_embeddings()
21352135
if isinstance(old_lm_head, torch.nn.Embedding):
21362136
new_lm_head = self._get_resized_embeddings(
2137-
old_lm_head, new_num_tokens, multivariate_resizing=multivariate_resizing
2137+
old_lm_head, new_num_tokens, mean_resizing=mean_resizing
21382138
)
21392139
else:
21402140
new_lm_head = self._get_resized_lm_head(
2141-
old_lm_head, new_num_tokens, multivariate_resizing=multivariate_resizing
2141+
old_lm_head, new_num_tokens, mean_resizing=mean_resizing
21422142
)
21432143
if hasattr(old_lm_head, "_hf_hook"):
21442144
hook = old_lm_head._hf_hook
@@ -2154,7 +2154,7 @@ def _get_resized_embeddings(
21542154
old_embeddings: nn.Embedding,
21552155
new_num_tokens: Optional[int] = None,
21562156
pad_to_multiple_of: Optional[int] = None,
2157-
multivariate_resizing: bool = True,
2157+
mean_resizing: bool = True,
21582158
) -> nn.Embedding:
21592159
"""
21602160
Build a resized Embedding Module from a provided token Embedding Module. Increasing the size will add newly
@@ -2177,11 +2177,11 @@ def _get_resized_embeddings(
21772177
`>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
21782178
details about this, or help on choosing the correct value for resizing, refer to this guide:
21792179
https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
2180-
multivariate_resizing (`bool`):
2180+
mean_resizing (`bool`):
21812181
Whether to initialize the added embeddings from a multivariate normal distribution that has old embeddings' mean and
21822182
covariance or to initialize them with a normal distribution that has a mean of zero and std equals `initializer_range`.
21832183
2184-
Setting `multivariate_resizing` to `True` is useful when increasing the size of the embedding for language models.
2184+
Setting `mean_resizing` to `True` is useful when increasing the size of the embedding for language models.
21852185
Where the generated tokens will not be affected by the added embeddings because this will reduce the kl-divergence
21862186
between the next token probability before and after adding the new embeddings.
21872187
Refer to this article for more information: https://nlp.stanford.edu/~johnhew/vocab-expansion.html
@@ -2243,18 +2243,18 @@ def _get_resized_embeddings(
22432243
dtype=old_embeddings.weight.dtype,
22442244
)
22452245

2246-
if new_num_tokens > old_num_tokens and not multivariate_resizing:
2246+
if new_num_tokens > old_num_tokens and not mean_resizing:
22472247
self._init_weights(new_embeddings)
22482248

2249-
elif new_num_tokens > old_num_tokens and multivariate_resizing:
2249+
elif new_num_tokens > old_num_tokens and mean_resizing:
22502250
# initialize new embeddings (in particular added tokens) if `new_num_tokens` is larger
22512251
# than `old_num_tokens`. The new embeddings will be sampled from a multivariate normal
22522252
# distribution that has old embeddings' mean and covariance. as described in this article:
22532253
# https://nlp.stanford.edu/~johnhew/vocab-expansion.html
22542254
logger.warning_once(
22552255
"The new embeddings will be sampled from a multivariate normal distribution that has old embeddings' mean and covariance. "
22562256
"As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. "
2257-
"To disable this, use `multivariate_resizing=False`"
2257+
"To disable this, use `mean_resizing=False`"
22582258
)
22592259

22602260
added_num_tokens = new_num_tokens - old_num_tokens
@@ -2312,7 +2312,7 @@ def _get_resized_lm_head(
23122312
old_lm_head: nn.Linear,
23132313
new_num_tokens: Optional[int] = None,
23142314
transposed: Optional[bool] = False,
2315-
multivariate_resizing: bool = True,
2315+
mean_resizing: bool = True,
23162316
) -> nn.Linear:
23172317
"""
23182318
Build a resized Linear Module from a provided old Linear Module. Increasing the size will add newly initialized
@@ -2329,11 +2329,11 @@ def _get_resized_lm_head(
23292329
`torch.nn.Linear` module of the model without doing anything. transposed (`bool`, *optional*, defaults
23302330
to `False`): Whether `old_lm_head` is transposed or not. If True `old_lm_head.size()` is `lm_head_dim,
23312331
vocab_size` else `vocab_size, lm_head_dim`.
2332-
multivariate_resizing (`bool`):
2332+
mean_resizing (`bool`):
23332333
Whether to initialize the added embeddings from a multivariate normal distribution that has old embeddings' mean and
23342334
covariance or to initialize them with a normal distribution that has a mean of zero and std equals `initializer_range`.
23352335
2336-
Setting `multivariate_resizing` to `True` is useful when increasing the size of the embedding for language models.
2336+
Setting `mean_resizing` to `True` is useful when increasing the size of the embedding for language models.
23372337
Where the generated tokens will not be affected by the added embeddings because this will reduce the kl-divergence
23382338
between the next token probability before and after adding the new embeddings.
23392339
Refer to this article for more information: https://nlp.stanford.edu/~johnhew/vocab-expansion.html
@@ -2383,18 +2383,18 @@ def _get_resized_lm_head(
23832383
dtype=old_lm_head.weight.dtype,
23842384
)
23852385

2386-
if new_num_tokens > old_num_tokens and not multivariate_resizing:
2386+
if new_num_tokens > old_num_tokens and not mean_resizing:
23872387
self._init_weights(new_lm_head)
23882388

2389-
elif new_num_tokens > old_num_tokens and multivariate_resizing:
2389+
elif new_num_tokens > old_num_tokens and mean_resizing:
23902390
# initialize new embeddings (in particular added tokens) if `new_num_tokens` is larger
23912391
# than `old_num_tokens`. The new embeddings will be sampled from a multivariate normal
23922392
# distribution that has old embeddings' mean and covariance. as described in this article:
23932393
# https://nlp.stanford.edu/~johnhew/vocab-expansion.html
23942394
logger.warning_once(
23952395
"The new embeddings will be sampled from a multivariate normal distribution that has old embeddings' mean and covariance. "
23962396
"As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. "
2397-
"To disable this, use `multivariate_resizing=False`"
2397+
"To disable this, use `mean_resizing=False`"
23982398
)
23992399

24002400
added_num_tokens = new_num_tokens - old_num_tokens

0 commit comments

Comments
 (0)