Skip to content
/ DPO Public

Implementation of Direct Preference Optimization

Notifications You must be signed in to change notification settings

okarthikb/DPO

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

25 Commits
 
 
 
 
 
 
 
 

Repository files navigation

DPO

This repo is my implementation of DPO (the RL phase).

Direct Preference Optimization is an alternative to RLHF that optimizes the same objective but doesn't require a reward model or online RL. It is much cleaner to implement than say, PPO (Proximal Policy Optimization). The dataset is a .pt file with dicts, where each dict has keys prompt_chosen_tokens (tensor, prompt tokens and chosen response tokens), prompt_rejected_tokens (prompt tokens and rejected response tokens), chosen_loss_mask (the loss mask for prompt_chosen_tokens, we only compute loss for the response tokens), and rejected_loss_mask (for prompt_rejected_tokens).

Dataset is generated by dataset.py using Anthropic's HH-RLHF jsonl files here. For non-Ampere GPUs, change {param, reduce, buffer}_dtype in mixed_precision in train.py to something other than bfloat16.

Check out my post for a more in-depth explanation of DPO.

To train, get the jsonl files. Then

python3 dataset.py --model <path to HF model> --dataset <path to jsonl file>
python3 train.py --nodes <number of nodes> --gpus <gpus per node> --model <path to HF model> ...

For <path to HF model>, use Eleuther's Pythia or the GPT-2 models. Training is sped up with FSDP and activation recomputation.

Relevant reading

  1. Training Deep Nets with Sublinear Memory Cost
  2. Gradient-Checkpointing in PyTorch
  3. PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel
  4. Tensor Parallelism in NumPy (not really relevant to this project, but useful reading)
  5. DPO
  6. RLHF

About

Implementation of Direct Preference Optimization

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages