Skip to content

Commit

Permalink
🤫 TR-DPO implementation (#1593)
Browse files Browse the repository at this point in the history
* 🤫 TR-DPO implementation baseline

* fix comments

* docs

* fix linters

* test added

* move configs to DPOConfig

* fix typo

* add docs

* fix import

* use state.global_step

* fix order of arguments

* make sure plugins are not none

* Update trl/trainer/utils.py

Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Update trl/trainer/utils.py

Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>

* checking that reference model weights have changed

* sync_target_model as staticmethod

* set reference model

---------

Co-authored-by: Nikita Surnachev <n.surnachev@tinkoff.ru>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
  • Loading branch information
4 people authored May 23, 2024
1 parent b344bce commit 9a7efbd
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 2 deletions.
2 changes: 2 additions & 0 deletions docs/source/dpo_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ The [SPPO](https://arxiv.org/abs/2405.00675) authors claim that SPPO is capable

The [NCA](https://arxiv.org/abs/2402.05369) authors shows that NCA optimizes the absolute likelihood for each response rather than the relative likelihood.

The [TR-DPO](https://arxiv.org/pdf/2404.09656) paper suggests syncing the reference model weights after every `ref_model_sync_steps` steps of SGD with weight `ref_model_mixup_alpha` during DPO training. To toggle this callback use the `sync_ref_model` flag in the `DPOConfig`.

## Logging

While training and evaluating we record the following reward metrics:
Expand Down
42 changes: 42 additions & 0 deletions tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,48 @@ def test_dpo_trainer_w_dataset_num_proc(self):

trainer.train()

def test_tr_dpo_trainer(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = DPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=4,
learning_rate=9e-1,
evaluation_strategy="steps",
precompute_ref_log_probs=False,
sync_ref_model=True,
ref_model_mixup_alpha=0.5,
ref_model_sync_steps=1,
)

dummy_dataset = self._init_dummy_dataset()

trainer = DPOTrainer(
model=self.model,
ref_model=self.model,
beta=0.1,
args=training_args,
tokenizer=self.tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
)

# params of the ref model as its the same as the model
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

assert trainer.state.log_history[-1]["train_loss"] is not None

# check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.ref_model.get_parameter(n)
# check the ref model's params have changed - ignore 0 biases
if param.sum() != 0:
assert not torch.equal(param, new_param)

@require_no_wandb
def test_dpo_trainer_generate_during_eval_no_wandb(self):
with tempfile.TemporaryDirectory() as tmp_dir:
Expand Down
9 changes: 9 additions & 0 deletions trl/trainer/dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ class DPOConfig(TrainingArguments):
If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses.
force_use_ref_model (`bool`, defaults to `False`):
In case one passes a PEFT model for the active model and you want to use a different model for the ref_model, set this flag to `True`.
sync_ref_model ('bool', defaults to `False`):
The flag for syncing reference model during training from the [TR-DPO](https://arxiv.org/pdf/2404.09656) paper.
ref_model_mixup_alpha ('float', defaults to 1.0):
The alpha parameter from the [TR-DPO](https://arxiv.org/pdf/2404.09656) paper.
ref_model_sync_steps ('int', defaults to 2):
The tau parameter from the [TR-DPO](https://arxiv.org/pdf/2404.09656) paper.
"""

beta: float = 0.1
Expand All @@ -89,3 +95,6 @@ class DPOConfig(TrainingArguments):
ref_adapter_name: Optional[str] = None
reference_free: bool = False
force_use_ref_model: bool = False
sync_ref_model: bool = False
ref_model_mixup_alpha: float = 0.9
ref_model_sync_steps: int = 64
12 changes: 12 additions & 0 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from .utils import (
DPODataCollatorWithPadding,
RunningMoments,
SyncRefModelCallback,
disable_dropout_in_model,
pad_to_length,
peft_module_casting_to_bf16,
Expand Down Expand Up @@ -528,12 +529,23 @@ def make_inputs_require_grad(module, input, output):
raise ValueError(
"No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`"
)
if args.sync_ref_model:
raise ValueError(
"You currently cannot use `ref_model=None` with TR-DPO method. Please provide `ref_model`."
)
else:
if self.is_deepspeed_enabled:
self.ref_model = self._prepare_deepspeed(self.ref_model)
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)

if args.sync_ref_model:
if precompute_ref_log_probs:
raise ValueError(
"You cannot use `precompute_ref_log_probs=True` with TR-DPO method. Please set `precompute_ref_log_probs=False`."
)

self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator))
if self.loss_type == "bco_pair":
self.running = RunningMoments(self.accelerator)

Expand Down
43 changes: 41 additions & 2 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
import numpy as np
import pandas as pd
import torch
from accelerate import PartialState
from accelerate.state import AcceleratorState
from accelerate import Accelerator
from accelerate.state import AcceleratorState, PartialState
from accelerate.utils import is_deepspeed_available
from rich.console import Console, Group
from rich.live import Live
from rich.panel import Panel
Expand All @@ -32,6 +33,7 @@
from transformers import (
BitsAndBytesConfig,
DataCollatorForLanguageModeling,
PreTrainedModel,
PreTrainedTokenizerBase,
)
from transformers.trainer import TrainerCallback
Expand All @@ -45,6 +47,10 @@
from peft import LoraConfig, PeftConfig


if is_deepspeed_available():
import deepspeed


class AdaptiveKLController:
"""
Adaptive KL controller described in the paper:
Expand All @@ -63,6 +69,39 @@ def update(self, current, n_steps):
self.value *= mult


class SyncRefModelCallback(TrainerCallback):
def __init__(
self,
ref_model: Union[PreTrainedModel, torch.nn.Module],
accelerator: Optional[Accelerator],
):
self.accelerator = accelerator
self.ref_model = ref_model

@staticmethod
def _sync_target_model(model, target_model, alpha):
for target_param, copy_param in zip(target_model.parameters(), model.parameters()):
target_param.data.mul_(1.0 - alpha).add_(copy_param.data, alpha=alpha)

@staticmethod
def sync_target_model(model, target_model, alpha):
deepspeed_plugin = AcceleratorState().deepspeed_plugin
if deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3:
with deepspeed.zero.GatheredParameters(list(model.parameters()), modifier_rank=0):
if deepspeed.comm.get_rank() == 0:
SyncRefModelCallback._sync_target_model(model, target_model, alpha)
else:
SyncRefModelCallback._sync_target_model(model, target_model, alpha)

def on_step_end(self, args, state, control, **kwargs):
model: PreTrainedModel = kwargs["model"]

if self.ref_model is not None and state.global_step % args.ref_model_sync_steps == 0:
if self.accelerator:
model = self.accelerator.unwrap_model(model)
self.sync_target_model(model, self.ref_model, args.ref_model_mixup_alpha)


class FixedKLController:
"""Fixed KL controller."""

Expand Down

0 comments on commit 9a7efbd

Please sign in to comment.