diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 545e87c5fef94e..836cd3489811d2 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1502,24 +1502,40 @@ def _get_resized_embeddings( f" {nn.Embedding}." ) - # Build new embeddings - new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim) - new_embeddings.to(old_embeddings.weight.device, dtype=old_embeddings.weight.dtype) - - # initialize all new embeddings (in particular added tokens) - self._init_weights(new_embeddings) - - # Copy token embeddings from the previous weights - # numbers of tokens to copy n = min(old_num_tokens, new_num_tokens) if is_deepspeed_zero3_enabled(): import deepspeed - with deepspeed.zero.GatheredParameters(old_embeddings.weight, modifier_rank=0): - if torch.distributed.get_rank() == 0: - new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :] + with deepspeed.zero.Init(config_dict_or_path=deepspeed_config()): + # Build new embeddings + new_embeddings = nn.Embedding( + new_num_tokens, + old_embedding_dim, + device=old_embeddings.weight.device, + dtype=old_embeddings.weight.dtype, + ) + + params = [old_embeddings.weight, new_embeddings.weight] + with deepspeed.zero.GatheredParameters(params, modifier_rank=0): + # initialize all new embeddings (in particular added tokens) + self._init_weights(new_embeddings) + + # Copy token embeddings from the previous weights + new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :] else: + # Build new embeddings + new_embeddings = nn.Embedding( + new_num_tokens, + old_embedding_dim, + device=old_embeddings.weight.device, + dtype=old_embeddings.weight.dtype, + ) + + # initialize all new embeddings (in particular added tokens) + self._init_weights(new_embeddings) + + # Copy token embeddings from the previous weights new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :] return new_embeddings @@ -1575,11 +1591,6 @@ def _get_resized_lm_head( # Build new lm head new_lm_head_shape = (old_lm_head_dim, new_num_tokens) if not transposed else (new_num_tokens, old_lm_head_dim) has_new_lm_head_bias = old_lm_head.bias is not None - new_lm_head = nn.Linear(*new_lm_head_shape, bias=has_new_lm_head_bias) - new_lm_head = new_lm_head.to(old_lm_head.weight.device, dtype=old_lm_head.weight.dtype) - - # initialize new lm head (in particular added tokens) - self._init_weights(new_lm_head) num_tokens_to_copy = min(old_num_tokens, new_num_tokens) @@ -1587,23 +1598,33 @@ def _get_resized_lm_head( if is_deepspeed_zero3_enabled(): import deepspeed + with deepspeed.zero.Init(config_dict_or_path=deepspeed_config()): + new_lm_head = nn.Linear( + *new_lm_head_shape, + bias=has_new_lm_head_bias, + device=old_lm_head.weight.device, + dtype=old_lm_head.weight.dtype, + ) params = [old_lm_head.weight, old_lm_head.bias, new_lm_head.weight, new_lm_head.bias] with deepspeed.zero.GatheredParameters(params, modifier_rank=0): - if torch.distributed.get_rank() == 0: - # Copy old lm head weights to new lm head - if not transposed: - new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[ - :num_tokens_to_copy, : - ] - else: - new_lm_head.weight.data[:, :num_tokens_to_copy] = old_lm_head.weight.data[ - :, :num_tokens_to_copy - ] + self._init_weights(new_lm_head) + # Copy old lm head weights to new lm head + if not transposed: + new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[:num_tokens_to_copy, :] + else: + new_lm_head.weight.data[:, :num_tokens_to_copy] = old_lm_head.weight.data[:, :num_tokens_to_copy] - # Copy bias weights to new lm head - if has_new_lm_head_bias: - new_lm_head.bias.data[:num_tokens_to_copy] = old_lm_head.bias.data[:num_tokens_to_copy] + # Copy bias weights to new lm head + if has_new_lm_head_bias: + new_lm_head.bias.data[:num_tokens_to_copy] = old_lm_head.bias.data[:num_tokens_to_copy] else: + new_lm_head = nn.Linear( + *new_lm_head_shape, + bias=has_new_lm_head_bias, + device=old_lm_head.weight.device, + dtype=old_lm_head.weight.dtype, + ) + self._init_weights(new_lm_head) # Copy old lm head weights to new lm head if not transposed: new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[:num_tokens_to_copy, :]