You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/source/example_overview.md
+2-2Lines changed: 2 additions & 2 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -53,8 +53,8 @@ Scripts are maintained in the [`trl/scripts`](https://github.com/huggingface/trl
53
53
|[`examples/scripts/kto.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/kto.py)| This script shows how to use the [`KTOTrainer`] to fine-tune a model. |
54
54
|[`examples/scripts/mpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/mpo_vlm.py)| This script shows how to use MPO via the [`DPOTrainer`] to align a model based on preferences using the [HuggingFaceH4/rlaif-v_formatted](https://huggingface.co/datasets/HuggingFaceH4/rlaif-v_formatted) dataset and a set of loss weights with weights. |
55
55
|[`examples/scripts/nash_md.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/nash_md.py)| This script shows how to use the [`experimental.nash_md.NashMDTrainer`] to fine-tune a model. |
56
-
|[`examples/scripts/online_dpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/online_dpo.py)| This script shows how to use the [`OnlineDPOTrainer`] to fine-tune a model. |
57
-
|[`examples/scripts/online_dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/online_dpo_vlm.py)| This script shows how to use the [`OnlineDPOTrainer`] to fine-tune a a Vision Language Model. |
56
+
|[`examples/scripts/online_dpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/online_dpo.py)| This script shows how to use the [`experimental.online_dpo.OnlineDPOTrainer`] to fine-tune a model. |
57
+
|[`examples/scripts/online_dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/online_dpo_vlm.py)| This script shows how to use the [`experimental.online_dpo.OnlineDPOTrainer`] to fine-tune a a Vision Language Model. |
58
58
|[`examples/scripts/openenv/browsergym.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/openenv/browsergym.py)| Simple script to run GRPO training via the [`GRPOTrainer`] with OpenEnv's BrowserGym environment and vLLM |
59
59
|[`examples/scripts/openenv/catch.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/openenv/catch.py)| Simple script to run GRPO training via the [`GRPOTrainer`] with OpenEnv's Catch environment (OpenSpiel) and vLLM |
60
60
|[`examples/scripts/openenv/echo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/openenv/echo.py)| Simple script to run GRPO training via the [`GRPOTrainer`] with OpenEnv's Echo environment and vLLM. |
Copy file name to clipboardExpand all lines: docs/source/online_dpo_trainer.md
+6-6Lines changed: 6 additions & 6 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -28,8 +28,8 @@ Below is the script to train the model:
28
28
```python
29
29
# train_online_dpo.py
30
30
from datasets import load_dataset
31
-
from trl import OnlineDPOConfig, OnlineDPOTrainer
32
31
from trl.experimental.judges import PairRMJudge
32
+
from trl.experimental.online_dpo import OnlineDPOConfig, OnlineDPOTrainer
33
33
from transformers import AutoModelForCausalLM, AutoTokenizer
34
34
35
35
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
@@ -66,7 +66,7 @@ The best programming language depends on your specific needs and priorities. Som
66
66
67
67
## Expected dataset type
68
68
69
-
Online DPO only requires a [prompt-only dataset](dataset_formats#prompt-only) (unlike offline DPO, that expects [preference dataset](dataset_formats#preference)). The [`OnlineDPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
69
+
Online DPO only requires a [prompt-only dataset](dataset_formats#prompt-only) (unlike offline DPO, that expects [preference dataset](dataset_formats#preference)). The [`experimental.online_dpo.OnlineDPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
70
70
71
71
## Usage tips
72
72
@@ -93,7 +93,7 @@ Instead of a judge, you can chose to use a reward model -- see [Reward Bench](ht
93
93
94
94
### Encourage EOS token generation
95
95
96
-
When using a reward model, we may want the model to generate completions within a given length. During training, the model will generate completions up to the maximum length specified in the `max_new_tokens` argument of [`OnlineDPOConfig`]. If you want to penalize the model for not generating an EOS token before reaching the maximum length, you can use the `missing_eos_penalty` argument of [`OnlineDPOConfig`]:
96
+
When using a reward model, we may want the model to generate completions within a given length. During training, the model will generate completions up to the maximum length specified in the `max_new_tokens` argument of [`experimental.online_dpo.OnlineDPOConfig`]. If you want to penalize the model for not generating an EOS token before reaching the maximum length, you can use the `missing_eos_penalty` argument of [`experimental.online_dpo.OnlineDPOConfig`]:
@@ -147,7 +147,7 @@ While training and evaluating, we record the following reward metrics. Here is a
147
147
*`logps/chosen`: The mean log probabilities of the chosen completions.
148
148
*`logps/rejected`: The mean log probabilities of the rejected completions.
149
149
*`val/contain_eos_token`: The fraction of completions which contain an EOS token.
150
-
*`beta`: The parameter that controls the weight of the loss term representing the deviation from the reference model. Typically fixed, but can be made dynamic by passing a list to [`OnlineDPOConfig`].
150
+
*`beta`: The parameter that controls the weight of the loss term representing the deviation from the reference model. Typically fixed, but can be made dynamic by passing a list to [`experimental.online_dpo.OnlineDPOConfig`].
151
151
152
152
## Benchmark experiments
153
153
@@ -261,11 +261,11 @@ The online DPO checkpoint gets increasingly more win rate as we scale up the mod
0 commit comments