From c1f209dadd3ec595de10f8a3560b29e0225d21ab Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Fri, 11 Mar 2022 15:13:11 -0800 Subject: [PATCH] [ZeRO] Fixes issue with embedding resize (#16093) * gather z3 params for new_lm_head * Update src/transformers/modeling_utils.py Co-authored-by: Stas Bekman Co-authored-by: Stas Bekman --- src/transformers/modeling_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 680bc695bd67..5f0ca223667e 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -892,7 +892,8 @@ def _get_resized_lm_head( if is_deepspeed_zero3_enabled(): import deepspeed - with deepspeed.zero.GatheredParameters(old_lm_head.weight, modifier_rank=0): + params = [old_lm_head.weight, old_lm_head.bias, new_lm_head.weight, new_lm_head.bias] + with deepspeed.zero.GatheredParameters(params, modifier_rank=0): if torch.distributed.get_rank() == 0: # Copy old lm head weights to new lm head if not transposed: