Skip to content

Commit

Permalink
visualize rm prediction (#1636)
Browse files Browse the repository at this point in the history
* visualize rm prediction

* quick update

* quick check

* quick fix

* update eval steps
  • Loading branch information
vwxyzjn authored May 10, 2024
1 parent 3b4c249 commit 8799952
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 10 deletions.
20 changes: 12 additions & 8 deletions examples/scripts/reward_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,17 @@
python examples/scripts/reward_modeling.py \
--model_name_or_path=facebook/opt-350m \
--output_dir="reward_modeling_anthropic_hh" \
--per_device_train_batch_size=64 \
--per_device_train_batch_size=16 \
--num_train_epochs=1 \
--gradient_accumulation_steps=16 \
--gradient_accumulation_steps=2 \
--gradient_checkpointing=True \
--learning_rate=1.41e-5 \
--report_to="wandb" \
--remove_unused_columns=False \
--optim="adamw_torch" \
--logging_steps=10 \
--evaluation_strategy="steps" \
--eval_steps=500 \
--max_length=512 \
"""
import warnings
Expand All @@ -42,8 +43,8 @@

if __name__ == "__main__":
parser = HfArgumentParser((RewardConfig, ModelConfig))
reward_config, model_config = parser.parse_args_into_dataclasses()
reward_config.gradient_checkpointing_kwargs = dict(use_reentrant=False)
config, model_config = parser.parse_args_into_dataclasses()
config.gradient_checkpointing_kwargs = dict(use_reentrant=False)

################
# Model & Tokenizer
Expand Down Expand Up @@ -103,8 +104,7 @@ def preprocess_function(examples):
num_proc=4,
)
raw_datasets = raw_datasets.filter(
lambda x: len(x["input_ids_chosen"]) <= reward_config.max_length
and len(x["input_ids_rejected"]) <= reward_config.max_length
lambda x: len(x["input_ids_chosen"]) <= config.max_length and len(x["input_ids_rejected"]) <= config.max_length
)
train_dataset = raw_datasets["train"]
eval_dataset = raw_datasets["test"]
Expand All @@ -115,10 +115,14 @@ def preprocess_function(examples):
trainer = RewardTrainer(
model=model,
tokenizer=tokenizer,
args=reward_config,
args=config,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=get_peft_config(model_config),
)
trainer.train()
trainer.save_model(reward_config.output_dir)
trainer.save_model(config.output_dir)
trainer.push_to_hub()
metrics = trainer.evaluate()
trainer.log_metrics("eval", metrics)
print(metrics)
41 changes: 40 additions & 1 deletion trl/trainer/reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@
# limitations under the License.
import inspect
import warnings
from collections import defaultdict
from dataclasses import FrozenInstanceError, replace
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import pandas as pd
import torch
import torch.nn as nn
from accelerate.utils import gather_object
from datasets import Dataset
from transformers import DataCollator, PreTrainedModel, PreTrainedTokenizerBase, Trainer, TrainingArguments
from transformers.trainer_callback import TrainerCallback
Expand All @@ -26,7 +29,7 @@

from ..import_utils import is_peft_available
from .reward_config import RewardConfig
from .utils import RewardDataCollatorWithPadding, compute_accuracy
from .utils import RewardDataCollatorWithPadding, compute_accuracy, print_rich_table


if is_peft_available():
Expand Down Expand Up @@ -279,3 +282,39 @@ def prediction_step(
labels = self._prepare_inputs(labels)

return loss, logits, labels

def evaluate(self, *args, **kwargs):
num_print_samples = kwargs.pop("num_print_samples", 4)
self.visualize_samples(num_print_samples)
return super().evaluate(*args, **kwargs)

def visualize_samples(self, num_print_samples: int):
"""
Visualize the reward model logits prediction
Args:
num_print_samples (`int`, defaults to `4`):
The number of samples to print. Set to `-1` to print all samples.
"""
eval_dataloader = self.get_eval_dataloader()
table = defaultdict(list)
for _, inputs in enumerate(eval_dataloader):
_, logits, _ = self.prediction_step(self.model, inputs, prediction_loss_only=False)
chosen_text = self.tokenizer.batch_decode(inputs["input_ids_chosen"], skip_special_tokens=True)
rejected_text = self.tokenizer.batch_decode(inputs["input_ids_rejected"], skip_special_tokens=True)
table["chosen_text"].extend(gather_object(chosen_text))
table["rejected_text"].extend(gather_object(rejected_text))
table["logits"].extend(
gather_object([[round(inner_item, 4) for inner_item in item] for item in logits.tolist()])
)
if num_print_samples >= 0 and len(table["chosen_text"]) >= num_print_samples:
break
df = pd.DataFrame(table)
print_rich_table(pd.DataFrame(table))
if self.accelerator.process_index == 0:
print_rich_table(df[:num_print_samples])
if "wandb" in self.args.report_to:
import wandb

if wandb.run is not None:
wandb.log({"completions": wandb.Table(dataframe=df)})
18 changes: 17 additions & 1 deletion trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,21 @@
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import pandas as pd
import torch
from accelerate import PartialState
from rich.console import Console, Group
from rich.live import Live
from rich.panel import Panel
from rich.progress import Progress
from rich.table import Table
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import IterableDataset
from transformers import BitsAndBytesConfig, DataCollatorForLanguageModeling, PreTrainedTokenizerBase
from transformers import (
BitsAndBytesConfig,
DataCollatorForLanguageModeling,
PreTrainedTokenizerBase,
)
from transformers.trainer import TrainerCallback
from transformers.trainer_utils import has_length

Expand Down Expand Up @@ -815,3 +821,13 @@ def on_train_end(self, args, state, control, **kwargs):
self.rich_console = None
self.training_status = None
self.current_step = None


def print_rich_table(df: pd.DataFrame) -> Table:
console = Console()
table = Table(show_lines=True)
for column in df.columns:
table.add_column(column)
for _, row in df.iterrows():
table.add_row(*row.astype(str).tolist())
console.print(table)

0 comments on commit 8799952

Please sign in to comment.