diff --git a/examples/scripts/reward_modeling.py b/examples/scripts/reward_modeling.py index 5f7c459632..5a038d210d 100644 --- a/examples/scripts/reward_modeling.py +++ b/examples/scripts/reward_modeling.py @@ -15,9 +15,9 @@ python examples/scripts/reward_modeling.py \ --model_name_or_path=facebook/opt-350m \ --output_dir="reward_modeling_anthropic_hh" \ - --per_device_train_batch_size=64 \ + --per_device_train_batch_size=16 \ --num_train_epochs=1 \ - --gradient_accumulation_steps=16 \ + --gradient_accumulation_steps=2 \ --gradient_checkpointing=True \ --learning_rate=1.41e-5 \ --report_to="wandb" \ @@ -25,6 +25,7 @@ --optim="adamw_torch" \ --logging_steps=10 \ --evaluation_strategy="steps" \ + --eval_steps=500 \ --max_length=512 \ """ import warnings @@ -42,8 +43,8 @@ if __name__ == "__main__": parser = HfArgumentParser((RewardConfig, ModelConfig)) - reward_config, model_config = parser.parse_args_into_dataclasses() - reward_config.gradient_checkpointing_kwargs = dict(use_reentrant=False) + config, model_config = parser.parse_args_into_dataclasses() + config.gradient_checkpointing_kwargs = dict(use_reentrant=False) ################ # Model & Tokenizer @@ -103,8 +104,7 @@ def preprocess_function(examples): num_proc=4, ) raw_datasets = raw_datasets.filter( - lambda x: len(x["input_ids_chosen"]) <= reward_config.max_length - and len(x["input_ids_rejected"]) <= reward_config.max_length + lambda x: len(x["input_ids_chosen"]) <= config.max_length and len(x["input_ids_rejected"]) <= config.max_length ) train_dataset = raw_datasets["train"] eval_dataset = raw_datasets["test"] @@ -115,10 +115,14 @@ def preprocess_function(examples): trainer = RewardTrainer( model=model, tokenizer=tokenizer, - args=reward_config, + args=config, train_dataset=train_dataset, eval_dataset=eval_dataset, peft_config=get_peft_config(model_config), ) trainer.train() - trainer.save_model(reward_config.output_dir) + trainer.save_model(config.output_dir) + trainer.push_to_hub() + metrics = trainer.evaluate() + trainer.log_metrics("eval", metrics) + print(metrics) diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index bbee5e705e..3b98d68d90 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -13,11 +13,14 @@ # limitations under the License. import inspect import warnings +from collections import defaultdict from dataclasses import FrozenInstanceError, replace from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import pandas as pd import torch import torch.nn as nn +from accelerate.utils import gather_object from datasets import Dataset from transformers import DataCollator, PreTrainedModel, PreTrainedTokenizerBase, Trainer, TrainingArguments from transformers.trainer_callback import TrainerCallback @@ -26,7 +29,7 @@ from ..import_utils import is_peft_available from .reward_config import RewardConfig -from .utils import RewardDataCollatorWithPadding, compute_accuracy +from .utils import RewardDataCollatorWithPadding, compute_accuracy, print_rich_table if is_peft_available(): @@ -279,3 +282,39 @@ def prediction_step( labels = self._prepare_inputs(labels) return loss, logits, labels + + def evaluate(self, *args, **kwargs): + num_print_samples = kwargs.pop("num_print_samples", 4) + self.visualize_samples(num_print_samples) + return super().evaluate(*args, **kwargs) + + def visualize_samples(self, num_print_samples: int): + """ + Visualize the reward model logits prediction + + Args: + num_print_samples (`int`, defaults to `4`): + The number of samples to print. Set to `-1` to print all samples. + """ + eval_dataloader = self.get_eval_dataloader() + table = defaultdict(list) + for _, inputs in enumerate(eval_dataloader): + _, logits, _ = self.prediction_step(self.model, inputs, prediction_loss_only=False) + chosen_text = self.tokenizer.batch_decode(inputs["input_ids_chosen"], skip_special_tokens=True) + rejected_text = self.tokenizer.batch_decode(inputs["input_ids_rejected"], skip_special_tokens=True) + table["chosen_text"].extend(gather_object(chosen_text)) + table["rejected_text"].extend(gather_object(rejected_text)) + table["logits"].extend( + gather_object([[round(inner_item, 4) for inner_item in item] for item in logits.tolist()]) + ) + if num_print_samples >= 0 and len(table["chosen_text"]) >= num_print_samples: + break + df = pd.DataFrame(table) + print_rich_table(pd.DataFrame(table)) + if self.accelerator.process_index == 0: + print_rich_table(df[:num_print_samples]) + if "wandb" in self.args.report_to: + import wandb + + if wandb.run is not None: + wandb.log({"completions": wandb.Table(dataframe=df)}) diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 31e11f84a4..9cb4d26b95 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -18,15 +18,21 @@ from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np +import pandas as pd import torch from accelerate import PartialState from rich.console import Console, Group from rich.live import Live from rich.panel import Panel from rich.progress import Progress +from rich.table import Table from torch.nn.utils.rnn import pad_sequence from torch.utils.data import IterableDataset -from transformers import BitsAndBytesConfig, DataCollatorForLanguageModeling, PreTrainedTokenizerBase +from transformers import ( + BitsAndBytesConfig, + DataCollatorForLanguageModeling, + PreTrainedTokenizerBase, +) from transformers.trainer import TrainerCallback from transformers.trainer_utils import has_length @@ -815,3 +821,13 @@ def on_train_end(self, args, state, control, **kwargs): self.rich_console = None self.training_status = None self.current_step = None + + +def print_rich_table(df: pd.DataFrame) -> Table: + console = Console() + table = Table(show_lines=True) + for column in df.columns: + table.add_column(column) + for _, row in df.iterrows(): + table.add_row(*row.astype(str).tolist()) + console.print(table)