diff --git a/docs/source/kto_trainer.mdx b/docs/source/kto_trainer.mdx index 179880503f..713cc6c622 100644 --- a/docs/source/kto_trainer.mdx +++ b/docs/source/kto_trainer.mdx @@ -84,6 +84,13 @@ After this one can then call: kto_trainer.train() ``` +## Loss Functions + +Given the binary signal data indicating whether a completion is desirable or undesirable for a prompt, we can optimize an implicit reward function that aligns with the key principles of Kahneman-Tversky's prospect theory, such as reference dependence, loss aversion, and diminishing sensitivity. + +The [BCO](https://arxiv.org/abs/2404.04656) authors train a binary classifier whose logit serves as a reward so that the classifier maps {prompt, chosen completion} pairs to 1 and {prompt, rejected completion} pairs to 0. +The `KTOTrainer` can be switched to this loss via the `loss_type="bco"` argument. + ## KTOTrainer [[autodoc]] KTOTrainer diff --git a/examples/scripts/bco.py b/examples/scripts/bco.py new file mode 100644 index 0000000000..9f96e3a914 --- /dev/null +++ b/examples/scripts/bco.py @@ -0,0 +1,223 @@ +""" +Run the BCO training script with the commands below. In general, the optimal configuration for BCO will be similar to that of KTO. + +# Full training: +python examples/scripts/bco.py \ + --model_name_or_path=nnheui/stablelm-2-1_6b-sft-full \ + --per_device_train_batch_size 16 \ + --per_device_eval_batch_size 32 \ + --num_train_epochs 1 \ + --learning_rate 1e-6 \ + --gradient_checkpointing \ + --gradient_accumulation_steps 1 \ + --logging_steps 0.01 \ + --eval_steps 0.2 \ + --save_strategy no \ + --output_dir=bco-aligned-model \ + --logging_first_step \ + --max_length 2048 \ + --max_prompt_length 1536 \ + --max_completion_length 1024 \ + --no_remove_unused_columns \ + --warmup_ratio 0.1 \ + --bf16 \ + --loss_type bco \ + --report_to wandb + +# QLoRA: +python examples/scripts/bco.py \ + --model_name_or_path=nnheui/stablelm-2-1_6b-sft-full \ + --per_device_train_batch_size 16 \ + --per_device_eval_batch_size 32 \ + --num_train_epochs 1 \ + --learning_rate 1e-6 \ + --gradient_checkpointing \ + --gradient_accumulation_steps 1 \ + --logging_steps 0.01 \ + --eval_steps 0.2 \ + --save_strategy no \ + --output_dir=bco-aligned-model-lora \ + --logging_first_step \ + --warmup_ratio 0.1 \ + --report_to wandb \ + --max_length 2048 \ + --max_prompt_length 1536 \ + --max_completion_length 1024 \ + --no_remove_unused_columns \ + --warmup_ratio 0.1 \ + --bf16 \ + --loss_type bco \ + --use_peft \ + --load_in_4bit \ + --lora_target_modules=all-linear \ + --lora_r=16 \ + --lora_alpha=16 +""" + +import logging +from dataclasses import dataclass +from functools import partial +from typing import Literal + +import torch +import torch.nn.functional as F +from accelerate import Accelerator, PartialState +from datasets import Dataset, load_dataset +from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, PreTrainedModel + +from trl import KTOConfig, KTOTrainer, ModelConfig, get_peft_config, setup_chat_format + + +# Define and parse arguments. +@dataclass +class ScriptArguments: + """ + The arguments for the KTO training script. + """ + + llm_name: Literal["gpt-3.5-turbo", "llama-2-7b-chat", "llama-2-70b-chat"] = "gpt-3.5-turbo" + + +def build_helpfulness_dataset(llm_name: str) -> Dataset: + """ + Filter `llm_name` completions and binarize given their helpfulness score. + If helpfulness score is 5, it is desirable. Otherwise, it is undesirable. + """ + + def get_model_rating(example, metric: str, llm_name: str): + try: + model_index = example["models"].index(llm_name) + return {metric: int(example["completions"][model_index]["annotations"][metric]["Rating"])} + except ValueError as e: + logging.warning(e) + return -1 + + def get_model_response(example, llm_name: str): + try: + model_index = example["models"].index(llm_name) + return {"response": example["completions"][model_index]["response"]} + except ValueError as e: + logging.warning(e) + return -1 + + dataset = load_dataset("openbmb/UltraFeedback")["train"] + + ds = dataset.filter(lambda example: llm_name in example["models"], batched=False, num_proc=8) + ds = ds.filter(lambda example: len(example["models"]) == len(example["completions"]), batched=False, num_proc=8) + + METRIC = "helpfulness" + + ds = ds.map( + get_model_rating, + batched=False, + num_proc=8, + fn_kwargs={"metric": METRIC, "llm_name": llm_name}, + ) + + ds = ds.map( + get_model_response, + batched=False, + num_proc=8, + fn_kwargs={"llm_name": llm_name}, + ) + + ds = ds.select_columns(["source", "instruction", "response", "helpfulness"]) + + ds = ds.rename_columns({"instruction": "prompt", "response": "completion"}) + ds = ds.map(lambda example: {"label": example["helpfulness"] >= 5}, batched=False, num_proc=8) + + ds = ds.map( + lambda example: {"prompt": [{"role": "user", "content": example["prompt"]}]}, + batched=False, + num_proc=8, + ) + dataset = ds.train_test_split(test_size=0.05, seed=42) + + return dataset + + +def embed_prompt(input_ids: torch.LongTensor, attention_mask: torch.LongTensor, model: PreTrainedModel): + """ + Borrowed from https://huggingface.co/nomic-ai/nomic-embed-text-v1.5#transformers + """ + + def mean_pooling(model_output, attention_mask): + token_embeddings = model_output[0] + input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) + + with torch.no_grad(): + model_output = model(input_ids=input_ids, attention_mask=attention_mask) + embeddings = mean_pooling(model_output, attention_mask) + + matryoshka_dim = 512 + # normalize embeddings + embeddings = F.normalize(embeddings, p=2, dim=1) + embeddings = F.layer_norm(embeddings, normalized_shape=(embeddings.shape[1],)) + embeddings = embeddings[:, :matryoshka_dim] + + return embeddings + + +if __name__ == "__main__": + parser = HfArgumentParser((ScriptArguments, KTOConfig, ModelConfig)) + script_args, kto_args, model_args = parser.parse_args_into_dataclasses() + + kto_args.gradient_checkpointing_kwargs = {"use_reentrant": True} + + # Load a pretrained model + model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path) + model_ref = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path) + + tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + # If we are aligning a base model, we use ChatML as the default template + if tokenizer.chat_template is None: + model, tokenizer = setup_chat_format(model, tokenizer) + + # Load the dataset + dataset = build_helpfulness_dataset(script_args.llm_name) + + # Apply chat template + def format_dataset(example): + example["prompt"] = tokenizer.apply_chat_template( + example["prompt"], tokenize=False, add_generation_prompt=True + ) + return example + + with PartialState().local_main_process_first(): + formatted_dataset = dataset.map(format_dataset, batched=False, num_proc=8) + + accelerator = Accelerator() + embedding_model = AutoModel.from_pretrained( + "nomic-ai/nomic-embed-text-v1.5", + trust_remote_code=True, + safe_serialization=True, + torch_dtype=torch.bfloat16, + device_map="auto", + ) + embedding_model = accelerator.prepare_model(embedding_model) + embedding_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") + embedding_func = partial( + embed_prompt, + model=embedding_model, + ) + + # Initialize the KTO trainer + kto_trainer = KTOTrainer( + model, + model_ref, + args=kto_args, + train_dataset=formatted_dataset["train"], + eval_dataset=formatted_dataset["test"], + tokenizer=tokenizer, + peft_config=get_peft_config(model_args), + embedding_func=embedding_func, + embedding_tokenizer=embedding_tokenizer, + ) + + # Train and push the model to the Hub + kto_trainer.train() + kto_trainer.save_model(kto_args.output_dir) diff --git a/setup.py b/setup.py index b80580061d..008e6912a6 100644 --- a/setup.py +++ b/setup.py @@ -69,7 +69,7 @@ "tyro>=0.5.11", ] EXTRAS = { - "test": ["parameterized", "pytest", "pytest-xdist", "accelerate", "pytest-cov", "pytest-xdist"], + "test": ["parameterized", "pytest", "pytest-xdist", "accelerate", "pytest-cov", "pytest-xdist", "scikit-learn"], "peft": ["peft>=0.4.0"], "diffusers": ["diffusers>=0.18.0"], "deepspeed": ["deepspeed>=0.9.5"], diff --git a/tests/test_kto_trainer.py b/tests/test_kto_trainer.py index d4f5d686f6..b622e7ea7b 100644 --- a/tests/test_kto_trainer.py +++ b/tests/test_kto_trainer.py @@ -13,12 +13,14 @@ # limitations under the License. import tempfile import unittest +from functools import partial import torch +from accelerate import Accelerator from datasets import Dataset from parameterized import parameterized from pytest import mark -from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer +from transformers import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer from trl import KTOConfig, KTOTrainer from trl.trainer.kto_trainer import _get_kl_dataset, _process_tokens, _tokenize @@ -41,6 +43,11 @@ def setUpClass(cls): cls.t5_ref_model = AutoModelForSeq2SeqLM.from_pretrained(model_id) cls.t5_tokenizer = AutoTokenizer.from_pretrained(model_id) + # get embedding model + model_id = "facebook/bart-base" + cls.embedding_model = AutoModel.from_pretrained(model_id) + cls.embedding_tokenizer = AutoTokenizer.from_pretrained(model_id) + def _init_dummy_dataset(self): # fmt: off dummy_dataset_dict = { @@ -77,15 +84,19 @@ def _init_dummy_dataset(self): @parameterized.expand( [ - ["gpt2", True, True], - ["gpt2", True, False], + ["gpt2", "kto", True, True], + ["gpt2", "kto", True, False], # ["t5", True], - ["gpt2", False, True], - ["gpt2", False, False], + ["gpt2", "kto", False, True], + ["gpt2", "kto", False, False], # ["t5", False], + ["gpt2", "bco", True, True], + ["gpt2", "bco", True, False], + ["gpt2", "bco", False, True], + ["gpt2", "bco", False, False], ] ) - def test_kto_trainer(self, name, pre_compute, eval_dataset): + def test_kto_trainer(self, name, loss_type, pre_compute, eval_dataset): with tempfile.TemporaryDirectory() as tmp_dir: training_args = KTOConfig( output_dir=tmp_dir, @@ -97,6 +108,7 @@ def test_kto_trainer(self, name, pre_compute, eval_dataset): evaluation_strategy="steps", beta=0.1, precompute_ref_log_probs=pre_compute, + loss_type=loss_type, ) dummy_dataset = self._init_dummy_dataset() @@ -250,6 +262,54 @@ def test_kto_trainer_without_providing_ref_model(self): if param.sum() != 0: self.assertFalse(torch.equal(param, new_param)) + def test_kto_trainer_bco_udm(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = KTOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=4, + learning_rate=9e-1, + evaluation_strategy="steps", + beta=0.1, + loss_type="bco", + ) + + dummy_dataset = self._init_dummy_dataset() + + def embed_prompt(input_ids, attention_mask, model): + outputs = model(input_ids=input_ids, attention_mask=attention_mask) + + return outputs.last_hidden_state.mean(dim=1) + + embedding_model = Accelerator().prepare_model(self.embedding_model) + embedding_func = partial(embed_prompt, model=embedding_model) + + trainer = KTOTrainer( + model=self.model, + ref_model=None, + args=training_args, + tokenizer=self.tokenizer, + train_dataset=dummy_dataset, + eval_dataset=dummy_dataset, + embedding_func=embedding_func, + embedding_tokenizer=self.embedding_tokenizer, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + # check the params have changed - ignore 0 biases + if param.sum() != 0: + self.assertFalse(torch.equal(param, new_param)) + @require_peft @mark.peft_test def test_kto_trainer_without_providing_ref_model_with_lora(self): diff --git a/trl/import_utils.py b/trl/import_utils.py index 78cd03ed7b..df44a38aa8 100644 --- a/trl/import_utils.py +++ b/trl/import_utils.py @@ -97,6 +97,10 @@ def is_wandb_available() -> bool: return find_spec("wandb") is not None +def is_sklearn_available() -> bool: + return find_spec("sklearn") is not None + + def is_xpu_available() -> bool: if is_accelerate_greater_20_0(): import accelerate diff --git a/trl/trainer/kto_config.py b/trl/trainer/kto_config.py index 52eb635b47..1743b93bb7 100644 --- a/trl/trainer/kto_config.py +++ b/trl/trainer/kto_config.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass -from typing import Dict, Optional +from typing import Dict, Literal, Optional from transformers import TrainingArguments +from ..import_utils import is_sklearn_available + @dataclass class KTOConfig(TrainingArguments): @@ -58,6 +60,14 @@ class KTOConfig(TrainingArguments): Dict of Optional kwargs to pass when instantiating the ref model from a string. dataset_num_proc: (`Optional[int]`, *optional*, defaults to `None`): Number of processes to use for processing the datasets. + loss_type: (`Literal["kto", "bco"]`, *optional*): + The type of loss to use. Either `"kto"` the default KTO loss, `"bco"` loss from [BCO](https://arxiv.org/abs/2404.04656) paper. + prompt_sample_size: (`int`, defaults to 1024): + Number of prompts that are fed to density ratio classifier. + min_density_ratio: (`float`, defaults to 0.5): + The minimum value of the density ratio. The estimated density ratio is clamped to this value. + max_density_ratio: (`float`, defaults to 10.0): + The maximum value of the density ratio. The estimated density ratio is clamped to this value. """ max_length: Optional[int] = None @@ -82,3 +92,19 @@ class KTOConfig(TrainingArguments): model_init_kwargs: Optional[Dict] = None ref_model_init_kwargs: Optional[Dict] = None dataset_num_proc: Optional[int] = None + + loss_type: Literal["kto", "bco"] = "kto" + + # BCO config + prompt_sample_size: int = 1024 + min_density_ratio: float = 0.5 + max_density_ratio: float = 10.0 + + def __post_init__(self): + super().__post_init__() + + if self.loss_type == "bco" and not is_sklearn_available(): + raise ImportError( + "You need to install scikit-learn to use loss_type='bco' " + "You can install it with `pip install scikit-learn`." + ) diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index 44232b59fe..e66d4a24d0 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -28,7 +28,8 @@ import torch.nn.functional as F from accelerate import PartialState from accelerate.utils import is_deepspeed_available, tqdm -from datasets import Dataset, concatenate_datasets +from datasets import Dataset, concatenate_datasets, interleave_datasets +from torch.nn.utils.rnn import pad_sequence from torch.utils.data import DataLoader, SequentialSampler from transformers import ( AutoModelForCausalLM, @@ -41,11 +42,12 @@ from transformers.trainer_callback import TrainerCallback from transformers.trainer_utils import EvalLoopOutput, has_length -from ..import_utils import is_peft_available, is_wandb_available +from ..import_utils import is_peft_available, is_sklearn_available, is_wandb_available from ..models import PreTrainedModelWrapper, create_reference_model from .kto_config import KTOConfig from .utils import ( DPODataCollatorWithPadding, + RunningMoments, disable_dropout_in_model, pad_to_length, peft_module_casting_to_bf16, @@ -60,6 +62,9 @@ if is_wandb_available(): import wandb +if is_sklearn_available(): + from sklearn.linear_model import LogisticRegression + if is_deepspeed_available(): import deepspeed @@ -74,8 +79,12 @@ def _get_kl_dataset(batch: Dict[str, List[Any]]) -> Dict[str, List[Any]]: return batch -def _tokenize(batch: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer") -> Dict[str, List[Any]]: - """Tokenize a batch from a KTO specific dataset.""" +def _tokenize( + batch: Dict[str, List[Any]], + tokenizer: "PreTrainedTokenizer", + embedding_tokenizer: Optional["PreTrainedTokenizer"] = None, +) -> Dict[str, List[Any]]: + """Tokenize a batch from a KTO/BCO specific dataset.""" prompt_tokenized = tokenizer(batch["prompt"], add_special_tokens=False) prompt_input_ids = prompt_tokenized["input_ids"] prompt_attention_mask = prompt_tokenized["attention_mask"] @@ -117,13 +126,25 @@ def _tokenize(batch: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer") -> answer_input_ids = [f[r:] for f, r in zip(full_input_ids, response_token_ids_start_idx)] answer_attention_mask = [f[r:] for f, r in zip(full_attention_mask, response_token_ids_start_idx)] - return dict( + output = dict( prompt_input_ids=prompt_input_ids, prompt_attention_mask=prompt_attention_mask, answer_input_ids=answer_input_ids, answer_attention_mask=answer_attention_mask, ) + if embedding_tokenizer is not None: + embedding_tokenized = embedding_tokenizer(batch["prompt"], truncation=True, add_special_tokens=False) + + output.update( + { + "embedding_input_ids": embedding_tokenized["input_ids"], + "embedding_attention_mask": embedding_tokenized["attention_mask"], + } + ) + + return output + def _process_tokens(example: Dict[str, Any], model: "PreTrainedModel" = None, **kwargs) -> Dict: """Process tokens of a KTO specific dataset. @@ -282,6 +303,8 @@ def __init__( compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None, model_adapter_name: Optional[str] = None, ref_adapter_name: Optional[str] = None, + embedding_func: Optional[Callable] = None, + embedding_tokenizer: Optional[PreTrainedTokenizerBase] = None, ): if type(args) == TrainingArguments: raise ValueError("Please use `KTOConfig` instead TrainingArguments.") @@ -492,6 +515,11 @@ def make_inputs_require_grad(module, input, output): self.desirable_weight = args.desirable_weight self.undesirable_weight = args.undesirable_weight + self.loss_type = args.loss_type + # Underlying Distribution Matching argument + self.embedding_func = embedding_func + self.embedding_tokenizer = embedding_tokenizer + with PartialState().local_main_process_first(): # Shuffle the datasets train_dataset = train_dataset.shuffle(seed=args.data_seed) @@ -500,7 +528,7 @@ def make_inputs_require_grad(module, input, output): # Tokenize and prepare the training datasets train_dataset = train_dataset.map( _tokenize, - fn_kwargs={"tokenizer": self.tokenizer}, + fn_kwargs={"tokenizer": self.tokenizer, "embedding_tokenizer": self.embedding_tokenizer}, batched=True, desc="Tokenizing train dataset", ) @@ -549,7 +577,7 @@ def make_inputs_require_grad(module, input, output): # Tokenize eval_dataset = eval_dataset.map( _tokenize, - fn_kwargs={"tokenizer": self.tokenizer}, + fn_kwargs={"tokenizer": self.tokenizer, "embedding_tokenizer": self.embedding_tokenizer}, batched=True, desc="Tokenizing eval dataset", ) @@ -613,6 +641,17 @@ def make_inputs_require_grad(module, input, output): UserWarning, ) + if self.loss_type == "bco": + desirable = desirable.shuffle(seed=args.data_seed) + undesirable = undesirable.shuffle(seed=args.data_seed) + + # split the dataset and interleave them together with equal probability of choosing chosen or rejected + interleaved_train_dataset = interleave_datasets( + [desirable, undesirable], + stopping_strategy="all_exhausted", + ) + train_dataset = interleaved_train_dataset + super().__init__( model=model, args=args, @@ -654,6 +693,117 @@ def make_inputs_require_grad(module, input, output): else: self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) + if self.loss_type == "bco": + self.running = RunningMoments(self.accelerator) + + if self.embedding_func is None: + return + + chosen_embeddings = self._get_sample_prompt_embeddings(desirable, sample_size=self.args.prompt_sample_size) + rejected_embeddings = self._get_sample_prompt_embeddings(undesirable, sample_size=self.args.prompt_sample_size) + + embeddings = torch.cat((chosen_embeddings, rejected_embeddings), dim=0) + labels = torch.cat( + (torch.ones_like(chosen_embeddings[:, 0]), torch.zeros_like(rejected_embeddings[:, 0])), dim=0 + ) + + self.clf = LogisticRegression(class_weight="balanced").fit(embeddings.cpu().numpy(), labels.cpu().numpy()) + + @property + def match_underlying_distribution(self): + return self.embedding_func is not None and self.embedding_tokenizer is not None + + def _get_chosen_prob(self, prompt_embeddings: torch.FloatTensor) -> torch.FloatTensor: + """ + Calculates the probability if the given prompt embedding is from desirable dataset. + This function calculates the probability in the process and ensemble across processes. + """ + dtype = prompt_embeddings.dtype + device = prompt_embeddings.device + sample_size = prompt_embeddings.shape[0] + + padded_prompt_embeddings = self.accelerator.pad_across_processes(prompt_embeddings) + nonzero = padded_prompt_embeddings.sum(dim=1) != 0 + prompt_embeddings = self.accelerator.gather(padded_prompt_embeddings) + + prob = self.clf.predict_proba(prompt_embeddings.cpu().numpy())[:, 1] + prob = torch.as_tensor(prob, dtype=dtype, device=device) + prob = self.accelerator.reduce(prob, reduction="mean") + + rank = self.accelerator.process_index + prob = prob[sample_size * rank : sample_size * (rank + 1)] + prob = prob[nonzero] + + return prob + + def _vectorize_prompt(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor) -> torch.FloatTensor: + """ + Replaces tokenizer.pad_token_id to embedding_tokenizer.pad_token_id + and applies self.embedding_func + """ + input_ids = torch.where( + input_ids == self.tokenizer.pad_token_id, + self.embedding_tokenizer.pad_token_id, + input_ids, + ) + + with torch.no_grad(): + embeddings = self.embedding_func( + input_ids=input_ids, + attention_mask=attention_mask, + ) + + return embeddings + + def _get_prompt_embeddings( + self, batch: Dict[str, Union[List, torch.LongTensor]] + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + """Extract embeddings from frozen embedding model""" + + if not self.match_underlying_distribution: + return None, None + + embeddings = self._vectorize_prompt( + input_ids=batch["embedding_input_ids"], + attention_mask=batch["embedding_attention_mask"], + ) + + chosen_idx = [i for i in range(len(batch["label"])) if batch["label"][i] is True] + rejected_idx = [i for i in range(len(batch["label"])) if batch["label"][i] is False] + + chosen_embeddings = embeddings[chosen_idx, ...] + rejected_embeddings = embeddings[rejected_idx, ...] + + return (chosen_embeddings, rejected_embeddings) + + def _get_sample_prompt_embeddings(self, dataset: Dataset, sample_size: int = 512) -> torch.FloatTensor: + """ + Sample instances from dataset and get prompt embeddings. + Used for density ratio classifier training. + """ + n_samples = min(len(dataset), sample_size) + rand_indices = np.random.choice(len(dataset), size=(n_samples,)) + + batch = dataset.select(rand_indices) + input_ids = pad_sequence( + [torch.as_tensor(ids) for ids in batch["embedding_input_ids"]], + batch_first=True, + padding_value=self.embedding_tokenizer.pad_token_id, + ).to(self.accelerator.device) + attention_mask = pad_sequence( + [torch.as_tensor(ids) for ids in batch["embedding_attention_mask"]], + batch_first=True, + padding_value=0, + ).to(self.accelerator.device) + + with torch.no_grad(): + embeddings = self.embedding_func( + input_ids=input_ids, + attention_mask=attention_mask, + ) + + return embeddings + def _prepare_deepspeed(self, model: PreTrainedModelWrapper): # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473 deepspeed_plugin = self.accelerator.state.deepspeed_plugin @@ -1019,6 +1169,65 @@ def kto_loss( return losses, chosen_rewards, rejected_rewards, kl + def _get_udm_weight(self, rejected_embeddings: torch.FloatTensor) -> torch.FloatTensor: + prob_desirable = self._get_chosen_prob(rejected_embeddings) + min_ratio = self.args.min_density_ratio + max_ratio = self.args.max_density_ratio + + weight = (prob_desirable / (1 - prob_desirable + 1e-8)).clamp(min=min_ratio, max=max_ratio) + + return weight + + def bco_loss( + self, + policy_chosen_logps: torch.FloatTensor, + policy_rejected_logps: torch.FloatTensor, + reference_chosen_logps: torch.FloatTensor, + reference_rejected_logps: torch.FloatTensor, + chosen_embeddings: Optional[torch.FloatTensor], + rejected_embeddings: Optional[torch.FloatTensor], + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Compute the BCO loss for a batch of policy and reference model log probabilities. + + Args: + policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (num(chosen) in batch_size,) + policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (num(rejected) in batch_size,) + reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (num(chosen) in batch_size,) + reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (num(rejected) in batch_size,) + chosen_embeddings: embeddings of desirable prompts + rejected_embeddings: embeddings of undesirable prompts + + Returns: + A tuple of four tensors: (losses, chosen_rewards, rejected_rewards, KL). + The losses tensor contains the KTO loss for each example in the batch. + The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively. + The delta value contains the moving average of all implicit rewards. + """ + assert policy_chosen_logps.shape[0] > 0, f"no chosen data at {self.accelerator.local_process_index}" + assert policy_rejected_logps.shape[0] > 0, f"no rejected data at {self.accelerator.local_process_index}" + chosen_logratios = policy_chosen_logps - reference_chosen_logps + rejected_logratios = policy_rejected_logps - reference_rejected_logps + + chosen_rewards = self.beta * chosen_logratios + rejected_rewards = self.beta * rejected_logratios + + rewards = torch.cat((chosen_rewards, rejected_rewards), 0).mean().detach() + self.running.update(rewards) + rewards_mean = self.running.mean + + chosen_losses = -F.logsigmoid(chosen_rewards - rewards_mean) + rejected_losses = -F.logsigmoid(-(rejected_rewards - rewards_mean)) + + if self.match_underlying_distribution: + chosen_weight = torch.ones_like(chosen_losses) + rejected_weight = self._get_udm_weight(rejected_embeddings) + + losses = torch.cat((chosen_weight * chosen_losses, rejected_weight * rejected_losses), dim=0) + else: + losses = torch.cat((chosen_losses, rejected_losses), dim=0) + + return losses, chosen_rewards, rejected_rewards, torch.as_tensor(rewards_mean) + def get_batch_loss_metrics( self, model, @@ -1064,14 +1273,26 @@ def get_batch_loss_metrics( reference_KL_logps, ) = self.forward(self.ref_model, batch) - losses, chosen_rewards, rejected_rewards, kl = self.kto_loss( - policy_chosen_logps, - policy_rejected_logps, - policy_KL_logps, - reference_chosen_logps, - reference_rejected_logps, - reference_KL_logps, - ) + if self.loss_type == "kto": + losses, chosen_rewards, rejected_rewards, kl = self.kto_loss( + policy_chosen_logps, + policy_rejected_logps, + policy_KL_logps, + reference_chosen_logps, + reference_rejected_logps, + reference_KL_logps, + ) + elif self.loss_type == "bco": + chosen_embeddings, rejected_embeddings = self._get_prompt_embeddings(batch) + + losses, chosen_rewards, rejected_rewards, kl = self.bco_loss( + policy_chosen_logps, + policy_rejected_logps, + reference_chosen_logps, + reference_rejected_logps, + chosen_embeddings, + rejected_embeddings, + ) num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device) num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device)