Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

deepspeed-chat: fix bf16 stage2 accuracy for bloom-560m #772

Merged
merged 1 commit into from
Oct 17, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
deepspeed-chat: fix bf16 stage2 accuracy for bloom-560m
Bloom-560m model has high variance in its last LN layer weight.
This causes accuracy issues in bf16 stage2 training.
Therefore, reset the parameters of the last LN layer before training.
This is a good practice in any case where we replace the classifier that
follows the LN.

In addition, in case we are using only optimize lora, we need to force the
training of the LN parameters that were reset.

Note that current fix uses plain initialization of final LN.
A separate commit will provide support for zero3 initialization.

Change-Id: I323d8947907eb4a1cc0fa6354bdaf0cbbf33a68d
Signed-off-by: Moshe Island <misland@habana.ai>
  • Loading branch information
mosheisland committed Oct 17, 2023
commit 9176e0e2091139c1402e2a3ab788270517acb7d0
Original file line number Diff line number Diff line change
Expand Up @@ -247,13 +247,25 @@ def main():
zero_stage=args.zero_stage,
compute_fp32_loss=args.compute_fp32_loss)

# Model bigscience/bloom-560m has large variance at ln_f.weight parameter
# This makes bf16 finetuning hard.
# In general, since we are replacing the model head, it makes sense to reset
# the LN that precedes it.
force_optimize_params = []
if "bigscience/bloom-" in args.model_name_or_path:
torch.nn.init.ones_(rm_model.rwtransformer.ln_f.weight)
torch.nn.init.zeros_(rm_model.rwtransformer.ln_f.bias)
force_optimize_params.extend(
['rwtransformer.ln_f.weight', 'rwtransformer.ln_f.bias'])

if args.lora_dim > 0:
rm_model = convert_linear_layer_to_lora(rm_model,
args.lora_module_name,
args.lora_dim)
if args.only_optimize_lora:
rm_model = only_optimize_lora_parameters(
rm_model, force_optimize_params=['v_head.weight'])
force_optimize_params.append('v_head.weight')
rm_model = only_optimize_lora_parameters(rm_model,
force_optimize_params)
rm_model = make_model_gradient_checkpointing_compatible(rm_model)

train_phase = 2
Expand Down
Loading