Skip to content

Commit

Permalink
Inconsistency in PreTrainedModel.resize_token_embeddings When ZeRO3 I…
Browse files Browse the repository at this point in the history
…s Enabled (#25394)

* Inconsistency in PreTrainedModel.resize_token_embeddings

This PR addresses #25241.

In previous implementation when ZeRO stage 3 was enbaled, resize_token_embeddings would create independent PyTorch weights on each device. Here we ensure that new embeddings are created with DeepSpeed init, and are properly partitioned accros devices.

* formatting with black

* adding the removed comments back in

---------

Co-authored-by: Sina Moeini <smoeini@amazon.com>
  • Loading branch information
sinamoeini and sinamoeini-amz authored Aug 17, 2023
1 parent b4d5548 commit 9264fc9
Showing 1 changed file with 51 additions and 30 deletions.
81 changes: 51 additions & 30 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1502,24 +1502,40 @@ def _get_resized_embeddings(
f" {nn.Embedding}."
)

# Build new embeddings
new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
new_embeddings.to(old_embeddings.weight.device, dtype=old_embeddings.weight.dtype)

# initialize all new embeddings (in particular added tokens)
self._init_weights(new_embeddings)

# Copy token embeddings from the previous weights

# numbers of tokens to copy
n = min(old_num_tokens, new_num_tokens)
if is_deepspeed_zero3_enabled():
import deepspeed

with deepspeed.zero.GatheredParameters(old_embeddings.weight, modifier_rank=0):
if torch.distributed.get_rank() == 0:
new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]
with deepspeed.zero.Init(config_dict_or_path=deepspeed_config()):
# Build new embeddings
new_embeddings = nn.Embedding(
new_num_tokens,
old_embedding_dim,
device=old_embeddings.weight.device,
dtype=old_embeddings.weight.dtype,
)

params = [old_embeddings.weight, new_embeddings.weight]
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
# initialize all new embeddings (in particular added tokens)
self._init_weights(new_embeddings)

# Copy token embeddings from the previous weights
new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]
else:
# Build new embeddings
new_embeddings = nn.Embedding(
new_num_tokens,
old_embedding_dim,
device=old_embeddings.weight.device,
dtype=old_embeddings.weight.dtype,
)

# initialize all new embeddings (in particular added tokens)
self._init_weights(new_embeddings)

# Copy token embeddings from the previous weights
new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]

return new_embeddings
Expand Down Expand Up @@ -1575,35 +1591,40 @@ def _get_resized_lm_head(
# Build new lm head
new_lm_head_shape = (old_lm_head_dim, new_num_tokens) if not transposed else (new_num_tokens, old_lm_head_dim)
has_new_lm_head_bias = old_lm_head.bias is not None
new_lm_head = nn.Linear(*new_lm_head_shape, bias=has_new_lm_head_bias)
new_lm_head = new_lm_head.to(old_lm_head.weight.device, dtype=old_lm_head.weight.dtype)

# initialize new lm head (in particular added tokens)
self._init_weights(new_lm_head)

num_tokens_to_copy = min(old_num_tokens, new_num_tokens)

# XXX: put the long block of code in a wrapper
if is_deepspeed_zero3_enabled():
import deepspeed

with deepspeed.zero.Init(config_dict_or_path=deepspeed_config()):
new_lm_head = nn.Linear(
*new_lm_head_shape,
bias=has_new_lm_head_bias,
device=old_lm_head.weight.device,
dtype=old_lm_head.weight.dtype,
)
params = [old_lm_head.weight, old_lm_head.bias, new_lm_head.weight, new_lm_head.bias]
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
if torch.distributed.get_rank() == 0:
# Copy old lm head weights to new lm head
if not transposed:
new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[
:num_tokens_to_copy, :
]
else:
new_lm_head.weight.data[:, :num_tokens_to_copy] = old_lm_head.weight.data[
:, :num_tokens_to_copy
]
self._init_weights(new_lm_head)
# Copy old lm head weights to new lm head
if not transposed:
new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[:num_tokens_to_copy, :]
else:
new_lm_head.weight.data[:, :num_tokens_to_copy] = old_lm_head.weight.data[:, :num_tokens_to_copy]

# Copy bias weights to new lm head
if has_new_lm_head_bias:
new_lm_head.bias.data[:num_tokens_to_copy] = old_lm_head.bias.data[:num_tokens_to_copy]
# Copy bias weights to new lm head
if has_new_lm_head_bias:
new_lm_head.bias.data[:num_tokens_to_copy] = old_lm_head.bias.data[:num_tokens_to_copy]
else:
new_lm_head = nn.Linear(
*new_lm_head_shape,
bias=has_new_lm_head_bias,
device=old_lm_head.weight.device,
dtype=old_lm_head.weight.dtype,
)
self._init_weights(new_lm_head)
# Copy old lm head weights to new lm head
if not transposed:
new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[:num_tokens_to_copy, :]
Expand Down

0 comments on commit 9264fc9

Please sign in to comment.