Skip to content

Commit

Permalink
[RewardTrainer] Tokenize inputs within trainer (huggingface#2102)
Browse files Browse the repository at this point in the history
* Pretokenize in reward modelling

* Fix README example

* Apply suggestions from code review

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Move chat template formatting inside trainer

* Refactor tests

* Fix README

* Disable wandb

* Update readme

* add comment `remove_unused_columns`

* Update trl/trainer/reward_config.py

* doc

* implicit*

* explicit

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
  • Loading branch information
3 people authored Sep 24, 2024
1 parent 2cad48d commit cc23b51
Show file tree
Hide file tree
Showing 6 changed files with 187 additions and 392 deletions.
48 changes: 2 additions & 46 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,12 @@ from datasets import load_dataset

dataset = load_dataset("trl-lib/Capybara", split="train")

# configure trainer
training_args = SFTConfig(output_dir="Qwen/Qwen2.5-0.5B-SFT")
trainer = SFTTrainer(
args=training_args,
model="Qwen/Qwen2.5-0.5B",
train_dataset=dataset,
)

# train
trainer.train()
```

Expand All @@ -121,7 +118,6 @@ Here is a basic example on how to use the `RewardTrainer`:

```python
from trl import RewardConfig, RewardTrainer
from trl.extras.dataset_formatting import conversations_formatting_function
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer

Expand All @@ -131,48 +127,15 @@ model = AutoModelForSequenceClassification.from_pretrained(
)
model.config.pad_token_id = tokenizer.pad_token_id

dataset = load_dataset("trl-lib/Capybara-Preferences", split="train")

def preprocess_function(examples):
new_examples = {
"input_ids_chosen": [],
"attention_mask_chosen": [],
"input_ids_rejected": [],
"attention_mask_rejected": [],
}
for chosen, rejected in zip(examples["chosen"], examples["rejected"]):
tokenized_chosen = tokenizer(chosen)
tokenized_rejected = tokenizer(rejected)
new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"])
new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"])
new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"])
new_examples["attention_mask_rejected"].append(
tokenized_rejected["attention_mask"]
)

return new_examples

chosen_fn = conversations_formatting_function(tokenizer, "chosen")
rejected_fn = conversations_formatting_function(tokenizer, "rejected")
dataset = dataset.map(lambda x: {"chosen": chosen_fn(x), "rejected": rejected_fn(x)})
dataset = dataset.map(
preprocess_function,
batched=True,
)
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")

training_args = RewardConfig(
per_device_train_batch_size=2,
remove_unused_columns=False,
output_dir="Qwen2.5-0.5B-Reward",
)
training_args = RewardConfig(output_dir="Qwen2.5-0.5B-Reward", per_device_train_batch_size=2)
trainer = RewardTrainer(
args=training_args,
model=model,
tokenizer=tokenizer,
train_dataset=dataset,
)

# train
trainer.train()
```

Expand Down Expand Up @@ -210,7 +173,6 @@ trainer = RLOOTrainer(
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
)
# train
trainer.train()
```

Expand All @@ -219,30 +181,24 @@ trainer.train()
`DPOTrainer` implements the popular [Direct Preference Optimization (DPO) algorithm](https://huggingface.co/papers/2305.18290) that was used to post-train Llama 3 and many other models. Here is a basic example on how to use the `DPOTrainer`:

```python
# imports
from trl import DPOConfig, DPOTrainer, maybe_extract_prompt, maybe_apply_chat_template
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

# load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")

# load preference dataset - needs to be in a specific format
dataset = load_dataset("trl-lib/Capybara-Preferences", split="train")
dataset = dataset.map(maybe_extract_prompt)
dataset = dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": tokenizer})

# load trainer
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
trainer = DPOTrainer(
args=training_args,
model=model,
tokenizer=tokenizer,
train_dataset=dataset,
)

# train
trainer.train()
```

Expand Down
14 changes: 3 additions & 11 deletions docs/source/reward_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,10 @@ Check out a complete flexible example at [`examples/scripts/reward_modeling.py`]

## Expected dataset format

The [`RewardTrainer`] expects a very specific format for the dataset since the model will be trained on pairs of examples to predict which of the two is preferred. We provide an example from the [`Anthropic/hh-rlhf`](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset below:
The [`RewardTrainer`] requires a [*implicit prompt* preference dataset](dataset_formats#preference). It means that the dataset should only contain the columns `chosen` and `rejected` (and not `prompt`).
The [`RewardTrainer`] supports both [conversational](dataset_formats#conversational-dataset-format) and [standard](dataset_formats#standard-dataset-format) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.

<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/rlhf-antropic-example.png", width="50%">
</div>

Therefore the final dataset object should contain two 4 entries at least if you use the default [`RewardDataCollatorWithPadding`] data collator. The entries should be named:

- `input_ids_chosen`
- `attention_mask_chosen`
- `input_ids_rejected`
- `attention_mask_rejected`
You can also use a pretokenized dataset, in which case the dataset should contain the following columns: `input_ids_chosen`, `attention_mask_chosen`, `input_ids_rejected` and `attention_mask_rejected`.

## Using the `RewardTrainer`

Expand Down
61 changes: 7 additions & 54 deletions examples/scripts/reward_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
--output_dir Qwen2-0.5B-Reward \
--per_device_train_batch_size 8 \
--num_train_epochs 1 \
--gradient_accumulation_steps 1 \
--remove_unused_columns False \
--gradient_checkpointing True \
--learning_rate 1.0e-5 \
--logging_steps 25 \
Expand All @@ -32,17 +30,15 @@
python examples/scripts/reward_modeling.py \
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
--dataset_name trl-lib/ultrafeedback_binarized \
--output_dir Qwen2-0.5B-Reward \
--output_dir Qwen2-0.5B-Reward-LoRA \
--per_device_train_batch_size 8 \
--num_train_epochs 1 \
--gradient_accumulation_steps 1 \
--remove_unused_columns False \
--gradient_checkpointing True \
--learning_rate 1.0e-5 \
--learning_rate 1.0e-4 \
--logging_steps 25 \
--eval_strategy steps \
--eval_steps 50 \
--max_length 2048 /
--max_length 2048 \
--use_peft \
--lora_r 32 \
--lora_alpha 16
Expand All @@ -51,9 +47,7 @@
import warnings

import torch
from accelerate import PartialState
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoModelForSequenceClassification, AutoTokenizer, HfArgumentParser

from trl import (
Expand All @@ -66,10 +60,6 @@
setup_chat_format,
)
from trl.commands.cli_utils import RewardScriptArguments
from trl.extras.dataset_formatting import conversations_formatting_function


tqdm.pandas()


if __name__ == "__main__":
Expand All @@ -90,6 +80,7 @@
revision=model_config.model_revision,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
use_cache=False if training_args.gradient_checkpointing else True,
)
tokenizer = AutoTokenizer.from_pretrained(
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, use_fast=True
Expand All @@ -110,49 +101,11 @@
" Make sure to pass --lora_task_type SEQ_CLS when using this script with PEFT."
)

#############################
# Load and preprocess dataset
#############################
##############
# Load dataset
##############
dataset = load_dataset(args.dataset_name)

def preprocess_function(examples):
new_examples = {
"input_ids_chosen": [],
"attention_mask_chosen": [],
"input_ids_rejected": [],
"attention_mask_rejected": [],
}
for chosen, rejected in zip(examples["chosen"], examples["rejected"]):
tokenized_chosen = tokenizer(chosen)
tokenized_rejected = tokenizer(rejected)
new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"])
new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"])
new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"])
new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"])

return new_examples

with PartialState().local_main_process_first():
# Wrap inputs with chat template.
# This assumes the chosen/rejected columns are in the OpenAI messages format.
chosen_fn = conversations_formatting_function(tokenizer, "chosen")
rejected_fn = conversations_formatting_function(tokenizer, "rejected")
dataset = dataset.map(
lambda x: {"chosen": chosen_fn(x), "rejected": rejected_fn(x)}, num_proc=training_args.dataset_num_proc
)
# Tokenize inputs
dataset = dataset.map(
preprocess_function,
batched=True,
num_proc=training_args.dataset_num_proc,
)
# Filter out examples that are too long
dataset = dataset.filter(
lambda x: len(x["input_ids_chosen"]) <= training_args.max_length
and len(x["input_ids_rejected"]) <= training_args.max_length,
num_proc=training_args.dataset_num_proc,
)

##########
# Training
##########
Expand Down
Loading

0 comments on commit cc23b51

Please sign in to comment.