Skip to content

Commit 07f3c95

Browse files
Move OnlineDPOTrainer to experimental module (#4473)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
1 parent 4cb1a25 commit 07f3c95

20 files changed

+1952
-1878
lines changed

docs/source/_toctree.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,6 @@
5858
- sections: # Sorted alphabetically
5959
- local: dpo_trainer
6060
title: DPO
61-
- local: online_dpo_trainer
62-
title: Online DPO
6361
- local: grpo_trainer
6462
title: GRPO
6563
- local: kto_trainer
@@ -111,6 +109,8 @@
111109
title: MiniLLM
112110
- local: nash_md_trainer
113111
title: Nash-MD
112+
- local: online_dpo_trainer
113+
title: Online DPO
114114
- local: orpo_trainer
115115
title: ORPO
116116
- local: papo_trainer

docs/source/dataset_formats.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,14 +390,14 @@ Choosing the right dataset type depends on the task you are working on and the s
390390
| [`DPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
391391
| [`GRPOTrainer`] | [Prompt-only](#prompt-only) |
392392
| [`KTOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) |
393-
| [`OnlineDPOTrainer`] | [Prompt-only](#prompt-only) |
394393
| [`RewardTrainer`] | [Preference (implicit prompt recommended)](#preference) |
395394
| [`RLOOTrainer`] | [Prompt-only](#prompt-only) |
396395
| [`SFTTrainer`] | [Language modeling](#language-modeling) or [Prompt-completion](#prompt-completion) |
397396
| [`experimental.bco.BCOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) |
398397
| [`experimental.cpo.CPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
399398
| [`experimental.gkd.GKDTrainer`] | [Prompt-completion](#prompt-completion) |
400399
| [`experimental.nash_md.NashMDTrainer`] | [Prompt-only](#prompt-only) |
400+
| [`experimental.online_dpo.OnlineDPOTrainer`] | [Prompt-only](#prompt-only) |
401401
| [`experimental.orpo.ORPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
402402
| [`experimental.ppo.PPOTrainer`] | Tokenized language modeling |
403403
| [`experimental.prm.PRMTrainer`] | [Stepwise supervision](#stepwise-supervision) |

docs/source/example_overview.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ Scripts are maintained in the [`trl/scripts`](https://github.com/huggingface/trl
5353
| [`examples/scripts/kto.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/kto.py) | This script shows how to use the [`KTOTrainer`] to fine-tune a model. |
5454
| [`examples/scripts/mpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/mpo_vlm.py) | This script shows how to use MPO via the [`DPOTrainer`] to align a model based on preferences using the [HuggingFaceH4/rlaif-v_formatted](https://huggingface.co/datasets/HuggingFaceH4/rlaif-v_formatted) dataset and a set of loss weights with weights. |
5555
| [`examples/scripts/nash_md.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/nash_md.py) | This script shows how to use the [`experimental.nash_md.NashMDTrainer`] to fine-tune a model. |
56-
| [`examples/scripts/online_dpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/online_dpo.py) | This script shows how to use the [`OnlineDPOTrainer`] to fine-tune a model. |
57-
| [`examples/scripts/online_dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/online_dpo_vlm.py) | This script shows how to use the [`OnlineDPOTrainer`] to fine-tune a a Vision Language Model. |
56+
| [`examples/scripts/online_dpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/online_dpo.py) | This script shows how to use the [`experimental.online_dpo.OnlineDPOTrainer`] to fine-tune a model. |
57+
| [`examples/scripts/online_dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/online_dpo_vlm.py) | This script shows how to use the [`experimental.online_dpo.OnlineDPOTrainer`] to fine-tune a a Vision Language Model. |
5858
| [`examples/scripts/openenv/browsergym.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/openenv/browsergym.py) | Simple script to run GRPO training via the [`GRPOTrainer`] with OpenEnv's BrowserGym environment and vLLM |
5959
| [`examples/scripts/openenv/catch.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/openenv/catch.py) | Simple script to run GRPO training via the [`GRPOTrainer`] with OpenEnv's Catch environment (OpenSpiel) and vLLM |
6060
| [`examples/scripts/openenv/echo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/openenv/echo.py) | Simple script to run GRPO training via the [`GRPOTrainer`] with OpenEnv's Echo environment and vLLM. |

docs/source/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ Below is the current list of TRL trainers, organized by method type (⚡️ = vL
2424

2525
- [`GRPOTrainer`] ⚡️
2626
- [`RLOOTrainer`] ⚡️
27-
- [`OnlineDPOTrainer`] ⚡️
2827
- [`experimental.nash_md.NashMDTrainer`] 🧪 ⚡️
28+
- [`experimental.online_dpo.OnlineDPOTrainer`] 🧪 ⚡️
2929
- [`experimental.ppo.PPOTrainer`] 🧪
3030
- [`experimental.xpo.XPOTrainer`] 🧪 ⚡️
3131

docs/source/online_dpo_trainer.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ Below is the script to train the model:
2828
```python
2929
# train_online_dpo.py
3030
from datasets import load_dataset
31-
from trl import OnlineDPOConfig, OnlineDPOTrainer
3231
from trl.experimental.judges import PairRMJudge
32+
from trl.experimental.online_dpo import OnlineDPOConfig, OnlineDPOTrainer
3333
from transformers import AutoModelForCausalLM, AutoTokenizer
3434

3535
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
@@ -66,7 +66,7 @@ The best programming language depends on your specific needs and priorities. Som
6666

6767
## Expected dataset type
6868

69-
Online DPO only requires a [prompt-only dataset](dataset_formats#prompt-only) (unlike offline DPO, that expects [preference dataset](dataset_formats#preference)). The [`OnlineDPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
69+
Online DPO only requires a [prompt-only dataset](dataset_formats#prompt-only) (unlike offline DPO, that expects [preference dataset](dataset_formats#preference)). The [`experimental.online_dpo.OnlineDPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
7070

7171
## Usage tips
7272

@@ -93,7 +93,7 @@ Instead of a judge, you can chose to use a reward model -- see [Reward Bench](ht
9393

9494
### Encourage EOS token generation
9595

96-
When using a reward model, we may want the model to generate completions within a given length. During training, the model will generate completions up to the maximum length specified in the `max_new_tokens` argument of [`OnlineDPOConfig`]. If you want to penalize the model for not generating an EOS token before reaching the maximum length, you can use the `missing_eos_penalty` argument of [`OnlineDPOConfig`]:
96+
When using a reward model, we may want the model to generate completions within a given length. During training, the model will generate completions up to the maximum length specified in the `max_new_tokens` argument of [`experimental.online_dpo.OnlineDPOConfig`]. If you want to penalize the model for not generating an EOS token before reaching the maximum length, you can use the `missing_eos_penalty` argument of [`experimental.online_dpo.OnlineDPOConfig`]:
9797

9898
```python
9999
training_args = OnlineDPOConfig(..., max_new_tokens=128, missing_eos_penalty=1.0)
@@ -147,7 +147,7 @@ While training and evaluating, we record the following reward metrics. Here is a
147147
* `logps/chosen`: The mean log probabilities of the chosen completions.
148148
* `logps/rejected`: The mean log probabilities of the rejected completions.
149149
* `val/contain_eos_token`: The fraction of completions which contain an EOS token.
150-
* `beta`: The parameter that controls the weight of the loss term representing the deviation from the reference model. Typically fixed, but can be made dynamic by passing a list to [`OnlineDPOConfig`].
150+
* `beta`: The parameter that controls the weight of the loss term representing the deviation from the reference model. Typically fixed, but can be made dynamic by passing a list to [`experimental.online_dpo.OnlineDPOConfig`].
151151

152152
## Benchmark experiments
153153

@@ -261,11 +261,11 @@ The online DPO checkpoint gets increasingly more win rate as we scale up the mod
261261

262262
## OnlineDPOTrainer
263263

264-
[[autodoc]] OnlineDPOTrainer
264+
[[autodoc]] experimental.online_dpo.OnlineDPOTrainer
265265
- train
266266
- save_model
267267
- push_to_hub
268268

269269
## OnlineDPOConfig
270270

271-
[[autodoc]] OnlineDPOConfig
271+
[[autodoc]] experimental.online_dpo.OnlineDPOConfig

docs/source/reducing_memory_usage.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ training_args = GRPOConfig(..., ds3_gather_for_generation=False)
265265
<hfoption id="Online DPO">
266266

267267
```python
268-
from trl import OnlineDPOConfig
268+
from trl.experimental.online_dpo import OnlineDPOConfig
269269

270270
training_args = OnlineDPOConfig(..., ds3_gather_for_generation=False)
271271
```

docs/source/speeding_up_training.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ pip install trl[vllm]
2020
Then, enable it by passing `use_vllm=True` in the training arguments.
2121

2222
```python
23-
from trl import OnlineDPOConfig
23+
from trl.experimental.online_dpo import OnlineDPOConfig
2424

2525
training_args = OnlineDPOConfig(..., use_vllm=True)
2626
```

docs/source/vllm_integration.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@ This document will guide you through the process of using vLLM with TRL for fast
99
> The following trainers currently support generation with vLLM:
1010
>
1111
> - [`GRPOTrainer`]
12-
> - [`OnlineDPOTrainer`]
1312
> - [`RLOOTrainer`]
1413
> - [`experimental.nash_md.NashMDTrainer`]
14+
> - [`experimental.online_dpo.OnlineDPOTrainer`]
1515
> - [`experimental.xpo.XPOTrainer`]
1616
1717
## 🚀 How can I use vLLM with TRL to speed up training?
@@ -65,7 +65,7 @@ trainer.train()
6565

6666
```python
6767
from datasets import load_dataset
68-
from trl import OnlineDPOTrainer, OnlineDPOConfig
68+
from trl.experimental.online_dpo import OnlineDPOConfig, OnlineDPOTrainer
6969
from trl.rewards import accuracy_reward
7070

7171
dataset = load_dataset("trl-lib/DeepMath-103K", split="train")
@@ -316,7 +316,7 @@ training_args = GRPOConfig(
316316
<hfoption id="OnlineDPO">
317317

318318
```python
319-
from trl import OnlineDPOConfig
319+
from trl.experimental.online_dpo import OnlineDPOConfig
320320

321321
training_args = OnlineDPOConfig(
322322
...,
@@ -391,7 +391,7 @@ training_args = GRPOConfig(
391391
<hfoption id="OnlineDPO">
392392

393393
```python
394-
from trl import OnlineDPOConfig
394+
from trl.experimental.online_dpo import OnlineDPOConfig
395395

396396
training_args = OnlineDPOConfig(
397397
...,

examples/scripts/online_dpo.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,14 @@
5858
from trl import (
5959
LogCompletionsCallback,
6060
ModelConfig,
61-
OnlineDPOConfig,
62-
OnlineDPOTrainer,
6361
ScriptArguments,
6462
TrlParser,
6563
get_kbit_device_map,
6664
get_peft_config,
6765
get_quantization_config,
6866
)
6967
from trl.experimental.judges import HfPairwiseJudge, OpenAIPairwiseJudge, PairRMJudge
68+
from trl.experimental.online_dpo import OnlineDPOConfig, OnlineDPOTrainer
7069

7170

7271
# Enable logging in a Hugging Face Space

examples/scripts/online_dpo_vlm.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,14 +92,13 @@
9292
from trl import (
9393
LogCompletionsCallback,
9494
ModelConfig,
95-
OnlineDPOConfig,
96-
OnlineDPOTrainer,
9795
ScriptArguments,
9896
TrlParser,
9997
get_kbit_device_map,
10098
get_peft_config,
10199
get_quantization_config,
102100
)
101+
from trl.experimental.online_dpo import OnlineDPOConfig, OnlineDPOTrainer
103102
from trl.rewards import accuracy_reward, think_format_reward
104103

105104

0 commit comments

Comments
 (0)