Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,6 @@
title: DPO
- local: online_dpo_trainer
title: Online DPO
- local: gkd_trainer
title: GKD
- local: grpo_trainer
title: GRPO
- local: kto_trainer
Expand Down Expand Up @@ -107,6 +105,8 @@
title: CPO
- local: gfpo
title: GFPO
- local: gkd_trainer
title: GKD
- local: gold_trainer
title: GOLD
- local: grpo_with_replay_buffer
Expand Down
2 changes: 1 addition & 1 deletion docs/source/dataset_formats.md
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ Choosing the right dataset type depends on the task you are working on and the s
| [`experimental.bco.BCOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) |
| [`experimental.cpo.CPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
| [`DPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
| [`GKDTrainer`] | [Prompt-completion](#prompt-completion) |
| [`experimental.gkd.GKDTrainer`] | [Prompt-completion](#prompt-completion) |
| [`GRPOTrainer`] | [Prompt-only](#prompt-only) |
| [`KTOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) |
| [`NashMDTrainer`] | [Prompt-only](#prompt-only) |
Expand Down
4 changes: 2 additions & 2 deletions docs/source/example_overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ Scripts are maintained in the [`trl/scripts`](https://github.com/huggingface/trl
| [`examples/scripts/cpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/cpo.py) | This script shows how to use the [`experimental.cpo.CPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
| [`trl/scripts/dpo.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/dpo.py) | This script shows how to use the [`DPOTrainer`] to fine-tune a model. |
| [`examples/scripts/dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_vlm.py) | This script shows how to use the [`DPOTrainer`] to fine-tune a Vision Language Model to reduce hallucinations using the [openbmb/RLAIF-V-Dataset](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset) dataset. |
| [`examples/scripts/evals/judge_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/evals/judge_tldr.py) | This script shows how to use [`HfPairwiseJudge`] or [`experimental.judges.OpenAIPairwiseJudge`] to judge model generations. |
| [`examples/scripts/gkd.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/gkd.py) | This script shows how to use the [`GKDTrainer`] to fine-tune a model. |
| [`examples/scripts/evals/judge_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/evals/judge_tldr.py) | This script shows how to use [`experimental.judges.HfPairwiseJudge`] or [`experimental.judges.OpenAIPairwiseJudge`] to judge model generations. |
| [`examples/scripts/gkd.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/gkd.py) | This script shows how to use the [`experimental.gkd.GKDTrainer`] to fine-tune a model. |
| [`trl/scripts/grpo.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/grpo.py) | This script shows how to use the [`GRPOTrainer`] to fine-tune a model. |
| [`examples/scripts/grpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/grpo_vlm.py) | This script shows how to use the [`GRPOTrainer`] to fine-tune a multimodal model for reasoning using the [lmms-lab/multimodal-open-r1-8k-verified](https://huggingface.co/datasets/lmms-lab/multimodal-open-r1-8k-verified) dataset. |
| [`examples/scripts/gspo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/gspo.py) | This script shows how to use GSPO via the [`GRPOTrainer`] to fine-tune model for reasoning using the [AI-MO/NuminaMath-TIR](https://huggingface.co/datasets/AI-MO/NuminaMath-TIR) dataset. |
Expand Down
15 changes: 6 additions & 9 deletions docs/source/gkd_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ This post-training method was contributed by [Kashif Rasul](https://huggingface.

## Usage tips

The [`GKDTrainer`] is a wrapper around the [`SFTTrainer`] class that takes in a teacher model argument. It needs three parameters to be set via the [`GKDConfig`] namely:
The [`experimental.gkd.GKDTrainer`] is a wrapper around the [`SFTTrainer`] class that takes in a teacher model argument. It needs three parameters to be set via the [`experimental.gkd.GKDConfig`] namely:

* `lmbda`: controls the student data fraction, i.e., the proportion of on-policy student-generated outputs. When `lmbda=0.0`, the loss reduces to supervised JSD where the student is trained with the token-level probabilities of the teacher. When `lmbda=1.0`, the loss reduces to on-policy JSD, where the student generates output sequences and token-specific feedback on these sequences from the teacher. For values in between [0, 1] it is random between the two based on the `lmbda` value for each batch.
* `seq_kd`: controls whether to perform Sequence-Level KD (can be viewed as supervised FT on teacher-generated out). When `seq_kd=True` and `lmbda=0.0`, the loss reduces to supervised JSD, where the teacher generates output sequences and the student receives token-specific feedback on these sequences from the teacher.
* `seq_kd`: controls whether to perform Sequence-Level KD (can be viewed as supervised FT on teacher-generated out). When `seq_kd=True` and `lmbda=0.0`, the loss reduces to supervised JSD, where the teacher generates output sequences and the student receives token-specific feedback on these sequences from the teacher.
* `beta`: controls the interpolation in the generalized Jensen-Shannon Divergence. When `beta=0.0` the loss approximates forward KL divergence, while for `beta=1.0` the loss approximates reverse KL divergence. For values in between [0, 1] it interpolates between the two.

The authors find that on-policy data (high `lmbda`) performs better and the optimal `beta` varied depending on the task and evaluation method.
Expand All @@ -34,11 +34,8 @@ The basic API is as follows:

```python
from datasets import Dataset
from trl import GKDConfig, GKDTrainer
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
)
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl.experimental.gkd import GKDConfig, GKDTrainer

NUM_DUMMY_SAMPLES = 100

Expand Down Expand Up @@ -92,11 +89,11 @@ The dataset should be formatted as a list of "messages" where each message is a

## GKDTrainer

[[autodoc]] GKDTrainer
[[autodoc]] experimental.gkd.GKDTrainer
- train
- save_model
- push_to_hub

## GKDConfig

[[autodoc]] GKDConfig
[[autodoc]] experimental.gkd.GKDConfig
4 changes: 2 additions & 2 deletions docs/source/gold_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Key capabilities:

1. **Cross-tokenizer alignment** – GOLD incrementally decodes the student and teacher tokens, groups passages with the same visible text, and merges probabilities inside each group. This guarantees loss terms are computed over the full completion even when token boundaries differ.
2. **Hybrid ULD loss** – when `uld_use_hybrid_loss` is enabled, GOLD compares exact vocabulary matches directly and falls back to the original sorted-probability ULD loss for unmatched tokens. This improves stability for students whose vocabularies only partially overlap with the teacher.
3. **Seamless integration with GKD** – GOLD inherits the on-policy vs. off-policy scheduling from the [`GKDTrainer`](./gkd_trainer.md), so you can combine sequence-level KD, generalized JSD, and cross-tokenizer distillation in a single training run.
3. **Seamless integration with GKD** – GOLD inherits the on-policy vs. off-policy scheduling from the [`experimental.gkd.GKDTrainer`], so you can combine sequence-level KD, generalized JSD, and cross-tokenizer distillation in a single training run.

> [!NOTE]
> GOLD is currently part of the `trl.experimental` namespace. APIs may change without notice while the feature is iterated on.
Expand All @@ -27,7 +27,7 @@ messages). Important configuration flags on [`GOLDConfig`] include:
* `teacher_tokenizer_name_or_path` – required when `use_uld_loss=True`; GOLD uses the teacher tokenizer to align tokens.
* `uld_use_hybrid_loss`, `uld_hybrid_matched_weight`, `uld_hybrid_unmatched_weight` – enables and weights the hybrid
matched/unmatched loss.
* `beta`, `lmbda`, `seq_kd` – inherited from `GKDConfig`, controlling the generalized JSD interpolation and on-policy
* `beta`, `lmbda`, `seq_kd` – inherited from [`experimental.gkd.GKDConfig`], controlling the generalized JSD interpolation and on-policy
sampling ratio.

A minimal end-to-end example:
Expand Down
2 changes: 1 addition & 1 deletion docs/source/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Below is the current list of TRL trainers, organized by method type (⚡️ = vL

### Knowledge distillation

- [`GKDTrainer`]
- [`experimental.gkd.GKDTrainer`] 🧪
- [`experimental.minillm.MiniLLMTrainer`] 🧪

</div>
Expand Down
2 changes: 1 addition & 1 deletion docs/source/liger_kernel_integration.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ training_args = KTOConfig(..., use_liger_kernel=True)
<hfoption id="GKD">

```python
from trl import GKDConfig
from trl.experimental.gkd import GKDConfig

training_args = GKDConfig(..., use_liger_kernel=True)
```
Expand Down
6 changes: 3 additions & 3 deletions docs/source/paper_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -646,12 +646,12 @@ On-Policy Distillation has been shown to outperform SFT, GRPO and can be used to

Additionally on-policy distillation is more compute efficient and is less prone to overfitting when trained with limited data.

To train a model with on-policy distillation using TRL, you can use the following configuration, with the [`GKDTrainer`] and [`GKDConfig`]:
To train a model with on-policy distillation using TRL, you can use the following configuration, with the [`experimental.gkd.GKDTrainer`] and [`experimental.gkd.GKDConfig`]:

```python
from trl import GKDConfig
from trl.experimental.gkd import GKDConfig

config = GKDConfig(
training_args = GKDConfig(
lmbda=1.0, # student produces rollouts for all batches
beta=1.0, # to ensure reverse-kl as the loss function
teacher_model_name_or_path="teacher-model", # specify the teacher model
Expand Down
2 changes: 1 addition & 1 deletion docs/source/reducing_memory_usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ training_args = KTOConfig(..., use_liger_kernel=True)
<hfoption id="GKD">

```python
from trl import GKDConfig
from trl.experimental.gkd import GKDConfig

training_args = GKDConfig(..., use_liger_kernel=True)
```
Expand Down
3 changes: 1 addition & 2 deletions examples/scripts/gkd.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@
from transformers import AutoTokenizer, GenerationConfig

from trl import (
GKDConfig,
GKDTrainer,
LogCompletionsCallback,
ModelConfig,
ScriptArguments,
Expand All @@ -68,6 +66,7 @@
get_peft_config,
get_quantization_config,
)
from trl.experimental.gkd import GKDConfig, GKDTrainer


# Enable logging in a Hugging Face Space
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig

from trl import GKDConfig, GKDTrainer
from trl.experimental.gkd import GKDConfig, GKDTrainer

from .testing_utils import TrlTestCase, require_liger_kernel
from ..testing_utils import TrlTestCase, require_liger_kernel


class TestGKDTrainerGenerateOnPolicy(TrlTestCase):
Expand Down
19 changes: 19 additions & 0 deletions trl/experimental/gkd/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .gkd_config import GKDConfig
from .gkd_trainer import GKDTrainer


__all__ = ["GKDConfig", "GKDTrainer"]
112 changes: 112 additions & 0 deletions trl/experimental/gkd/gkd_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, field
from typing import Any

from transformers import TrainingArguments

from ...trainer.sft_config import SFTConfig


@dataclass
class GKDConfig(SFTConfig):
"""
Configuration class for [`experimental.gkd.GKDTrainer`].

This class includes only the parameters that are specific to GKD training. For a full list of training arguments,
please refer to the [`~transformers.TrainingArguments`] and [`SFTConfig`] documentation.

Args:
temperature (`float`, *optional*, defaults to `0.9`):
Temperature for sampling. The higher the temperature, the more random the completions.
lmbda (`float`, *optional*, defaults to `0.5`):
Lambda parameter that controls the student data fraction (i.e., the proportion of on-policy
student-generated outputs).
beta (`float`, *optional*, defaults to `0.5`):
Interpolation coefficient between `0.0` and `1.0` of the Generalized Jensen-Shannon Divergence loss. When
beta is `0.0`, the loss is the KL divergence. When beta is `1.0`, the loss is the Inverse KL Divergence.
max_new_tokens (`int`, *optional*, defaults to `128`):
Maximum number of tokens to generate per completion.
teacher_model_name_or_path (`str`, *optional*):
Model name or path of the teacher model. If `None`, the teacher model will be the same as the model being
trained.
teacher_model_init_kwargs (`dict[str, Any]]`, *optional*):
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the teacher model
from a string.
disable_dropout (`bool`, *optional*, defaults to `True`):
Whether to disable dropout in the model.
seq_kd (`bool`, *optional*, defaults to `False`):
Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT on
teacher-generated output).
"""

_VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["teacher_model_init_kwargs"]

temperature: float = field(
default=0.9,
metadata={"help": "Temperature for sampling. The higher the temperature, the more random the completions."},
)
lmbda: float = field(
default=0.5,
metadata={
"help": "Lambda parameter that controls the student data fraction (i.e., the proportion of on-policy "
"student-generated outputs)."
},
)
beta: float = field(
default=0.5,
metadata={
"help": "Interpolation coefficient between `0.0` and `1.0` of the Generalized Jensen-Shannon Divergence "
"loss. When beta is `0.0`, the loss is the KL divergence. When beta is `1.0`, the loss is the Inverse KL "
"Divergence."
},
)
max_new_tokens: int = field(
default=128,
metadata={"help": "Maximum number of tokens to generate per completion."},
)
teacher_model_name_or_path: str | None = field(
default=None,
metadata={
"help": "Model name or path of the teacher model. If `None`, the teacher model will be the same as the "
"model being trained."
},
)
teacher_model_init_kwargs: dict[str, Any] | None = field(
default=None,
metadata={
"help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the "
"teacher model from a string."
},
)
disable_dropout: bool = field(
default=True,
metadata={"help": "Whether to disable dropouts in `model`."},
)
seq_kd: bool = field(
default=False,
metadata={
"help": "Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised "
"FT on teacher-generated output)."
},
)

def __post_init__(self):
super().__post_init__()
# check lmbda and beta are in the range [0, 1]
if self.lmbda < 0.0 or self.lmbda > 1.0:
raise ValueError("lmbda must be in the range [0.0, 1.0].")
if self.beta < 0.0 or self.beta > 1.0:
raise ValueError("beta must be in the range [0.0, 1.0].")
Loading
Loading