Skip to content

Commit e7b37d4

Browse files
shirinyamaniqgallouedecedbeeching
authored
🔥 [Refactor] RLOOTrainer (#3801)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Edward Beeching <edbeeching@users.noreply.github.com>
1 parent b7676d1 commit e7b37d4

21 files changed

+4771
-2103
lines changed

docs/source/clis.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ Currently supported commands are:
99
- `trl dpo`: fine-tune a LLM with DPO
1010
- `trl grpo`: fine-tune a LLM with GRPO
1111
- `trl kto`: fine-tune a LLM with KTO
12+
- `trl rloo`: fine-tune a LLM with RLOO
1213
- `trl sft`: fine-tune a LLM with SFT
1314

1415
#### Other Commands

docs/source/dataset_formats.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,7 @@ Choosing the right dataset type depends on the task you are working on and the s
405405
| [`PPOTrainer`] | Tokenized language modeling |
406406
| [`PRMTrainer`] | [Stepwise supervision](#stepwise-supervision) |
407407
| [`RewardTrainer`] | [Preference (implicit prompt recommended)](#preference) |
408+
| [`RLOOTrainer`] | [Prompt-only](#prompt-only) |
408409
| [`SFTTrainer`] | [Language modeling](#language-modeling) or [Prompt-completion](#prompt-completion) |
409410
| [`XPOTrainer`] | [Prompt-only](#prompt-only) |
410411

docs/source/example_overview.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,7 @@ Scripts can be used as examples of how to use TRL trainers. They are located in
5656
| [`examples/scripts/ppo/ppo_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo_tldr.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a model to improve its ability to generate TL;DR summaries. |
5757
| [`examples/scripts/prm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/prm.py) | This script shows how to use the [`PRMTrainer`] to fine-tune a Process-supervised Reward Model (PRM). |
5858
| [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/reward_modeling.py) | This script shows how to use the [`RewardTrainer`] to train a Outcome Reward Model (ORM) on your own dataset. |
59-
| [`examples/scripts/rloo/rloo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/rloo/rloo.py) | This script shows how to use the [`RLOOTrainer`] to fine-tune a model to improve its ability to continue text with positive sentiment or physically descriptive language. |
60-
| [`examples/scripts/rloo/rloo_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/rloo/rloo_tldr.py) | This script shows how to use the [`RLOOTrainer`] to fine-tune a model to improve its ability to generate TL;DR summaries. |
59+
| [`examples/scripts/rloo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/rloo.py) | This script shows how to use the [`RLOOTrainer`] to fine-tune a model to improve its ability to solve math questions. |
6160
| [`examples/scripts/sft.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/sft.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a model. |
6261
| [`examples/scripts/sft_gemma3.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_gemma3.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Gemma 3 model. |
6362
| [`examples/scripts/sft_video_llm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_video_llm.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Video Language Model. |

docs/source/grpo_trainer.md

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ This post-training method was contributed by [Quentin Gallouédec](https://huggi
1414

1515
## Quick start
1616

17-
This example demonstrates how to train a model using the GRPO method. We train a [Qwen 0.5B Instruct model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) with the prompts from the [TLDR dataset](https://huggingface.co/datasets/trl-lib/tldr) (completion column is ignored!). You can view the data in the dataset here:
17+
This example demonstrates how to train a model using the GRPO method. We train a [Qwen 0.5B Instruct model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) with the prompts from the [UltraFeedback prompts dataset](https://huggingface.co/datasets/trl-lib/ultrafeedback-prompt). You can view the data in the dataset here:
1818

1919
<iframe
20-
src="https://huggingface.co/datasets/trl-lib/tldr/embed/viewer/default/train?row=0"
20+
src="https://huggingface.co/datasets/trl-lib/ultrafeedback-prompt/embed/viewer/default/train?row=0"
2121
frameborder="0"
2222
width="100%"
2323
height="560px"
@@ -30,16 +30,18 @@ Below is the script to train the model.
3030
from datasets import load_dataset
3131
from trl import GRPOConfig, GRPOTrainer
3232

33-
dataset = load_dataset("trl-lib/tldr", split="train")
33+
dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train")
3434

35-
# Define the reward function, which rewards completions that are close to 20 characters
36-
def reward_len(completions, **kwargs):
37-
return [-abs(20 - len(completion)) for completion in completions]
35+
# Dummy reward function for demonstration purposes
36+
def reward_num_unique_letters(completions, **kwargs):
37+
"""Reward function that rewards completions with more unique letters."""
38+
completion_contents = [completion[0]["content"] for completion in completions]
39+
return [float(len(set(content))) for content in completion_contents]
3840

3941
training_args = GRPOConfig(output_dir="Qwen2-0.5B-GRPO")
4042
trainer = GRPOTrainer(
4143
model="Qwen/Qwen2-0.5B-Instruct",
42-
reward_funcs=reward_len,
44+
reward_funcs=reward_num_unique_letters,
4345
args=training_args,
4446
train_dataset=dataset,
4547
)
@@ -68,7 +70,7 @@ At each training step, we sample a batch of prompts and generate a set of \\( G
6870

6971
### Computing the advantage
7072

71-
For each of the \\( G \\) sequences, we compute the reward using a reward model. To align with the comparative nature of reward models—typically trained on datasets of comparisons between outputs for the same question—the advantage is calculated to reflect these relative comparisons. It is normalized as follows:
73+
For each of the \\( G \\) sequences, we compute the reward using a reward model or reward function. To align with the comparative nature of reward models—typically trained on datasets of comparisons between outputs for the same question—the advantage is calculated to reflect these relative comparisons. It is normalized as follows:
7274

7375
$$\hat{A}_{i,t} = \frac{r_i - \text{mean}(\mathbf{r})}{\text{std}(\mathbf{r})}$$
7476

docs/source/paper_index.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,24 @@ training_args = DPOConfig(
103103
)
104104
```
105105

106+
## Back to Basics: Revisiting REINFORCE Style Optimization for Learning from Human Feedback in LLMs
107+
108+
**📜 Paper**: https://huggingface.co/papers/2402.14740
109+
110+
RLOO is a variant of REINFORCE that reduces variance by using leave-one-out baselines. It computes rewards by comparing each sample against the average of all other samples in the batch, providing more stable gradients than standard REINFORCE. To reproduce the paper's setting, use this configuration:
111+
112+
```python
113+
from trl import RLOOConfig
114+
115+
training_args = RLOOConfig(
116+
per_device_train_batch_size=512, # section C Training Detail of the paper
117+
steps_per_generation=2 # section C Training Detail of the paper
118+
beta=0.03 # section C Training Detail of the paper
119+
num_generations=2, # experiments of paper different num_generations={2,4}
120+
learning_rate=1e-6 # section C Training Detail of the paper
121+
)
122+
```
123+
106124
## AlphaPO -- Reward shape matters for LLM alignment
107125

108126
**📜 Paper**: https://huggingface.co/papers/2501.03884

docs/source/rewards.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Reward Functions
22

3-
This module contains some useful reward functions, primarily intended for use with the [`GRPOTrainer`].
3+
This module contains some useful reward functions, primarily intended for use with the [`GRPOTrainer`] and [`RLOOTrainer`].
44

55
## Format rewards
66

0 commit comments

Comments
 (0)