Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Sequence-Level KD #2220

Merged
merged 8 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/source/gkd_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ This post-training method was contributed by [Kashif Rasul](https://huggingface.

## Usage tips

The GKD Trainer is a wrapper around the [`SFTTrainer`] class that takes in a teacher model argument. It needs two parameters to be set via the [`GKDConfig`] namely:
The [`GKDTrainer`] is a wrapper around the [`SFTTrainer`] class that takes in a teacher model argument. It needs three parameters to be set via the [`GKDConfig`] namely:
* `lmbda`: controls the student data fraction, i.e., the proportion of on-policy student-generated outputs. When `lmbda=0.0`, the loss reduces to supervised JSD where the student is trained with the token-level probabilities of the teacher. When `lmbda=1.0`, the loss reduces to on-policy JSD, where the student generates output sequences and token-specific feedback on these sequences from the teacher. For values in between [0, 1] it is random between the two based on the `lmbda` value for each batch.
* `seq_kd`: controls whether to perform Sequence-Level KD (can be viewed as supervised FT on teacher-generated out). When `seq_kd=True` and `lmbda=0.0`, the loss reduces to supervised JSD, where the teacher generates output sequences and the student receives token-specific feedback on these sequences from the teacher.
* `beta`: controls the interpolation in the generalized Jensen-Shannon Divergence. When `beta=0.0` the loss approximates forward KL divergence, while for `beta=1.0` the loss approximates reverse KL divergence. For values in between [0, 1] it interpolates between the two.

The authors find that on-policy data (high `lmbda`) performs better and the optimal `beta` varied depending on the task and evaluation method.
Expand Down
4 changes: 4 additions & 0 deletions trl/trainer/gkd_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ class GKDConfig(SFTConfig):
from a string.
disable_dropout (`bool`, *optional*, defaults to `True`):
Whether or not to disable dropouts in `model`.
seq_kd (`bool`, *optional*, defaults to `False`):
Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT
on teacher-generated output).
"""

temperature: float = 0.9
Expand All @@ -50,6 +53,7 @@ class GKDConfig(SFTConfig):
teacher_model_name_or_path: Optional[str] = None
teacher_model_init_kwargs: Optional[Dict[str, Any]] = None
disable_dropout: bool = True
seq_kd: bool = False

def __post_init__(self):
super().__post_init__()
Expand Down
9 changes: 9 additions & 0 deletions trl/trainer/gkd_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def __init__(
self.lmbda = args.lmbda
self.beta = args.beta
self.temperature = args.temperature
self.seq_kd = args.seq_kd

self.generation_config = GenerationConfig(
max_new_tokens=args.max_new_tokens,
Expand Down Expand Up @@ -280,6 +281,14 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,
With probability `self.lmbda`, it generates new responses using the student model,
which are then used for training instead of the original inputs.
"""
if self.seq_kd:
with unwrap_model_for_generation(self.teacher_model, self.accelerator) as unwrapped_model:
new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id
)
inputs["input_ids"] = new_input_ids
inputs["attention_mask"] = new_attention_mask
inputs["labels"] = new_labels
if random.random() <= self.lmbda:
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
Expand Down
Loading