Skip to content

Commit

Permalink
Move chat template formatting inside trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
lewtun committed Sep 23, 2024
1 parent fc9a2f9 commit 13b5ed0
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 10 deletions.
13 changes: 3 additions & 10 deletions examples/scripts/reward_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
import warnings

import torch
from accelerate import PartialState
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoModelForSequenceClassification, AutoTokenizer, HfArgumentParser
Expand All @@ -63,7 +62,6 @@
get_kbit_device_map,
get_peft_config,
get_quantization_config,
maybe_apply_chat_template,
setup_chat_format,
)
from trl.commands.cli_utils import RewardScriptArguments
Expand Down Expand Up @@ -110,16 +108,11 @@
" Make sure to pass --lora_task_type SEQ_CLS when using this script with PEFT."
)

#############################
# Load and preprocess dataset
#############################
##############
# Load dataset
##############
dataset = load_dataset(args.dataset_name)

with PartialState().local_main_process_first():
dataset = dataset.map(
maybe_apply_chat_template, num_proc=training_args.dataset_num_proc, fn_kwargs={"tokenizer": tokenizer}
)

##########
# Training
##########
Expand Down
3 changes: 3 additions & 0 deletions trl/trainer/reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from transformers.trainer_utils import EvalPrediction
from transformers.utils import is_peft_available

from ..data_utils import maybe_apply_chat_template
from .reward_config import RewardConfig
from .utils import (
RewardDataCollatorWithPadding,
Expand Down Expand Up @@ -229,6 +230,7 @@ def __init__(
if "input_ids" not in train_dataset.column_names:
with PartialState().local_main_process_first():
fn_kwargs = {"tokenizer": tokenizer}
train_dataset = train_dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": tokenizer})
train_dataset = train_dataset.map(
_tokenize,
batched=True,
Expand All @@ -243,6 +245,7 @@ def __init__(
num_proc=args.dataset_num_proc,
)
if eval_dataset is not None:
eval_dataset = eval_dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": tokenizer})
eval_dataset = eval_dataset.map(
_tokenize,
fn_kwargs=fn_kwargs,
Expand Down

0 comments on commit 13b5ed0

Please sign in to comment.