Skip to content

Commit

Permalink
Conversational dataset support for ORPOTrainer (#2184)
Browse files Browse the repository at this point in the history
* default learning rate

* update trainer

* update test

* update script

* update dataset format

* add line in dpo doc

* update orpo doc

* refine implicit/explicit

* update demo chat
  • Loading branch information
qgallouedec authored Oct 11, 2024
1 parent 5375d71 commit d0aa421
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 96 deletions.
7 changes: 4 additions & 3 deletions docs/source/dataset_formats.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ The *format* of a dataset refers to how the data is structured, typically catego
or, with implicit prompt:
<pre><code>{"chosen": [{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is blue."}],
"rejected": [{"role": "user", "content": "What color is the sky?"},
"rejected": [{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is green."}]}</code></pre>
</td>
</tr>
Expand Down Expand Up @@ -175,8 +175,9 @@ A preference dataset is used for tasks where the model is trained to choose betw
Some dataset may not include the `"prompt"` column, in which case the prompt is implicit and directly included in the `"chosen"` and `"rejected"` completions. We recommend using explicit prompts whenever possible.

```python
# explicit prompt
preference_example = {"prompt": "The sky is", "chosen": " blue.", "rejected": " green."} # recommended
# or,
# implicit prompt
preference_example = {"chosen": "The sky is blue.", "rejected": "The sky is green."}
```

Expand All @@ -203,7 +204,7 @@ Choosing the right dataset format depends on the task you are working on and the
| [`KTOTrainer`] | [Unpaired preference](#unpaired-preference) |
| [`NashMDTrainer`] | [Prompt-only](#prompt-only) |
| [`OnlineDPOTrainer`] | [Prompt-only](#prompt-only) |
| [`ORPOTrainer`] | [Preference (explicit prompt)](#preference) |
| [`ORPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
| [`PPOv2Trainer`] | Tokenized language modeling |
| [`RewardTrainer`] | [Preference (implicit prompt recommended)](#preference) |
| [`SFTTrainer`] | [Language modeling](#language-modeling) |
Expand Down
2 changes: 2 additions & 0 deletions docs/source/dpo_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ The best programming language based on these factors is subjective and depends o

DPO requires a [preference dataset](dataset_formats#preference). The [`DPOTrainer`] supports both [conversational](dataset_formats#conversational-dataset-format) and [standard](dataset_formats#standard-dataset-format) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.

Although the [`DPOTrainer`] supports both explicit and implicit prompts, we recommend using explicit prompts. If provided with an implicit prompt dataset, the trainer will automatically extract the prompt from the `"chosen"` and `"rejected"` columns. For more information, refer to the [preference style](dataset_formats#preference) section.

### Special considerations for vision-language models

The [`DPOTrainer`] supports fine-tuning vision-language models (VLMs). For these models, a vision dataset is required. To learn more about the specific format for vision datasets, refer to the [Vision dataset format](dataset_formats#vision-datasets) section.
Expand Down
171 changes: 96 additions & 75 deletions docs/source/orpo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,107 +2,128 @@

[![](https://img.shields.io/badge/All_models-ORPO-blue)](https://huggingface.co/models?other=orpo,trl)

[Odds Ratio Preference Optimization](https://huggingface.co/papers/2403.07691) (ORPO) by Jiwoo Hong, Noah Lee, and James Thorne studies the crucial role of SFT within the context of preference alignment. Using preference data the method posits that a minor penalty for the disfavored generation together with a strong adaption signal to the chosen response via a simple log odds ratio term appended to the NLL loss is sufficient for preference-aligned SFT.
## Overview

Odds Ratio Preference Optimization (ORPO) wa introduced in [ORPO: Monolithic Preference Optimization without Reference Model](https://huggingface.co/papers/2403.07691) by [Jiwoo Hong](https://huggingface.co/JW17), [Noah Lee](https://huggingface.co/nlee-208), and [James Thorne](https://huggingface.co/j6mes).

The abstract from the paper is the following:

> While recent preference alignment algorithms for language models have demonstrated promising results, supervised fine-tuning (SFT) remains imperative for achieving successful convergence. In this paper, we study the crucial role of SFT within the context of preference alignment, emphasizing that a minor penalty for the disfavored generation style is sufficient for preference-aligned SFT. Building on this foundation, we introduce a straightforward and innovative reference model-free monolithic odds ratio preference optimization algorithm, ORPO, eliminating the necessity for an additional preference alignment phase. We demonstrate, both empirically and theoretically, that the odds ratio is a sensible choice for contrasting favored and disfavored styles during SFT across the diverse sizes from 125M to 7B. Specifically, fine-tuning Phi-2 (2.7B), Llama-2 (7B), and Mistral (7B) with ORPO on the UltraFeedback alone surpasses the performance of state-of-the-art language models with more than 7B and 13B parameters: achieving up to 12.20% on AlpacaEval_{2.0} (Figure 1), 66.19% on IFEval (instruction-level loose, Table 6), and 7.32 in MT-Bench (Figure 12). We release code and model checkpoints for Mistral-ORPO-alpha (7B) and Mistral-ORPO-beta (7B).

It studies the crucial role of SFT within the context of preference alignment. Using preference data the method posits that a minor penalty for the disfavored generation together with a strong adaption signal to the chosen response via a simple log odds ratio term appended to the NLL loss is sufficient for preference-aligned SFT.

Thus ORPO is a reference model-free preference optimization algorithm eliminating the necessity for an additional preference alignment phase thus saving compute and memory.

The official code can be found [xfactlab/orpo](https://github.com/xfactlab/orpo).
The official code can be found in [xfactlab/orpo](https://github.com/xfactlab/orpo).

## Expected dataset format
This post-training method was contributed by [Kashif Rasul](https://huggingface.co/kashif), [Lewis Tunstall](https://huggingface.co/lewtun) and [Alvaro Bartolome](https://huggingface.co/alvarobartt).

The ORPO trainer expects a format identical to the DPO trainer, which should include three entries. These entries should be named as follows:

- `prompt`
- `chosen`
- `rejected`

for example:

```py
orpo_dataset_dict = {
"prompt": [
"hello",
"how are you",
"What is your name?",
"What is your name?",
"Which is the best programming language?",
"Which is the best programming language?",
"Which is the best programming language?",
],
"chosen": [
"hi nice to meet you",
"I am fine",
"My name is Mary",
"My name is Mary",
"Python",
"Python",
"Java",
],
"rejected": [
"leave me alone",
"I am not fine",
"Whats it to you?",
"I dont have a name",
"Javascript",
"C++",
"C++",
],
}
```
where the `prompt` contains the context inputs, `chosen` contains the corresponding chosen responses and `rejected` contains the corresponding negative (rejected) responses. Note that a prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays.
## Quick start

This example demonstrates how to train a model using the ORPO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model. We use the preference data from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the data in the dataset here:

<iframe
src="https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized/embed/viewer/default/train?row=0"
frameborder="0"
width="100%"
height="560px"
></iframe>
## Expected model format
The ORPO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function.
Below is the script to train the model:

## Using the `ORPOTrainer`
For a detailed example have a look at the `examples/scripts/orpo.py` script. At a high level we need to initialize the `ORPOTrainer` with a `model` we wish to train. **Note that ORPOTrainer eliminates the need to use the reference model, simplifying the optimization process.** The `beta` refers to the hyperparameter `lambda` in eq. (6) of the paper and refers to the weighting of the relative odd ratio loss in the standard cross-entropy loss used for SFT.
```python
# train_orpo.py
from datasets import load_dataset
from trl import ORPOConfig, ORPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer

```py
training_args = ORPOConfig(
beta=0.1, # the lambda/alpha hyperparameter in the paper/code
)
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
train_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")

orpo_trainer = ORPOTrainer(
model,
args=training_args,
train_dataset=train_dataset,
processing_class=tokenizer,
)
training_args = ORPOConfig(output_dir="Qwen2-0.5B-ORPO", logging_steps=10)
trainer = ORPOTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
trainer.train()
```
After this one can then call:

```py
orpo_trainer.train()
Execute the script using the following command:

```bash
accelerate launch train_orpo.py
```

### For Mixture of Experts Models: Enabling the auxiliary loss
Distributed across 8 GPUs, the training takes approximately 30 minutes. You can verify the training progress by checking the reward graph. An increasing trend in the reward margin indicates that the model is improving and generating better responses over time.

MOEs are the most efficient if the load is about equally distributed between experts.
To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.
![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/orpo-qwen2-reward-margin.png)

This option is enabled by setting `output_router_logits=True` in the model config (e.g. MixtralConfig).
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: 0.001).
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-ORPO) performs, you can use the [TRL Chat CLI](clis#chat-interface).

## Logging
<pre><code>$ trl chat --model_name_or_path trl-lib/Qwen2-0.5B-ORPO
<strong><span style="color: red;">&lt;quentin_gallouedec&gt;:</span></strong>
What is the best programming language?

While training and evaluating we record the following reward metrics:
<strong><span style="color: blue;">&lt;trl-lib/Qwen2-0.5B-ORPO&gt;:</span></strong>
It's challenging to determine the best programming language as no one language is perfect, as the complexity of a task and the type of project are significant factors. Some popular languages include Java, Python, JavaScript, and
C++. If you have specific needs or requirements for a specific project, it's important to choose the language that best suits those needs.

* `rewards/chosen`: the mean log probabilities of the policy model for the chosen responses scaled by beta
* `rewards/rejected`: the mean log probabilities of the policy model for the rejected responses scaled by beta
* `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards
* `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards
Here are some other factors to consider when choosing a programming language for a project:

<strong><span style="color: green;">• Language proficiency:</span></strong> A good programming language is more likely to be easy to understand and use, and will allow developers to collaborate on projects more efficiently.
<strong><span style="color: green;">• Ease of use:</span></strong> There are tools and libraries available to make programming more accessible, so developers should choose a language that can help them get started easier.
<strong><span style="color: green;">• Code readability:</span></strong> A clear and concise codebase should be easy to read and understand, especially when working with large projects.
<strong><span style="color: green;">• Tool and framework support:</span></strong> There are numerous libraries available for Python, Java, and JavaScript, along with tools like IDEs and static code analysis tools.
<strong><span style="color: green;">• Accessibility:</span></strong> Some languages and tools have features that make them more accessible to developers with disabilities, such as support for screen readers.
<strong><span style="color: green;">• Version control:</span></strong> As your projects grow and complexity increases, version control tools can be beneficial for tracking changes.

</code></pre>

## Expected dataset format

* `log_odds_chosen`: the mean log odds ratio of the chosen responses over the rejected responses
ORPO requires a [preference dataset](dataset_formats#preference). The [`ORPOTrainer`] supports both [conversational](dataset_formats#conversational-dataset-format) and [standard](dataset_formats#standard-dataset-format) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.

* `log_odds_ratio`: the mean of the `log(sigmoid(log_odds_chosen))`
Although the [`ORPOTrainer`] supports both explicit and implicit prompts, we recommend using explicit prompts. If provided with an implicit prompt dataset, the trainer will automatically extract the prompt from the `"chosen"` and `"rejected"` columns. For more information, refer to the [preference style](dataset_formats#preference) section.

* `nll_loss`: the mean negative log likelihood loss from the SFT part of the loss over chosen responses
## Example script

We provide an example script to train a model using the ORPO method. The script is available in [`examples/scripts/orpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/orpo.py)

To test the ORPO script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized), run the following command:

```bash
accelerate launch examples/scripts/orpo.py \
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
--dataset_name trl-lib/ultrafeedback_binarized \
--num_train_epochs 1 \
--logging_steps 25 \
--output_dir Qwen2-0.5B-DPO
```

## Usage tips

### For Mixture of Experts Models: Enabling the auxiliary loss

MOEs are the most efficient if the load is about equally distributed between experts.
To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.

This option is enabled by setting `output_router_logits=True` in the model config (e.g. [`~transformers.MixtralConfig`]).
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: `0.001`) in the model config.

## Logged metrics

While training and evaluating we record the following reward metrics:

- `rewards/chosen`: the mean log probabilities of the policy model for the chosen responses scaled by beta
- `rewards/rejected`: the mean log probabilities of the policy model for the rejected responses scaled by beta
- `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards
- `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards
- `log_odds_chosen`: the mean log odds ratio of the chosen responses over the rejected responses
- `log_odds_ratio`: the mean of the `log(sigmoid(log_odds_chosen))`
- `nll_loss`: the mean negative log likelihood loss from the SFT part of the loss over chosen responses

## ORPOTrainer

[[autodoc]] ORPOTrainer


## ORPOConfig

[[autodoc]] ORPOConfig
12 changes: 0 additions & 12 deletions examples/scripts/orpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@

from dataclasses import dataclass, field

from accelerate import PartialState
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser

Expand Down Expand Up @@ -93,17 +92,6 @@ class ScriptArguments:
if tokenizer.chat_template is None:
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE

def process(row):
row["prompt"] = tokenizer.apply_chat_template(row["chosen"][:-1], tokenize=False)
row["chosen"] = tokenizer.apply_chat_template([row["chosen"][-1]], tokenize=False)
row["rejected"] = tokenizer.apply_chat_template([row["rejected"][-1]], tokenize=False)
return row

# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
dataset = dataset.map(process, num_proc=training_args.dataset_num_proc)

################
# Training
################
Expand Down
25 changes: 20 additions & 5 deletions tests/test_orpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,15 @@ def setUp(self):
self.t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
self.t5_tokenizer = AutoTokenizer.from_pretrained(model_id)

@parameterized.expand([["gpt2"], ["t5"]])
def test_orpo_trainer(self, name):
@parameterized.expand(
[
("gpt2", "standard_preference"),
("t5", "standard_implicit_prompt_preference"),
("gpt2", "conversational_preference"),
("t5", "conversational_implicit_prompt_preference"),
]
)
def test_orpo_trainer(self, name, config_name):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = ORPOConfig(
output_dir=tmp_dir,
Expand All @@ -50,7 +57,7 @@ def test_orpo_trainer(self, name):
report_to="none",
)

dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference")
dummy_dataset = load_dataset("trl-internal-testing/zen", config_name)

if name == "gpt2":
model = self.model
Expand Down Expand Up @@ -82,7 +89,15 @@ def test_orpo_trainer(self, name):
assert not torch.equal(param, new_param)

@require_peft
def test_orpo_trainer_with_lora(self):
@parameterized.expand(
[
("standard_preference",),
("standard_implicit_prompt_preference",),
("conversational_preference",),
("conversational_implicit_prompt_preference",),
]
)
def test_orpo_trainer_with_lora(self, config_name):
from peft import LoraConfig

lora_config = LoraConfig(
Expand All @@ -106,7 +121,7 @@ def test_orpo_trainer_with_lora(self):
report_to="none",
)

dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference")
dummy_dataset = load_dataset("trl-internal-testing/zen", config_name)

trainer = ORPOTrainer(
model=self.model,
Expand Down
4 changes: 4 additions & 0 deletions trl/trainer/orpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ class ORPOConfig(TrainingArguments):
command line.
Parameters:
learning_rate (`float`, *optional*, defaults to `1e-6`):
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
[`~transformers.TrainingArguments`].
max_length (`Optional[int]`, *optional*, defaults to `None`):
Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
to use the default data collator.
Expand Down Expand Up @@ -59,6 +62,7 @@ class ORPOConfig(TrainingArguments):
Number of processes to use for processing the dataset.
"""

learning_rate: float = 1e-6
max_length: Optional[int] = None
max_prompt_length: Optional[int] = None
max_completion_length: Optional[int] = None
Expand Down
Loading

0 comments on commit d0aa421

Please sign in to comment.