Skip to content

Commit

Permalink
[KTOTrainer] add BCO (reward shift and underlying distribution matc…
Browse files Browse the repository at this point in the history
…hing) (huggingface#1599)

* add `Loss Functions` section in the doc.

* add bce loss with reward shift in KTOTrainer

* add underlying distribution matching

* update example to use underlying distribution matching

* add config description

* fix 'referenced before assignment' error

* add 'bco' and 'udm' test cases

* run pre-commit

* add `scikit-learn` dependency

* raise error is sklearn is not available

* call TrainingArguments().__post_init__() for proper init
  • Loading branch information
seanexp authored Apr 30, 2024
1 parent d88ec14 commit d1aa0b6
Show file tree
Hide file tree
Showing 7 changed files with 564 additions and 23 deletions.
7 changes: 7 additions & 0 deletions docs/source/kto_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,13 @@ After this one can then call:
kto_trainer.train()
```

## Loss Functions

Given the binary signal data indicating whether a completion is desirable or undesirable for a prompt, we can optimize an implicit reward function that aligns with the key principles of Kahneman-Tversky's prospect theory, such as reference dependence, loss aversion, and diminishing sensitivity.

The [BCO](https://arxiv.org/abs/2404.04656) authors train a binary classifier whose logit serves as a reward so that the classifier maps {prompt, chosen completion} pairs to 1 and {prompt, rejected completion} pairs to 0.
The `KTOTrainer` can be switched to this loss via the `loss_type="bco"` argument.

## KTOTrainer

[[autodoc]] KTOTrainer
Expand Down
223 changes: 223 additions & 0 deletions examples/scripts/bco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
"""
Run the BCO training script with the commands below. In general, the optimal configuration for BCO will be similar to that of KTO.
# Full training:
python examples/scripts/bco.py \
--model_name_or_path=nnheui/stablelm-2-1_6b-sft-full \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 32 \
--num_train_epochs 1 \
--learning_rate 1e-6 \
--gradient_checkpointing \
--gradient_accumulation_steps 1 \
--logging_steps 0.01 \
--eval_steps 0.2 \
--save_strategy no \
--output_dir=bco-aligned-model \
--logging_first_step \
--max_length 2048 \
--max_prompt_length 1536 \
--max_completion_length 1024 \
--no_remove_unused_columns \
--warmup_ratio 0.1 \
--bf16 \
--loss_type bco \
--report_to wandb
# QLoRA:
python examples/scripts/bco.py \
--model_name_or_path=nnheui/stablelm-2-1_6b-sft-full \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 32 \
--num_train_epochs 1 \
--learning_rate 1e-6 \
--gradient_checkpointing \
--gradient_accumulation_steps 1 \
--logging_steps 0.01 \
--eval_steps 0.2 \
--save_strategy no \
--output_dir=bco-aligned-model-lora \
--logging_first_step \
--warmup_ratio 0.1 \
--report_to wandb \
--max_length 2048 \
--max_prompt_length 1536 \
--max_completion_length 1024 \
--no_remove_unused_columns \
--warmup_ratio 0.1 \
--bf16 \
--loss_type bco \
--use_peft \
--load_in_4bit \
--lora_target_modules=all-linear \
--lora_r=16 \
--lora_alpha=16
"""

import logging
from dataclasses import dataclass
from functools import partial
from typing import Literal

import torch
import torch.nn.functional as F
from accelerate import Accelerator, PartialState
from datasets import Dataset, load_dataset
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, PreTrainedModel

from trl import KTOConfig, KTOTrainer, ModelConfig, get_peft_config, setup_chat_format


# Define and parse arguments.
@dataclass
class ScriptArguments:
"""
The arguments for the KTO training script.
"""

llm_name: Literal["gpt-3.5-turbo", "llama-2-7b-chat", "llama-2-70b-chat"] = "gpt-3.5-turbo"


def build_helpfulness_dataset(llm_name: str) -> Dataset:
"""
Filter `llm_name` completions and binarize given their helpfulness score.
If helpfulness score is 5, it is desirable. Otherwise, it is undesirable.
"""

def get_model_rating(example, metric: str, llm_name: str):
try:
model_index = example["models"].index(llm_name)
return {metric: int(example["completions"][model_index]["annotations"][metric]["Rating"])}
except ValueError as e:
logging.warning(e)
return -1

def get_model_response(example, llm_name: str):
try:
model_index = example["models"].index(llm_name)
return {"response": example["completions"][model_index]["response"]}
except ValueError as e:
logging.warning(e)
return -1

dataset = load_dataset("openbmb/UltraFeedback")["train"]

ds = dataset.filter(lambda example: llm_name in example["models"], batched=False, num_proc=8)
ds = ds.filter(lambda example: len(example["models"]) == len(example["completions"]), batched=False, num_proc=8)

METRIC = "helpfulness"

ds = ds.map(
get_model_rating,
batched=False,
num_proc=8,
fn_kwargs={"metric": METRIC, "llm_name": llm_name},
)

ds = ds.map(
get_model_response,
batched=False,
num_proc=8,
fn_kwargs={"llm_name": llm_name},
)

ds = ds.select_columns(["source", "instruction", "response", "helpfulness"])

ds = ds.rename_columns({"instruction": "prompt", "response": "completion"})
ds = ds.map(lambda example: {"label": example["helpfulness"] >= 5}, batched=False, num_proc=8)

ds = ds.map(
lambda example: {"prompt": [{"role": "user", "content": example["prompt"]}]},
batched=False,
num_proc=8,
)
dataset = ds.train_test_split(test_size=0.05, seed=42)

return dataset


def embed_prompt(input_ids: torch.LongTensor, attention_mask: torch.LongTensor, model: PreTrainedModel):
"""
Borrowed from https://huggingface.co/nomic-ai/nomic-embed-text-v1.5#transformers
"""

def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

with torch.no_grad():
model_output = model(input_ids=input_ids, attention_mask=attention_mask)
embeddings = mean_pooling(model_output, attention_mask)

matryoshka_dim = 512
# normalize embeddings
embeddings = F.normalize(embeddings, p=2, dim=1)
embeddings = F.layer_norm(embeddings, normalized_shape=(embeddings.shape[1],))
embeddings = embeddings[:, :matryoshka_dim]

return embeddings


if __name__ == "__main__":
parser = HfArgumentParser((ScriptArguments, KTOConfig, ModelConfig))
script_args, kto_args, model_args = parser.parse_args_into_dataclasses()

kto_args.gradient_checkpointing_kwargs = {"use_reentrant": True}

# Load a pretrained model
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path)
model_ref = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path)

tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

# If we are aligning a base model, we use ChatML as the default template
if tokenizer.chat_template is None:
model, tokenizer = setup_chat_format(model, tokenizer)

# Load the dataset
dataset = build_helpfulness_dataset(script_args.llm_name)

# Apply chat template
def format_dataset(example):
example["prompt"] = tokenizer.apply_chat_template(
example["prompt"], tokenize=False, add_generation_prompt=True
)
return example

with PartialState().local_main_process_first():
formatted_dataset = dataset.map(format_dataset, batched=False, num_proc=8)

accelerator = Accelerator()
embedding_model = AutoModel.from_pretrained(
"nomic-ai/nomic-embed-text-v1.5",
trust_remote_code=True,
safe_serialization=True,
torch_dtype=torch.bfloat16,
device_map="auto",
)
embedding_model = accelerator.prepare_model(embedding_model)
embedding_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
embedding_func = partial(
embed_prompt,
model=embedding_model,
)

# Initialize the KTO trainer
kto_trainer = KTOTrainer(
model,
model_ref,
args=kto_args,
train_dataset=formatted_dataset["train"],
eval_dataset=formatted_dataset["test"],
tokenizer=tokenizer,
peft_config=get_peft_config(model_args),
embedding_func=embedding_func,
embedding_tokenizer=embedding_tokenizer,
)

# Train and push the model to the Hub
kto_trainer.train()
kto_trainer.save_model(kto_args.output_dir)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
"tyro>=0.5.11",
]
EXTRAS = {
"test": ["parameterized", "pytest", "pytest-xdist", "accelerate", "pytest-cov", "pytest-xdist"],
"test": ["parameterized", "pytest", "pytest-xdist", "accelerate", "pytest-cov", "pytest-xdist", "scikit-learn"],
"peft": ["peft>=0.4.0"],
"diffusers": ["diffusers>=0.18.0"],
"deepspeed": ["deepspeed>=0.9.5"],
Expand Down
72 changes: 66 additions & 6 deletions tests/test_kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
# limitations under the License.
import tempfile
import unittest
from functools import partial

import torch
from accelerate import Accelerator
from datasets import Dataset
from parameterized import parameterized
from pytest import mark
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
from transformers import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer

from trl import KTOConfig, KTOTrainer
from trl.trainer.kto_trainer import _get_kl_dataset, _process_tokens, _tokenize
Expand All @@ -41,6 +43,11 @@ def setUpClass(cls):
cls.t5_ref_model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
cls.t5_tokenizer = AutoTokenizer.from_pretrained(model_id)

# get embedding model
model_id = "facebook/bart-base"
cls.embedding_model = AutoModel.from_pretrained(model_id)
cls.embedding_tokenizer = AutoTokenizer.from_pretrained(model_id)

def _init_dummy_dataset(self):
# fmt: off
dummy_dataset_dict = {
Expand Down Expand Up @@ -77,15 +84,19 @@ def _init_dummy_dataset(self):

@parameterized.expand(
[
["gpt2", True, True],
["gpt2", True, False],
["gpt2", "kto", True, True],
["gpt2", "kto", True, False],
# ["t5", True],
["gpt2", False, True],
["gpt2", False, False],
["gpt2", "kto", False, True],
["gpt2", "kto", False, False],
# ["t5", False],
["gpt2", "bco", True, True],
["gpt2", "bco", True, False],
["gpt2", "bco", False, True],
["gpt2", "bco", False, False],
]
)
def test_kto_trainer(self, name, pre_compute, eval_dataset):
def test_kto_trainer(self, name, loss_type, pre_compute, eval_dataset):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = KTOConfig(
output_dir=tmp_dir,
Expand All @@ -97,6 +108,7 @@ def test_kto_trainer(self, name, pre_compute, eval_dataset):
evaluation_strategy="steps",
beta=0.1,
precompute_ref_log_probs=pre_compute,
loss_type=loss_type,
)

dummy_dataset = self._init_dummy_dataset()
Expand Down Expand Up @@ -250,6 +262,54 @@ def test_kto_trainer_without_providing_ref_model(self):
if param.sum() != 0:
self.assertFalse(torch.equal(param, new_param))

def test_kto_trainer_bco_udm(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = KTOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=4,
learning_rate=9e-1,
evaluation_strategy="steps",
beta=0.1,
loss_type="bco",
)

dummy_dataset = self._init_dummy_dataset()

def embed_prompt(input_ids, attention_mask, model):
outputs = model(input_ids=input_ids, attention_mask=attention_mask)

return outputs.last_hidden_state.mean(dim=1)

embedding_model = Accelerator().prepare_model(self.embedding_model)
embedding_func = partial(embed_prompt, model=embedding_model)

trainer = KTOTrainer(
model=self.model,
ref_model=None,
args=training_args,
tokenizer=self.tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
embedding_func=embedding_func,
embedding_tokenizer=self.embedding_tokenizer,
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

# check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
# check the params have changed - ignore 0 biases
if param.sum() != 0:
self.assertFalse(torch.equal(param, new_param))

@require_peft
@mark.peft_test
def test_kto_trainer_without_providing_ref_model_with_lora(self):
Expand Down
4 changes: 4 additions & 0 deletions trl/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ def is_wandb_available() -> bool:
return find_spec("wandb") is not None


def is_sklearn_available() -> bool:
return find_spec("sklearn") is not None


def is_xpu_available() -> bool:
if is_accelerate_greater_20_0():
import accelerate
Expand Down
Loading

0 comments on commit d1aa0b6

Please sign in to comment.