Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DPO] Adding weighted preference optimization (WPO) #2141

Merged
merged 13 commits into from
Oct 8, 2024
41 changes: 41 additions & 0 deletions tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,47 @@ def test_dpo_trainer(self, name, loss_type, pre_compute):
if param.sum() != 0:
assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)

def test_dpo_trainer_with_weighting(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = DPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=1,
learning_rate=9e-1,
eval_strategy="steps",
beta=0.1,
loss_type="sigmoid",
precompute_ref_log_probs=False,
use_weighting=True,
report_to="none",
)

dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference")

trainer = DPOTrainer(
model=self.model,
ref_model=self.ref_model,
args=training_args,
tokenizer=self.tokenizer,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

assert trainer.state.log_history[-1]["train_loss"] is not None

# check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
# check the params have changed - ignore 0 biases
if param.sum() != 0:
assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)

@parameterized.expand(
[
[None, "Test when rpo_alpha is set to None"],
Expand Down
4 changes: 3 additions & 1 deletion trl/trainer/dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ class DPOConfig(TrainingArguments):
- `"aot_pair"`: AOT loss for unpaired datasets from the [AOT](https://huggingface.co/papers/2406.05882) paper.
- `"apo_zero"`: APO-zero loss from the [APO](https://huggingface.co/papers/2408.06266) paper.
- `"apo_down"`: APO-down loss from the [APO](https://huggingface.co/papers/2408.06266) paper.

use_weighting (`bool`, *optional*, defaults to `False`):
Whether or not to weight the loss as done in the [WPO](https://huggingface.co/papers/2406.11827) paper.
label_pad_token_id (`int`, *optional*, defaults to `-100`):
Label pad token id. This argument is required if you want to use the default data collator.
padding_value (`Optional[int]`, *optional*, defaults to `None`):
Expand Down Expand Up @@ -146,6 +147,7 @@ class DPOConfig(TrainingArguments):
"apo_zero",
"apo_down",
] = "sigmoid"
use_weighting: bool = False
label_pad_token_id: int = -100
padding_value: Optional[int] = None
truncation_mode: str = "keep_end"
Expand Down
58 changes: 47 additions & 11 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1316,17 +1316,22 @@ def get_batch_logps(
labels: torch.LongTensor,
label_pad_token_id: int = -100,
is_encoder_decoder: bool = False,
) -> Tuple[torch.FloatTensor, torch.LongTensor]:
use_weighting: bool = False,
) -> Tuple[torch.FloatTensor, torch.LongTensor, Optional[torch.FloatTensor]]:
"""Compute the log probabilities of the given labels under the given logits.

Args:
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
label_pad_token_id: The label pad token id.
is_encoder_decoder: Whether the model is an encoder-decoder model.
use_weighting: Whether to apply weighting as done in the [WPO](https://huggingface.co/papers/2406.11827) paper.

Returns:
A Tuple of two tensor of shape ((batch_size,), (batch_size,)) containing the sum of log probabilities of the given labels under the given logits in the first tensor and the number of non-masked tokens in the second tensor.
Returns
A Tuple of three tensors of shape ((batch_size,), (batch_size,), Optional[(batch_size,)]) containing:
- The sum of log probabilities of the given labels under the given logits.
- The number of non-masked tokens.
- The wpo weighting (if use_weighting is True, otherwise None).
"""
if logits.shape[:-1] != labels.shape:
raise ValueError(
Expand All @@ -1343,7 +1348,17 @@ def get_batch_logps(

per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)

return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1)
all_logps = (per_token_logps * loss_mask).sum(-1)

all_weights = None
if use_weighting:
# eqn (2) of the WPO paper: https://huggingface.co/papers/2406.11827
probs = F.softmax(logits, dim=-1)
weights_adjustment_factor = torch.log((probs**2).sum(-1))
per_token_logps_adjusted = per_token_logps - weights_adjustment_factor
all_weights = (per_token_logps_adjusted * loss_mask).sum(-1) / loss_mask.sum(-1)

return all_logps, loss_mask.sum(-1), all_weights.detach()

def concatenated_forward(
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
Expand Down Expand Up @@ -1389,12 +1404,13 @@ def concatenated_forward(
seq_len = concatenated_batch["concatenated_labels"].shape[1]
all_logits = all_logits[:, -seq_len:]

all_logps, size_completion = self.get_batch_logps(
all_logps, size_completion, all_weights = self.get_batch_logps(
all_logits,
concatenated_batch["concatenated_labels"],
# average_log_prob=self.loss_type == "ipo",
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
use_weighting=self.args.use_weighting,
)

def cross_entropy_loss(logits, labels):
Expand All @@ -1417,16 +1433,30 @@ def cross_entropy_loss(logits, labels):
if self.loss_type == "ipo":
all_logps = all_logps / size_completion

policy_weights = None
if self.args.use_weighting:
chosen_weights = all_weights[:len_chosen]
rejected_weights = all_weights[len_chosen:]
policy_weights = torch.clamp(torch.exp(chosen_weights + rejected_weights), max=1)

chosen_logps = all_logps[:len_chosen]
rejected_logps = all_logps[len_chosen:]

chosen_logits = all_logits[:len_chosen]
rejected_logits = all_logits[len_chosen:]

if self.aux_loss_enabled:
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss, outputs.aux_loss)
return (
chosen_logps,
rejected_logps,
chosen_logits,
rejected_logits,
nll_loss,
policy_weights,
outputs.aux_loss,
)

return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss)
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss, policy_weights)

def get_batch_loss_metrics(
self,
Expand All @@ -1444,9 +1474,10 @@ def get_batch_loss_metrics(
policy_chosen_logits,
policy_rejected_logits,
policy_nll_loss,
) = forward_output[:5]
policy_weights,
) = forward_output[:6]
if self.aux_loss_enabled:
aux_loss = forward_output[5]
aux_loss = forward_output[6]

# if reference_chosen_logps and reference_rejected_logps in batch use them, otherwise use the reference model
if (
Expand Down Expand Up @@ -1480,6 +1511,9 @@ def get_batch_loss_metrics(
# RPO loss from V3 of the paper:
losses = losses + policy_nll_loss * self.args.rpo_alpha

if self.args.use_weighting:
losses = losses * policy_weights

prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu()
Expand Down Expand Up @@ -1703,15 +1737,17 @@ def create_model_card(
else:
base_model = None

citation = textwrap.dedent("""\
citation = textwrap.dedent(
"""\
@inproceedings{rafailov2023direct,
title = {{Direct Preference Optimization: Your Language Model is Secretly a Reward Model}},
author = {Rafael Rafailov and Archit Sharma and Eric Mitchell and Christopher D. Manning and Stefano Ermon and Chelsea Finn},
year = 2023,
booktitle = {Advances in Neural Information Processing Systems 36: Annual Conference on Neural Information Processing Systems 2023, NeurIPS 2023, New Orleans, LA, USA, December 10 - 16, 2023},
url = {http://papers.nips.cc/paper_files/paper/2023/hash/a85b405ed65c6477a4fe8302b5e06ce7-Abstract-Conference.html},
editor = {Alice Oh and Tristan Naumann and Amir Globerson and Kate Saenko and Moritz Hardt and Sergey Levine},
}""")
}"""
)

model_card = generate_model_card(
base_model=base_model,
Expand Down
Loading