You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/source/example_overview.md
+1-1Lines changed: 1 addition & 1 deletion
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -66,7 +66,7 @@ Scripts are maintained in the [`trl/scripts`](https://github.com/huggingface/trl
66
66
|[`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm.py)| This script shows how to use the [`SFTTrainer`] to fine-tune a Vision Language Model in a chat setting. The script has only been tested with [LLaVA 1.5](https://huggingface.co/llava-hf/llava-1.5-7b-hf), [LLaVA 1.6](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf), and [Llama-3.2-11B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct) models, so users may see unexpected behaviour in other model architectures. |
67
67
|[`examples/scripts/sft_vlm_gemma3.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm_gemma3.py)| This script shows how to use the [`SFTTrainer`] to fine-tune a Gemma 3 model on vision to text tasks. |
68
68
|[`examples/scripts/sft_vlm_smol_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm_smol_vlm.py)| This script shows how to use the [`SFTTrainer`] to fine-tune a SmolVLM model. |
69
-
|[`examples/scripts/xpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/xpo.py)| This script shows how to use the [`XPOTrainer`] to fine-tune a model. |
69
+
|[`examples/scripts/xpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/xpo.py)| This script shows how to use the [`experimental.xpo.XPOTrainer`] to fine-tune a model. |
Copy file name to clipboardExpand all lines: docs/source/xpo_trainer.md
+11-7Lines changed: 11 additions & 7 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -12,6 +12,9 @@ The abstract from the paper is the following:
12
12
13
13
This post-training method was contributed by [Kashif Rasul](https://huggingface.co/kashif), [Quentin Gallouédec](https://huggingface.co/qgallouedec) and [Lewis Tunstall](https://huggingface.co/lewtun).
14
14
15
+
> [!NOTE]
16
+
> XPO is currently experimental. The API may change without notice while the feature is iterated on.
17
+
15
18
## Quick start
16
19
17
20
This example demonstrates how to train a model using the XPO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model and [`PairRMJudge`] as a judge. We use the prompts from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the prompts in the dataset here:
@@ -27,7 +30,8 @@ Below is the script to train the model:
27
30
```python
28
31
# train_xpo.py
29
32
from datasets import load_dataset
30
-
from trl import PairRMJudge, XPOConfig, XPOTrainer
33
+
from trl import PairRMJudge
34
+
from trl.experimental.xpo import XPOConfig, XPOTrainer
31
35
from transformers import AutoModelForCausalLM, AutoTokenizer
32
36
33
37
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
@@ -62,7 +66,7 @@ The best programming language depends on individual preferences and familiarity
62
66
63
67
## Expected dataset type
64
68
65
-
XPO requires a [prompt-only dataset](dataset_formats#prompt-only). The [`XPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
69
+
XPO requires a [prompt-only dataset](dataset_formats#prompt-only). The [`experimental.xpo.XPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
66
70
67
71
## Usage tips
68
72
@@ -89,7 +93,7 @@ Instead of a judge, you can chose to use a reward model -- see [Reward Bench](ht
89
93
90
94
### Encourage EOS token generation
91
95
92
-
When using a reward model, we may want the model to generate completions within a given length. During training, the model will generate completions up to the maximum length specified in the `max_new_tokens` argument of [`XPOConfig`]. If you want to penalize the model for not generating an EOS token before reaching the maximum length, you can use the `missing_eos_penalty` argument of [`XPOConfig`]:
96
+
When using a reward model, we may want the model to generate completions within a given length. During training, the model will generate completions up to the maximum length specified in the `max_new_tokens` argument of [`experimental.xpo.XPOConfig`]. If you want to penalize the model for not generating an EOS token before reaching the maximum length, you can use the `missing_eos_penalty` argument of [`experimental.xpo.XPOConfig`]:
@@ -145,16 +149,16 @@ While training and evaluating we record the following reward metrics:
145
149
*`logps/rejected`: The mean log probabilities of the rejected completions.
146
150
*`val/model_contain_eos_token`: The amount of times the model's output contains the eos token.
147
151
*`val/ref_contain_eos_token`: The amount of times the reference's output contains the eos token.
148
-
*`alpha`: The weight of the XPO loss term. Typically fixed, but can be made dynamic by passing a list to [`XPOConfig`].
149
-
*`beta`: The parameter that controls the weight of the loss term representing the deviation from the reference model. Typically fixed, but can be made dynamic by passing a list to [`XPOConfig`].
152
+
*`alpha`: The weight of the XPO loss term. Typically fixed, but can be made dynamic by passing a list to [`experimental.xpo.XPOConfig`].
153
+
*`beta`: The parameter that controls the weight of the loss term representing the deviation from the reference model. Typically fixed, but can be made dynamic by passing a list to [`experimental.xpo.XPOConfig`].
0 commit comments