This repo is my implementation of DPO (the RL phase).
Direct Preference Optimization is an alternative to RLHF that optimizes the same objective but doesn't require a reward model or online RL. It is much cleaner to implement than say, PPO (Proximal Policy Optimization). The dataset is a .pt
file with dicts, where each dict has keys prompt_chosen_tokens
(tensor, prompt tokens and chosen response tokens), prompt_rejected_tokens
(prompt tokens and rejected response tokens), chosen_loss_mask
(the loss mask for prompt_chosen_tokens
, we only compute loss for the response tokens), and rejected_loss_mask
(for prompt_rejected_tokens
).
Dataset is generated by dataset.py
using Anthropic's HH-RLHF jsonl
files here. For non-Ampere GPUs, change {param, reduce, buffer}_dtype
in mixed_precision
in train.py
to something other than bfloat16
.
Check out my post for a more in-depth explanation of DPO.
To train, get the jsonl
files. Then
python3 dataset.py --model <path to HF model> --dataset <path to jsonl file>
python3 train.py --nodes <number of nodes> --gpus <gpus per node> --model <path to HF model> ...
For <path to HF model>
, use Eleuther's Pythia or the GPT-2 models. Training is sped up with FSDP and activation recomputation.
- Training Deep Nets with Sublinear Memory Cost
- Gradient-Checkpointing in PyTorch
- PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel
- Tensor Parallelism in NumPy (not really relevant to this project, but useful reading)
- DPO
- RLHF