Skip to content

Commit

Permalink
[ZeRO] Fixes issue with embedding resize (huggingface#16093)
Browse files Browse the repository at this point in the history
* gather z3 params for new_lm_head

* Update src/transformers/modeling_utils.py

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
  • Loading branch information
jeffra and stas00 authored Mar 11, 2022
1 parent ae2dd42 commit c1f209d
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,7 +892,8 @@ def _get_resized_lm_head(
if is_deepspeed_zero3_enabled():
import deepspeed

with deepspeed.zero.GatheredParameters(old_lm_head.weight, modifier_rank=0):
params = [old_lm_head.weight, old_lm_head.bias, new_lm_head.weight, new_lm_head.bias]
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
if torch.distributed.get_rank() == 0:
# Copy old lm head weights to new lm head
if not transposed:
Expand Down

0 comments on commit c1f209d

Please sign in to comment.