Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@
title: GSPO-token
- local: judges
title: Judges
- local: minillm
title: MiniLLM
- local: papo_trainer
title: PAPO
- local: xpo_trainer
Expand Down
67 changes: 67 additions & 0 deletions docs/source/minillm.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# MiniLLM Trainer

[![All_models-MiniLLM-blue](https://img.shields.io/badge/All_models-MiniLLM-blue)](https://huggingface.co/models?other=minillm,trl)

## Overview

TRL supports the MiniLLM Trainer for distilling large language models into smaller ones using reverse KLD for better precision, quality, and performance, as described in the paper [Knowledge Distillation of Large Language Models](https://huggingface.co/papers/2306.08543) by [Yuxian Gu](https://huggingface.co/t1101675), [Li Dong](https://huggingface.co/unilm), [Furu Wei](https://huggingface.co/thegenerality), and Minlie Huang.
The abstract from the paper is the following:

> Knowledge Distillation (KD) is a promising technique for reducing the high computational demand of large language models (LLMs). However, previous KD methods are primarily applied to white-box classification models or training small models to imitate black-box model APIs like ChatGPT. How to effectively distill the knowledge from white-box generative LLMs is still under-explored, which becomes more and more important with the prosperity of LLMs. In this work, we propose MiniLLM that distills smaller language models from generative larger language models. We first replace the forward Kullback-Leibler divergence (KLD) objective in the standard KD approaches with reverse KLD, which is more suitable for KD on generative language models, to prevent the student model from overestimating the low-probability regions of the teacher distribution. Then, we derive an effective optimization approach to learn this objective. Extensive experiments in the instruction-following setting show that the MiniLLM models generate more precise responses with the higher overall quality, lower exposure bias, better calibration, and higher long-text generation performance. Our method is also scalable for different model families with 120M to 13B parameters. We will release our code and model checkpoints at https://aka.ms/MiniLLM.

This post-training method was contributed by [Yuxian Gu](https://huggingface.co/t1101675).

It is a generalized version of [Think Machine Lab's On-Policy Distillation](https://thinkingmachines.ai/blog/on-policy-distillation/), with the option to add distribution-level single-step distillation signals (like GKD when `beta=1`) and long-context reverse KLD signals.

$$
\begin{align}
L_{\text{MiniLLM}}&=\alpha_1\mathbb{E}_{x\sim \pi_{\theta}}\sum_{t'=t}^{|x|}\frac{\gamma^{t'-t}}{\sum_{t'}\gamma^{t'-t}}\left[\log \frac{\pi_{\theta}(x_{t'+1}|x_{1..t'})}{\pi_{\text{teacher}}(x_{t'+1}|x_{1..t'})}\right] \\
&+ \alpha_2\mathbb{E}_{x\sim \pi_{\theta}} \text{KL}\left[\pi_\theta(\cdot|x_{1..t})||\pi_{\text{teacher}}(\cdot | x_{1..t})\right].
\end{align}
$$

When \\( \alpha_1=1 \\), \\( \alpha_2=0 \\), \\( \gamma=0 \\), which corresponds to

```python
from trl.experimental.minillm import MiniLLMConfig

training_args = MiniLLMConfig(
rkl_advantage=True,
single_step_decomposition=False,
gamma=False
)
```

\\( L_{\text{MiniLLM}} \\) becomes the on-policy KD implemented in [Tinker](https://github.com/thinking-machines-lab/tinker-cookbook/blob/5d08be6d130596b7bedd02197861c41fa81ea436/tinker_cookbook/distillation/train_on_policy.py#L88):

$$
L_{\text{tinker}}=\mathbb{E}_{x\sim \pi_{\theta}}\left[\log \frac{\pi_{\theta}(x_{t'+1}|x_{1..t'})}{\pi_{\text{teacher}}(x_{t'+1}|x_{1..t'})}\right].
$$

When \\( \alpha_1=0 \\), \\( \alpha_2=1 \\), which corresponds to

```python
from trl.experimental.minillm import MiniLLMConfig

training_args = MiniLLMConfig(
rkl_advantage=False,
single_step_decomposition=True
)
```

\\( L_{\text{MiniLLM}} \\) becomes the reverse KLD version of the GKD loss as in [GKD Trainer](./gkd.md):

$$
L_{\text{GKD-RKL}}=\mathbb{E}_{x\sim \pi_{\theta}} \text{KL}\left[\pi_\theta(\cdot|x_{1..t})||\pi_{\text{teacher}}(\cdot | x_{1..t})\right].
$$

## MiniLLMTrainer

[[autodoc]] experimental.minillm.MiniLLMTrainer
- train
- save_model
- push_to_hub

## MiniLLMConfig

[[autodoc]] experimental.minillm.MiniLLMConfig
26 changes: 26 additions & 0 deletions docs/source/paper_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -671,3 +671,29 @@ config = GOLDConfig(

)
```

### Knowledge Distillation of Large Language Models

**📜 Paper**: https://huggingface.co/papers/2306.08543

MiniLLM is the first on-policy knowledge distillation method, which minimizes the sequence-level reverse KLD between the teacher and the student model and is optimized by reinforcement learning.

It is a generalized version of [Think Machine Lab's On-Policy Distillation](https://thinkingmachines.ai/blog/on-policy-distillation/), with the option to add distribution-level single-step distillation signals (like GKD when `beta=1`) and long-context reverse KLD signals.

Alternatively, you can use the [`experimental.MiniLLMTrainer`] and [`experimental.MiniLLMConfig`] to perform MiniLLM distillation as follows:

```python
from datasets import load_dataset
from trl.experimental.minillm import MiniLLMTrainer

dataset = load_dataset("trl-lib/tldr", split="train")

trainer = MiniLLMTrainer(
model="Qwen/Qwen3-0.6B",
teacher_model="Qwen/Qwen3-1.7B",
train_dataset=dataset,
)
trainer.train()
```

For more details, see the [MiniLLM Trainer documentation](minillm) documentation.
57 changes: 57 additions & 0 deletions tests/experimental/test_minillm_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch
from datasets import load_dataset

from trl.experimental.minillm import MiniLLMConfig, MiniLLMTrainer

from ..testing_utils import TrlTestCase


@pytest.mark.low_priority
class TestMiniLLMTrainer(TrlTestCase):
def test_train(self):
# Get the dataset
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

# Initialize the trainer
training_args = MiniLLMConfig(
output_dir=self.tmp_dir,
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=32, # reduce the completion length to reduce memory usage
report_to="none",
)
trainer = MiniLLMTrainer(
model="trl-internal-testing/small-Qwen3ForCausalLM",
teacher_model="trl-internal-testing/tiny-Qwen3ForCausalLM",
args=training_args,
train_dataset=dataset,
)

# Save the initial parameters to compare them later
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

# Train the model
trainer.train()

# Check that the training loss is not None
assert trainer.state.log_history[-1]["train_loss"] is not None

# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
19 changes: 19 additions & 0 deletions trl/experimental/minillm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .minillm_config import MiniLLMConfig
from .minillm_trainer import MiniLLMTrainer


__all__ = ["MiniLLMConfig", "MiniLLMTrainer"]
150 changes: 150 additions & 0 deletions trl/experimental/minillm/minillm_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from dataclasses import dataclass, field
from typing import Any

from transformers import TrainingArguments

from ...trainer.grpo_config import GRPOConfig


@dataclass
class MiniLLMConfig(GRPOConfig):
"""
Configuration class for [`MiniLLMTrainer`].

This class includes only the parameters that are specific to MiniLLM training. For a full list of training
arguments, please refer to the [`~transformers.TrainingArguments`] and [`GRPOConfig`] documentation.

Args:
temperature (`float`, *optional*, defaults to `0.9`):
Temperature for sampling. The higher the temperature, the more random the completions.
lmbda (`float`, *optional*, defaults to `0.5`):
Lambda parameter that controls the student data fraction (i.e., the proportion of on-policy
student-generated outputs).
beta (`float`, *optional*, defaults to `0.5`):
Interpolation coefficient between `0.0` and `1.0` of the Generalized Jensen-Shannon Divergence loss. When
beta is `0.0`, the loss is the KL divergence. When beta is `1.0`, the loss is the Inverse KL Divergence.
max_new_tokens (`int`, *optional*, defaults to `128`):
Maximum number of tokens to generate per completion.
teacher_model_name_or_path (`str`, *optional*):
Model name or path of the teacher model. If `None`, the teacher model will be the same as the model being
trained.
teacher_model_init_kwargs (`dict[str, Any]]`, *optional*):
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the teacher model
from a string.
disable_dropout (`bool`, *optional*, defaults to `True`):
Whether to disable dropout in the model.
seq_kd (`bool`, *optional*, defaults to `False`):
Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT on
teacher-generated output).
"""

teacher_model_init_kwargs: dict[str, Any] | None = field(
default=None,
metadata={
"help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the "
"teacher model from a string."
},
)
disable_dropout: bool = field(
default=True,
metadata={"help": "Whether to disable dropouts in `model`."},
)
rkl_advantage: bool = field(
default=True,
metadata={"help": "Whether to add the reverse KL advantage to the reward advantage."},
)
single_step_decomposition: bool = field(
default=True,
metadata={"help": "Whether to use single-step decomposition for the KL divergence computation."},
)
kd_temperature: float = field(
default=1.0,
metadata={
"help": "Temperature for knowledge distillation. Higher temperatures produce softer probability "
"distributions over classes."
},
)
gamma: float = field(
default=0.0,
metadata={"help": "Discount factor for future rewards in reinforcement learning."},
)
length_normalization: bool = field(
default=True,
metadata={"help": "Whether to apply length normalization to the rewards."},
)

def __post_init__(self):
# We do not use the post_init of GRPOConfig because:
# 1. num_generations can be < 2 in MiniLLMConfig. Scale_rewards must be set to "none" to avoid nan.
self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16

TrainingArguments.__post_init__(self)

self.scale_rewards = {True: "group", False: "none"}.get(self.scale_rewards, self.scale_rewards)
if self.num_generations == 1:
self.scale_rewards = "none"

num_processes = self.world_size
# The current default effective batch size
if self.generation_batch_size is None and self.steps_per_generation is None:
self.steps_per_generation = self.gradient_accumulation_steps
self.generation_batch_size = self.per_device_train_batch_size * num_processes * self.steps_per_generation
elif self.generation_batch_size is not None and self.steps_per_generation is None:
# Just ensure the value is divisible by the global batch size
if self.generation_batch_size % (self.per_device_train_batch_size * num_processes) != 0:
raise ValueError(
f"generation_batch_size ({self.generation_batch_size}) must be divisible by the global batch size "
f"({self.per_device_train_batch_size * num_processes})."
)
self.steps_per_generation = self.generation_batch_size // (
self.per_device_train_batch_size * num_processes
)
elif self.generation_batch_size is None and self.steps_per_generation is not None:
self.generation_batch_size = self.per_device_train_batch_size * num_processes * self.steps_per_generation
else:
raise ValueError(
"'generation_batch_size' and 'steps_per_generation' can not be both configured at the same time"
)

if self.do_eval and self.eval_strategy != "no":
# Just ensure the value is divisible by the global batch size
if (self.per_device_eval_batch_size * num_processes) % self.num_generations != 0:
raise ValueError(
f"The global eval batch size ({self.per_device_eval_batch_size} * {num_processes}) must be "
f"divisible by num_generations ({self.num_generations})."
)

# The generation batch must contain full prompt groups (no partials), so it must be divisible by
# num_generations.
if self.generation_batch_size % self.num_generations != 0:
raise ValueError(
f"generation_batch_size ({self.generation_batch_size}) must be divisible by num_generations "
f"({self.num_generations})."
)

if self.use_liger_loss is not None:
warnings.warn(
"The `use_liger_loss` argument is deprecated and will be removed in version 0.28.0. Please use "
"`use_liger_kernel` instead.",
FutureWarning,
stacklevel=2,
)
self.use_liger_kernel = self.use_liger_loss

if self.delta is not None and self.use_liger_kernel:
raise ValueError("Liger kernel does not support two-sided GRPO loss yet.")
Loading