Skip to content

Commit

Permalink
Zheweiyao/fixing training acc (microsoft#346)
Browse files Browse the repository at this point in the history
* add more on benchmark and fix prompt model loading

* fix format

* add all code

* add training logs

* update training script

* add training explanation
  • Loading branch information
yaozhewei authored Apr 21, 2023
1 parent dcf67c0 commit 3a77097
Show file tree
Hide file tree
Showing 29 changed files with 5,781 additions and 57 deletions.
2 changes: 1 addition & 1 deletion applications/DeepSpeed-Chat/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ For more APIs, example scripts, and evaluation results, please refer to
- [**Step1: Supervised Fine-Tuning (SFT)**](./training/step1_supervised_finetuning/README.md)
- [**Step2: Reward Model Fine-Tuning**](./training/step2_reward_model_finetuning/README.md)
- [**Step3: Reinforcement Learning Human Feedback (RLHF)**](./training/step3_rlhf_finetuning/README.md)

- [**Training Details Explanation**](./training/README.md)

## 🌱 DeepSpeed Chat's Roadmap 🌱

Expand Down
56 changes: 56 additions & 0 deletions applications/DeepSpeed-Chat/training/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Training Instability
It's important to note that training large language models (LLMs) and Reinforcement Learning from Human Feedback (RLHF) are still open problems with many unknowns. DeepSpeed-Chat aims to provide an end-to-end RLHF training pipeline with efficient and fast system support, rather than a comprehensive solution for RLHF training. As this field is relatively new, there are various unknowns for both users and developers.

Users are encouraged to experiment with the pipeline, provide feedback, and make suggestions. Contributions to the project are welcome when users find something useful and tested for their use cases. By working together, we can advance the development of this project and improve our understanding of LLMs and RLHF training.

## Three Training Steps Discussion
### Step 1: Supervised Finetuning
Supervised fine-tuning (SFT) has indeed made significant progress in the field of large language models (LLMs). However, unexpected behaviors such as repeating content generation and inconsistency between perplexity (PPL) scores and generation capabilities can still occur.

Based on our testing, there are several terms that affect the generation behavior:
* ```weight decay```: OPT models are pretrained with weight decay. Following that, finetuning normally inherits this setting. However, it may not produce the desired model. Particularly, for our OPT-1.3B example, we disabled weight decay.
* ```dropout```: Similar as above, dropout is used in OPT pretraining. However, SFT may not necessary need it. Particularly, for our OPT-1.3B example, we enabled dropout.
* ```dataset```: Using more data usually provide better model quality. But if the sources of datasets are too different, it may hurt the performance. For our OPT-1.3B example, we use the following four datasets: ```Dahoas/rm-static Dahoas/full-hh-rlhf Dahoas/synthetic-instruct-gptj-pairwise yitingxie/rlhf-reward-datasets```.
* ```training epochs``` Normally, to avoid overfitting, we choose smaller training epochs instead of longer epochs if smaller epochs can achieve similar model quality (in this case, we use PPL as an indicator). However, similar as InstructGPT pointed, we found even though we got overfitting due to longer training, it is still recommended to use longer training epochs to get better generation quality. Particularly, for our OPT-1.3B example, we use 16 epochs even though we found 1 or 2 epochs training can reach the same PPL score.

### Step 2: Reward Model Finetuning
Reward model (RM) fine-tuning is indeed similar to SFT, with the main differences being: (1) the training datasets are different - RM requires both good responses and bad responses to the same query; (2) the training loss is different - RM requires pair ranking loss as the optimizing objective.

We provide two metrics for the reward model: (1) the reward score for accepted responses (and bad responses), and (2) the accuracy, i.e., when accepted responses can get higher scores than rejected responses. Sometimes, we observe that the accuracy is very high, but the average reward score for accepted answers is negative, or the rejected answer's score is similar to accepted answers. Would this affect the step-3 model quality? If we use the metric reward score gain for step-3, this probably won't have any issue. However, this machine learning metric (reward score gain/increasing) cannot really reflect the step-3 model generation quality. As such, we do not have a definitive answer yet.

Here, we share more about what we observed during our exploration:
* ```weight decay```: For our OPT-350m example, we enabled weight decay with 0.1.
* ```dropout```: For our OPT-350m example, we disabled dropout.
* ```dataset```: For our OPT-350m example, we use the following four datasets: ```Dahoas/rm-static Dahoas/full-hh-rlhf Dahoas/synthetic-instruct-gptj-pairwise yitingxie/rlhf-reward-datasets```.
* ```training epochs``` InstructGPT suggests to finetune the model with 1 epoch since overfitting hurts the step 3 performance. During our exploration, we did not see overfitting behavior when we increased the training epochs. However, to follow the instrution from authors. We set training epoch to be 1.

Also, we provide more explorations here even though we have not set them as an option or included them in our current pipeline
* ```multiple answers for one prompt``` In InstructGPT, authors specifically mentioned that using paird rejected and accepted answers for one prompt is not good for reward model training. Therefore, InstructGPT construts the dataset with 4--9 answers per prompt. However, we did not find good datasets with this feature.
* ```initialize RM with SFT or Pretrained checkpoint``` We internally tested this but did not see big difference for either accuracy or reward score. Also, in InstructGPT, authors have the same finding. However, we encourage users to try it for their own usage.
* ```Reward score calculation``` We use the final token (or the first padding token) to get the reward score. However, it might not be the optimal choice. For instance, users can try the average score for the entire answer etc.
* ```Reward loss objective``` We simply use the ranking loss to be the objective. However, others, like MSE, can also be an option.


### Step 3: RLHF finetuning
The RLHF finetuning is the most complicated step among the three step training. Similar to SFT, reward score cannot really reflect the model generation quality. Also, we sometines observed that reward score drops to initial phase at certain point then quickly recovers. To make things worse, we also see the training can easily get divergence. We here share our settings and observations.

* ```weight decay```: For our OPT-1.3B/350m (actor/critic) example, we disabled weight decay for both models.
* ```dropout```: We disabled droppout for OPT-1.3B and enabled it for OPT-350m.
* ```dataset```: We use the following single dataset: ```Dahoas/rm-static```.
* ```training epochs``` The reward score quickly becomes platou. Therefore, we set the training epoch to be 1 for our OPT-1.3B/350m (actor/critic) example. However, longer training may bring better model quality as SFT.
* ```ema checkpoint``` We observe ema checkpoint can generally bring bettr model generation quality as stated in InstructGPT.
* ```PPO related hyperparameters``` PPO training has a lot of hyperparameters, see [here](https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py#L61-L66). For now, we hard-coded them for users but you may want to adjust them for you own usage.
* ```mix unsupervised training``` InstructGPT suggests to mix PPO and unsupervised training to prevent the lost of model's benchmark quality. However, when we directly apply the hyperparameter from Instruct, the model cannot converge. Therefore, we stop exploring this. However, users are encourage to test it and tune the hyperparameter for their own usage.
* ```diverging issue``` We have found that it is very unstable to use different generation training batch sizes (`--per_device_train_batch_size`) and PPO training batch sizes (`--per_device_mini_batch_size`), more than one PPO training epoch (`--ppo_epochs`), or more than one generation batch size (`--generation_batch_numbers`). These all point to the same problem: we are not able to update the actor model multiple times after generating experimental data. Therefore, in all of our successful runs, we have set `per_device_train_batch_size=per_device_mini_batch_size` and `ppo_epochs=generation_batch_numbers=1`. This is unexpected for a standard RL training pipeline, and we have tried different methods to overcome this, but all have failed. One of the most likely reasons for this instability is that we found the `log_probs` and `old_log_probs` used in the `actor_loss_fn` function can quickly diverge even within two consecutive iterations, which causes the corresponding `ratio` to be huge. Setting a strict upper bound can alleviate this problem, but it cannot fully resolve the convergence issue.

### About our testing
We did most of our accuracy/quality testing on OPT-1.3B (SFT and Actor model) and OPT-350m (RW and Critic model). Particularly, we used the 16 V100-32G (DGX-2 node) gpus to run our experiments.

The hyperparameters included in our scripts are based on our own testing. Therefore, it may not work for you case when (but not limited to): (1) a different number of GPUs, (2) different model sizes, (3) different model families, etc.

Also note that, you could find even better training configurations/recipes than what we provided. We did not extensively tested all hyperparameter combinations due to resouces constraints.

### Others
RLHF (Reinforcement Learning for Human Feedback) training is still an open problem, and DeepSpeed-Chat is designed to be a starting point for researchers and practitioners to work on it with an efficient and fast training experience. The Hybrid-Engine and other efficient components, like LoRA, can be inherited from DeepSpeed-Chat, allowing you to develop your own RLHF training pipeline for exploration, research, and other purposes.

Contributions from users are highly appreciated to build a more successful, easier-to-use, and more stable RLHF training pipeline together.
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def parse_args():
)
parser.add_argument("--weight_decay",
type=float,
default=0.1,
default=0.,
help="Weight decay to use.")
parser.add_argument("--num_train_epochs",
type=int,
Expand Down Expand Up @@ -138,6 +138,9 @@ def parse_args():
parser.add_argument('--gradient_checkpointing',
action='store_true',
help='Enable HF gradient checkpointing for model.')
parser.add_argument('--disable_dropout',
action='store_true',
help='Disable the dropout of the model.')
# deepspeed features
parser.add_argument('--offload',
action='store_true',
Expand Down Expand Up @@ -204,8 +207,11 @@ def main():
fast_tokenizer=True)
tokenizer.pad_token = tokenizer.eos_token

model = create_hf_model(AutoModelForCausalLM, args.model_name_or_path,
tokenizer, ds_config)
model = create_hf_model(AutoModelForCausalLM,
args.model_name_or_path,
tokenizer,
ds_config,
disable_dropout=args.disable_dropout)

if args.lora_dim > 0:
model = convert_linear_layer_to_lora(model, args.lora_module_name,
Expand Down
Loading

0 comments on commit 3a77097

Please sign in to comment.