Skip to content

Commit 1bcfc50

Browse files
behroozazarkhaliliInvidia19qgallouedec
authored
Move XPOTrainer to trl.experimental.xpo (#4485)
Co-authored-by: Invidia19 <54266187+Invidia19@users.noreply.github.com> Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
1 parent 37942bc commit 1bcfc50

File tree

15 files changed

+667
-590
lines changed

15 files changed

+667
-590
lines changed

docs/source/_toctree.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,6 @@
8282
title: RLOO
8383
- local: sft_trainer
8484
title: SFT
85-
- local: xpo_trainer
86-
title: XPO
8785
title: Trainers
8886
- local: models
8987
title: Model Classes
@@ -119,6 +117,8 @@
119117
title: GSPO-token
120118
- local: papo_trainer
121119
title: PAPO
120+
- local: xpo_trainer
121+
title: XPO
122122
- local: openenv
123123
title: OpenEnv Integration
124124
title: Experimental

docs/source/dataset_formats.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,7 @@ Choosing the right dataset type depends on the task you are working on and the s
401401
| [`RewardTrainer`] | [Preference (implicit prompt recommended)](#preference) |
402402
| [`RLOOTrainer`] | [Prompt-only](#prompt-only) |
403403
| [`SFTTrainer`] | [Language modeling](#language-modeling) or [Prompt-completion](#prompt-completion) |
404-
| [`XPOTrainer`] | [Prompt-only](#prompt-only) |
404+
| [`experimental.xpo.XPOTrainer`] | [Prompt-only](#prompt-only) |
405405

406406
## Using any dataset with TRL: preprocessing and conversion
407407

docs/source/example_overview.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ Scripts are maintained in the [`trl/scripts`](https://github.com/huggingface/trl
6666
| [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Vision Language Model in a chat setting. The script has only been tested with [LLaVA 1.5](https://huggingface.co/llava-hf/llava-1.5-7b-hf), [LLaVA 1.6](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf), and [Llama-3.2-11B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct) models, so users may see unexpected behaviour in other model architectures. |
6767
| [`examples/scripts/sft_vlm_gemma3.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm_gemma3.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Gemma 3 model on vision to text tasks. |
6868
| [`examples/scripts/sft_vlm_smol_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm_smol_vlm.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a SmolVLM model. |
69-
| [`examples/scripts/xpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/xpo.py) | This script shows how to use the [`XPOTrainer`] to fine-tune a model. |
69+
| [`examples/scripts/xpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/xpo.py) | This script shows how to use the [`experimental.xpo.XPOTrainer`] to fine-tune a model. |
7070

7171
## Distributed Training (for scripts)
7272

docs/source/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ Below is the current list of TRL trainers, organized by method type (⚡️ = vL
2626
- [`RLOOTrainer`] ⚡️
2727
- [`OnlineDPOTrainer`] ⚡️
2828
- [`NashMDTrainer`] ⚡️
29-
- [`XPOTrainer`] ⚡️
29+
- [`experimental.xpo.XPOTrainer`] 🧪 ⚡️
3030
- [`PPOTrainer`]
3131

3232
### Reward modeling

docs/source/vllm_integration.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ This document will guide you through the process of using vLLM with TRL for fast
1111
> - [`GRPOTrainer`]
1212
> - [`OnlineDPOTrainer`]
1313
> - [`NashMDTrainer`]
14-
> - [`XPOTrainer`]
14+
> - [`experimental.xpo.XPOTrainer`]
1515
> - [`RLOOTrainer`]
1616
1717
## 🚀 How can I use vLLM with TRL to speed up training?
@@ -135,7 +135,7 @@ trainer.train()
135135

136136
```python
137137
from datasets import load_dataset
138-
from trl import XPOTrainer, XPOConfig
138+
from trl.experimental.xpo import XPOTrainer, XPOConfig
139139

140140
dataset = load_dataset("trl-lib/tldr", split="train")
141141

@@ -392,7 +392,7 @@ training_args = NashMDConfig(
392392
<hfoption id="XPO">
393393

394394
```python
395-
from trl import XPOConfig
395+
from trl.experimental.xpo import XPOConfig
396396

397397
training_args = XPOConfig(
398398
...,
@@ -467,7 +467,7 @@ training_args = NashMDConfig(
467467
<hfoption id="XPO">
468468

469469
```python
470-
from trl import XPOConfig
470+
from trl.experimental.xpo import XPOConfig
471471

472472
training_args = XPOConfig(
473473
...,

docs/source/xpo_trainer.md

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ The abstract from the paper is the following:
1212

1313
This post-training method was contributed by [Kashif Rasul](https://huggingface.co/kashif), [Quentin Gallouédec](https://huggingface.co/qgallouedec) and [Lewis Tunstall](https://huggingface.co/lewtun).
1414

15+
> [!NOTE]
16+
> XPO is currently experimental. The API may change without notice while the feature is iterated on.
17+
1518
## Quick start
1619

1720
This example demonstrates how to train a model using the XPO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model and [`PairRMJudge`] as a judge. We use the prompts from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the prompts in the dataset here:
@@ -27,7 +30,8 @@ Below is the script to train the model:
2730
```python
2831
# train_xpo.py
2932
from datasets import load_dataset
30-
from trl import PairRMJudge, XPOConfig, XPOTrainer
33+
from trl import PairRMJudge
34+
from trl.experimental.xpo import XPOConfig, XPOTrainer
3135
from transformers import AutoModelForCausalLM, AutoTokenizer
3236

3337
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
@@ -62,7 +66,7 @@ The best programming language depends on individual preferences and familiarity
6266

6367
## Expected dataset type
6468

65-
XPO requires a [prompt-only dataset](dataset_formats#prompt-only). The [`XPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
69+
XPO requires a [prompt-only dataset](dataset_formats#prompt-only). The [`experimental.xpo.XPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
6670

6771
## Usage tips
6872

@@ -89,7 +93,7 @@ Instead of a judge, you can chose to use a reward model -- see [Reward Bench](ht
8993
9094
### Encourage EOS token generation
9195

92-
When using a reward model, we may want the model to generate completions within a given length. During training, the model will generate completions up to the maximum length specified in the `max_new_tokens` argument of [`XPOConfig`]. If you want to penalize the model for not generating an EOS token before reaching the maximum length, you can use the `missing_eos_penalty` argument of [`XPOConfig`]:
96+
When using a reward model, we may want the model to generate completions within a given length. During training, the model will generate completions up to the maximum length specified in the `max_new_tokens` argument of [`experimental.xpo.XPOConfig`]. If you want to penalize the model for not generating an EOS token before reaching the maximum length, you can use the `missing_eos_penalty` argument of [`experimental.xpo.XPOConfig`]:
9397

9498
```python
9599
training_args = XPOConfig(..., max_new_tokens=128, missing_eos_penalty=1.0)
@@ -145,16 +149,16 @@ While training and evaluating we record the following reward metrics:
145149
* `logps/rejected`: The mean log probabilities of the rejected completions.
146150
* `val/model_contain_eos_token`: The amount of times the model's output contains the eos token.
147151
* `val/ref_contain_eos_token`: The amount of times the reference's output contains the eos token.
148-
* `alpha`: The weight of the XPO loss term. Typically fixed, but can be made dynamic by passing a list to [`XPOConfig`].
149-
* `beta`: The parameter that controls the weight of the loss term representing the deviation from the reference model. Typically fixed, but can be made dynamic by passing a list to [`XPOConfig`].
152+
* `alpha`: The weight of the XPO loss term. Typically fixed, but can be made dynamic by passing a list to [`experimental.xpo.XPOConfig`].
153+
* `beta`: The parameter that controls the weight of the loss term representing the deviation from the reference model. Typically fixed, but can be made dynamic by passing a list to [`experimental.xpo.XPOConfig`].
150154

151155
## XPOTrainer
152156

153-
[[autodoc]] XPOTrainer
157+
[[autodoc]] experimental.xpo.XPOTrainer
154158
- train
155159
- save_model
156160
- push_to_hub
157161

158162
## XPOConfig
159163

160-
[[autodoc]] XPOConfig
164+
[[autodoc]] experimental.xpo.XPOConfig

examples/scripts/xpo.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,10 @@
5252
PairRMJudge,
5353
ScriptArguments,
5454
TrlParser,
55-
XPOConfig,
56-
XPOTrainer,
5755
get_kbit_device_map,
5856
get_quantization_config,
5957
)
58+
from trl.experimental.xpo import XPOConfig, XPOTrainer
6059

6160

6261
# Enable logging in a Hugging Face Space

tests/experimental/test_trainers_args.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import pytest
1516
from datasets import load_dataset
16-
from transformers import AutoTokenizer
17+
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
1718

1819
from trl.experimental.bco import BCOConfig, BCOTrainer
20+
from trl.experimental.xpo import XPOConfig, XPOTrainer
1921

2022
from ..testing_utils import TrlTestCase, require_sklearn
2123

@@ -68,3 +70,25 @@ def test_bco(self):
6870
assert trainer.args.prompt_sample_size == 512
6971
assert trainer.args.min_density_ratio == 0.2
7072
assert trainer.args.max_density_ratio == 20.0
73+
74+
@pytest.mark.parametrize("alpha_list", [False, True])
75+
def test_xpo(self, alpha_list):
76+
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
77+
tokenizer = AutoTokenizer.from_pretrained(model_id)
78+
model = AutoModelForCausalLM.from_pretrained(model_id)
79+
ref_model = AutoModelForCausalLM.from_pretrained(model_id)
80+
reward_model = AutoModelForSequenceClassification.from_pretrained(model_id, num_labels=1)
81+
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
82+
training_args = XPOConfig(
83+
self.tmp_dir,
84+
alpha=0.5 if not alpha_list else [0.5, 0.6],
85+
)
86+
trainer = XPOTrainer(
87+
args=training_args,
88+
processing_class=tokenizer,
89+
model=model,
90+
ref_model=ref_model,
91+
reward_funcs=reward_model,
92+
train_dataset=dataset,
93+
)
94+
assert trainer.args.alpha == (0.5 if not alpha_list else [0.5, 0.6])

tests/test_xpo_trainer.py renamed to tests/experimental/test_xpo_trainer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,16 @@
1717
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
1818
from transformers.utils import is_peft_available
1919

20-
from trl import XPOConfig, XPOTrainer
20+
from trl.experimental.xpo import XPOConfig, XPOTrainer
2121

22-
from .testing_utils import RandomPairwiseJudge, TrlTestCase, require_llm_blender, require_peft
22+
from ..testing_utils import RandomPairwiseJudge, TrlTestCase, require_llm_blender, require_peft
2323

2424

2525
if is_peft_available():
2626
from peft import LoraConfig, get_peft_model
2727

2828

29+
@pytest.mark.low_priority
2930
class TestXPOTrainer(TrlTestCase):
3031
def setup_method(self):
3132
self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"

tests/test_trainers_args.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,6 @@
3434
RewardTrainer,
3535
SFTConfig,
3636
SFTTrainer,
37-
XPOConfig,
38-
XPOTrainer,
3937
)
4038

4139
from .testing_utils import TrlTestCase
@@ -320,25 +318,3 @@ def test_sft(self):
320318
assert "append_concat_token" in trainer.args.dataset_kwargs
321319
assert trainer.args.dataset_kwargs["append_concat_token"]
322320
assert trainer.args.eval_packing
323-
324-
@pytest.mark.parametrize("alpha_list", [False, True])
325-
def test_xpo(self, alpha_list):
326-
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
327-
tokenizer = AutoTokenizer.from_pretrained(model_id)
328-
model = AutoModelForCausalLM.from_pretrained(model_id)
329-
ref_model = AutoModelForCausalLM.from_pretrained(model_id)
330-
reward_model = AutoModelForSequenceClassification.from_pretrained(model_id, num_labels=1)
331-
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
332-
training_args = XPOConfig(
333-
self.tmp_dir,
334-
alpha=0.5 if not alpha_list else [0.5, 0.6],
335-
)
336-
trainer = XPOTrainer(
337-
args=training_args,
338-
processing_class=tokenizer,
339-
model=model,
340-
ref_model=ref_model,
341-
reward_funcs=reward_model,
342-
train_dataset=dataset,
343-
)
344-
assert trainer.args.alpha == (0.5 if not alpha_list else [0.5, 0.6])

0 commit comments

Comments
 (0)