diff --git a/README.md b/README.md index b0f22cf07e..0e1cfa74c5 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,7 @@ The library is built on top of the [`transformers`](https://github.com/huggingfa - [`PEFT`](https://github.com/huggingface/peft) is fully integrated and allows to train even the largest models on modest hardware with quantisation and methods such as LoRA or QLoRA. - [`unsloth`](https://github.com/unslothai/unsloth) is also integrated and allows to significantly speed up training with dedicated kernels. - **`CLI`**: With the [CLI](https://huggingface.co/docs/trl/clis) you can fine-tune and chat with LLMs without writing any code using a single command and a flexible config system. -- **`Trainers`**: The Trainer classes are an abstraction to apply many fine-tuning methods with ease such as the [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer), [`DPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.DPOTrainer), [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer), [`PPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.PPOTrainer), and [`CPOTrainer`]((https://huggingface.co/docs/trl/trainer#trl.CPOTrainer). +- **`Trainers`**: The Trainer classes are an abstraction to apply many fine-tuning methods with ease such as the [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer), [`DPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.DPOTrainer), [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer), [`PPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.PPOTrainer), [`CPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.CPOTrainer), and [`ORPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.ORPOTrainer). - **`AutoModels`**: The [`AutoModelForCausalLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForCausalLMWithValueHead) & [`AutoModelForSeq2SeqLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForSeq2SeqLMWithValueHead) classes add an additional value head to the model which allows to train them with RL algorithms such as PPO. - **`Examples`**: Train GPT2 to generate positive movie reviews with a BERT sentiment classifier, full RLHF using adapters only, train GPT-j to be less toxic, [StackLlama example](https://huggingface.co/blog/stackllama), etc. following the [examples](https://github.com/huggingface/trl/tree/main/examples). diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index b9e53815cf..e69a418f30 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -37,6 +37,8 @@ title: CPO Trainer - local: ddpo_trainer title: Denoising Diffusion Policy Optimization + - local: orpo_trainer + title: ORPO Trainer - local: iterative_sft_trainer title: Iterative Supervised Fine-Tuning - local: text_environments diff --git a/docs/source/orpo_trainer.md b/docs/source/orpo_trainer.md new file mode 100644 index 0000000000..774151d733 --- /dev/null +++ b/docs/source/orpo_trainer.md @@ -0,0 +1,98 @@ +# ORPO Trainer + +[Odds Ratio Preference Optimization](https://arxiv.org/abs/2403.07691) (ORPO) by Jiwoo Hong, Noah Lee, and James Thorne studies the crucial role of SFT within the context of preference alignment. Using preference data the method posits that a minor penalty for the disfavored generation together with a strong adaption signal to the chosen response via a simple log odds ratio term appended to the NLL loss is sufficient for preference-aligned SFT. + +Thus ORPO is a reference model-free preference optimization algorithm eliminating the necessity for an additional preference alignment phase thus saving compute and memory. + +The official code can be found [xfactlab/orpo](https://github.com/xfactlab/orpo). + +## Expected dataset format + +The ORPO trainer expects a format identical to the DPO trainer, which should include three entries. These entries should be named as follows: + +- `prompt` +- `chosen` +- `rejected` + +for example: + +```py +orpo_dataset_dict = { + "prompt": [ + "hello", + "how are you", + "What is your name?", + "What is your name?", + "Which is the best programming language?", + "Which is the best programming language?", + "Which is the best programming language?", + ], + "chosen": [ + "hi nice to meet you", + "I am fine", + "My name is Mary", + "My name is Mary", + "Python", + "Python", + "Java", + ], + "rejected": [ + "leave me alone", + "I am not fine", + "Whats it to you?", + "I dont have a name", + "Javascript", + "C++", + "C++", + ], +} +``` +where the `prompt` contains the context inputs, `chosen` contains the corresponding chosen responses and `rejected` contains the corresponding negative (rejected) responses. Note that a prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays. + +## Expected model format +The ORPO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function. + +## Using the `ORPOTrainer` +For a detailed example have a look at the `examples/scripts/orpo.py` script. At a high level we need to initialize the `ORPOTrainer` with a `model` we wish to train. **Note that ORPOTrainer eliminates the need to use the reference model, simplifying the optimization process.** The `beta` refers to the hyperparameter `lambda` in eq. (6) of the paper and refers to the weighting of the relative odd ratio loss in the standard cross-entropy loss used for SFT. + +```py +orpo_config = ORPOConfig( + beta=0.1, # the lambda/alpha hyperparameter in the paper/code +) + +orpo_trainer = ORPOTrainer( + model, + args=orpo_config, + train_dataset=train_dataset, + tokenizer=tokenizer, +) +``` +After this one can then call: + +```py +orpo_trainer.train() +``` + +## Logging + +While training and evaluating we record the following reward metrics: + +* `rewards/chosen`: the mean log probabilities of the policy model for the chosen responses scaled by beta +* `rewards/rejected`: the mean log probabilities of the policy model for the rejected responses scaled by beta +* `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards +* `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards + +* `log_odds_chosen`: the mean log odds ratio of the chosen responses over the rejected responses + +* `log_odds_ratio`: the mean of the `log(sigmoid(log_odds_chosen))` + +* `nll_loss`: the mean negative log likelihood loss from the SFT part of the loss over chosen responses + +## ORPOTrainer + +[[autodoc]] ORPOTrainer + + +## ORPOConfig + +[[autodoc]] ORPOConfig diff --git a/examples/scripts/orpo.py b/examples/scripts/orpo.py new file mode 100644 index 0000000000..de2c77e417 --- /dev/null +++ b/examples/scripts/orpo.py @@ -0,0 +1,121 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Run the ORPO training script with the following command with some example arguments. +In general, the optimal configuration for ORPO will be similar to that of DPO without the need for a reference model: + +# regular: +python examples/scripts/orpo.py \ + --model_name_or_path=gpt2 \ + --per_device_train_batch_size 4 \ + --max_steps 1000 \ + --learning_rate 8e-6 \ + --gradient_accumulation_steps 1 \ + --logging_steps 10 \ + --eval_steps 500 \ + --output_dir="gpt2-aligned-orpo" \ + --warmup_steps 150 \ + --report_to wandb \ + --bf16 \ + --logging_first_step \ + --no_remove_unused_columns + +# peft: +python examples/scripts/orpo.py \ + --model_name_or_path=gpt2 \ + --per_device_train_batch_size 4 \ + --max_steps 1000 \ + --learning_rate 8e-5 \ + --gradient_accumulation_steps 1 \ + --logging_steps 10 \ + --eval_steps 500 \ + --output_dir="gpt2-lora-aligned-orpo" \ + --optim rmsprop \ + --warmup_steps 150 \ + --report_to wandb \ + --bf16 \ + --logging_first_step \ + --no_remove_unused_columns \ + --use_peft \ + --lora_r=16 \ + --lora_alpha=16 +""" + +import multiprocessing +from dataclasses import dataclass, field + +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser + +from trl import ModelConfig, ORPOConfig, ORPOTrainer, get_peft_config + + +@dataclass +class ScriptArguments: + dataset: str = field( + default="trl-internal-testing/hh-rlhf-trl-style", metadata={"help": "The name of the dataset to use."} + ) + + +if __name__ == "__main__": + parser = HfArgumentParser((ScriptArguments, ORPOConfig, ModelConfig)) + args, orpo_args, model_config = parser.parse_args_into_dataclasses() + + ################ + # Model & Tokenizer + ################ + model = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path) + peft_config = get_peft_config(model_config) + tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + ################ + # Dataset + ################ + ds = load_dataset(args.dataset) + if orpo_args.debug: + for key in ds: + ds[key] = ds[key].select(range(50)) + if tokenizer.chat_template is None: + tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" + + def process(row): + row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False) + row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False) + return row + + ds = ds.map( + process, + num_proc=1 if orpo_args.debug else multiprocessing.cpu_count(), + load_from_cache_file=False, + ) + train_dataset = ds["train"] + eval_dataset = ds["test"] + + ################ + # Training + ################ + trainer = ORPOTrainer( + model, + args=orpo_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + peft_config=get_peft_config(model_config), + ) + + # train and save the model + trainer.train() + trainer.save_model(orpo_args.output_dir) diff --git a/tests/test_orpo_trainer.py b/tests/test_orpo_trainer.py new file mode 100644 index 0000000000..368c32d6c8 --- /dev/null +++ b/tests/test_orpo_trainer.py @@ -0,0 +1,174 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import tempfile +import unittest + +import torch +from datasets import Dataset +from parameterized import parameterized +from pytest import mark +from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer + +from trl import ORPOConfig, ORPOTrainer + +from .testing_utils import require_peft + + +class ORPOTrainerTester(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model_id = "trl-internal-testing/dummy-GPT2-correct-vocab" + cls.model = AutoModelForCausalLM.from_pretrained(cls.model_id) + cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_id) + cls.tokenizer.pad_token = cls.tokenizer.eos_token + + # get t5 as seq2seq example: + model_id = "trl-internal-testing/tiny-T5ForConditionalGeneration-correct-vocab" + cls.t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id) + cls.t5_tokenizer = AutoTokenizer.from_pretrained(model_id) + + def _init_dummy_dataset(self): + # fmt: off + dummy_dataset_dict = { + "prompt": [ + "hello", + "how are you", + "What is your name?", + "What is your name?", + "Which is the best programming language?", + "Which is the best programming language?", + "Which is the best programming language?", + "[INST] How is the stock price? [/INST]", + "[INST] How is the stock price? [/INST] ", + ], + "chosen": [ + "hi nice to meet you", + "I am fine", + "My name is Mary", + "My name is Mary", + "Python", + "Python", + "Python", + "$46 as of 10am EST", + "46 as of 10am EST", + ], + "rejected": [ + "leave me alone", + "I am not fine", + "Whats it to you?", + "I dont have a name", + "Javascript", + "C++", + "Java", + " $46 as of 10am EST", + " 46 as of 10am EST", + ], + } + # fmt: on + return Dataset.from_dict(dummy_dataset_dict) + + @parameterized.expand([["gpt2"], ["t5"]]) + def test_orpo_trainer(self, name): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = ORPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + evaluation_strategy="steps", + beta=0.1, + ) + + dummy_dataset = self._init_dummy_dataset() + + if name == "gpt2": + model = self.model + tokenizer = self.tokenizer + elif name == "t5": + model = self.t5_model + tokenizer = self.t5_tokenizer + training_args.is_encoder_decoder = True + + trainer = ORPOTrainer( + model=model, + args=training_args, + tokenizer=tokenizer, + train_dataset=dummy_dataset, + eval_dataset=dummy_dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + # check the params have changed - ignore 0 biases + if param.sum() != 0: + assert not torch.equal(param, new_param) + + @require_peft + @mark.peft_test + def test_orpo_trainer_with_lora(self): + from peft import LoraConfig + + lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = ORPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=4, + learning_rate=9e-1, + evaluation_strategy="steps", + beta=0.1, + ) + + dummy_dataset = self._init_dummy_dataset() + + trainer = ORPOTrainer( + model=self.model, + args=training_args, + tokenizer=self.tokenizer, + train_dataset=dummy_dataset, + eval_dataset=dummy_dataset, + peft_config=lora_config, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # check the params have changed + for n, param in previous_trainable_params.items(): + if "lora" in n: + new_param = trainer.model.get_parameter(n) + # check the params have changed - ignore 0 biases + if param.sum() != 0: + assert not torch.equal(param, new_param) diff --git a/trl/__init__.py b/trl/__init__.py index 5d5a43a7be..d1c4b4cb01 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -41,6 +41,8 @@ "KTOConfig", "KTOTrainer", "ModelConfig", + "ORPOConfig", + "ORPOTrainer", "PPOConfig", "PPOTrainer", "RewardConfig", @@ -102,6 +104,8 @@ KTOConfig, KTOTrainer, ModelConfig, + ORPOConfig, + ORPOTrainer, PPOConfig, PPOTrainer, RewardConfig, diff --git a/trl/trainer/__init__.py b/trl/trainer/__init__.py index 3ab1da1f4b..6310e2a0f9 100644 --- a/trl/trainer/__init__.py +++ b/trl/trainer/__init__.py @@ -40,6 +40,8 @@ "kto_config": ["KTOConfig"], "kto_trainer": ["KTOTrainer"], "model_config": ["ModelConfig"], + "orpo_config": ["ORPOConfig"], + "orpo_trainer": ["ORPOTrainer"], "ppo_config": ["PPOConfig"], "ppo_trainer": ["PPOTrainer"], "reward_config": ["RewardConfig"], @@ -83,6 +85,8 @@ from .kto_config import KTOConfig from .kto_trainer import KTOTrainer from .model_config import ModelConfig + from .orpo_config import ORPOConfig + from .orpo_trainer import ORPOTrainer from .ppo_config import PPOConfig from .ppo_trainer import PPOTrainer from .reward_config import RewardConfig diff --git a/trl/trainer/orpo_config.py b/trl/trainer/orpo_config.py new file mode 100644 index 0000000000..14be7ee1a2 --- /dev/null +++ b/trl/trainer/orpo_config.py @@ -0,0 +1,71 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Dict, Optional + +from transformers import TrainingArguments + + +@dataclass +class ORPOConfig(TrainingArguments): + r""" + ORPOConfig collects all training arguments related to the [`ORPOTrainer`] class. + + Using [`HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + max_length (`int`, defaults to `None`): + The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator. + max_prompt_length (`int`, defaults to `None`): + The maximum length of the prompt. This argument is required if you want to use the default data collator. + max_completion_length (`int`, defaults to `None`): + The maximum length of the completions. This argument is required if you want to use the default data collator and your model is an encoder-decoder. + beta (`float`, defaults to 0.1): + The beta factor in ORPO loss (lambda/alpha in paper/code) that is the weight of the relative loss ratio in the SFT loss. + label_pad_token_id (`int`, defaults to `-100`): + The label pad token id. This argument is required if you want to use the default data collator. + padding_value (`int`, defaults to `None`): + The padding value if it is different to the tokenizer's pad_token_id. + truncation_mode (`str`, defaults to `keep_end`): + The truncation mode to use, either `keep_end` or `keep_start`. This argument is required if you want to use the default data collator. + generate_during_eval (`bool`, defaults to `False`): + Whether to sample and log generations during evaluation step. + is_encoder_decoder (`Optional[bool]`, `optional`, defaults to `None`): + If no model is provided, we need to know if the model_init returns an encoder-decoder. + disable_dropout (`bool`, defaults to `True`): + Whether or not to disable dropouts in `model`. + model_init_kwargs (`Optional[Dict]`, *optional*): + Dict of Optional kwargs to pass when instantiating the model from a string + dataset_num_proc (`Optional[int]`, *optional*): + The number of workers to use to tokenize the data. Defaults to None. + """ + + max_length: Optional[int] = None + max_prompt_length: Optional[int] = None + max_completion_length: Optional[int] = None + + beta: float = 0.1 + disable_dropout: bool = True + + label_pad_token_id: int = -100 + padding_value: int = None + truncation_mode: str = "keep_end" + generate_during_eval: bool = False + is_encoder_decoder: Optional[bool] = None + + model_init_kwargs: Optional[Dict] = None + + dataset_num_proc: Optional[int] = None diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py new file mode 100644 index 0000000000..70ddfa2408 --- /dev/null +++ b/trl/trainer/orpo_trainer.py @@ -0,0 +1,949 @@ +# ORPO Authors: Jiwoo Hong, Noah Lee, and James Thorne +# Official code: https://github.com/xfactlab/orpo +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +import random +import warnings +from collections import defaultdict +from contextlib import nullcontext +from copy import deepcopy +from functools import wraps +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from accelerate import PartialState +from accelerate.utils import is_deepspeed_available +from datasets import Dataset +from torch.utils.data import DataLoader +from transformers import AutoModelForCausalLM, DataCollator, PreTrainedModel, PreTrainedTokenizerBase, Trainer +from transformers.trainer_callback import TrainerCallback +from transformers.trainer_utils import EvalLoopOutput +from transformers.utils import is_torch_fx_proxy + +from ..import_utils import is_peft_available, is_wandb_available +from ..models import PreTrainedModelWrapper +from .orpo_config import ORPOConfig +from .utils import ( + DPODataCollatorWithPadding, + disable_dropout_in_model, + pad_to_length, + peft_module_casting_to_bf16, + trl_sanitze_kwargs_for_tagging, +) + + +if is_peft_available(): + from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training + + +if is_wandb_available(): + import wandb + +if is_deepspeed_available(): + import deepspeed + + +class ORPOTrainer(Trainer): + r""" + Initialize ORPOTrainer. + + Args: + model (`transformers.PreTrainedModel`): + The model to train, preferably an `AutoModelForSequenceClassification`. + args (`ORPOConfig`): + The ORPO config arguments to use for training. + data_collator (`transformers.DataCollator`): + The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used + which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences. + train_dataset (`datasets.Dataset`): + The dataset to use for training. + eval_dataset (`datasets.Dataset`): + The dataset to use for evaluation. + tokenizer (`transformers.PreTrainedTokenizerBase`): + The tokenizer to use for training. This argument is required if you want to use the default data collator. + model_init (`Callable[[], transformers.PreTrainedModel]`): + The model initializer to use for training. If None is specified, the default model initializer will be used. + callbacks (`List[transformers.TrainerCallback]`): + The callbacks to use for training. + optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): + The optimizer and scheduler to use for training. + preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): + The function to use to preprocess the logits before computing the metrics. + peft_config (`Dict`, defaults to `None`): + The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model. + compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*): + The function to use to compute the metrics. Must take a `EvalPrediction` and return + a dictionary string to metric values. + """ + + _tag_names = ["trl", "orpo"] + + def __init__( + self, + model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, + args: Optional[ORPOConfig] = None, + data_collator: Optional[DataCollator] = None, + train_dataset: Optional[Dataset] = None, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, + tokenizer: Optional[PreTrainedTokenizerBase] = None, + model_init: Optional[Callable[[], PreTrainedModel]] = None, + callbacks: Optional[List[TrainerCallback]] = None, + optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + peft_config: Optional[Dict] = None, + compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None, + ): + if args.model_init_kwargs is None: + model_init_kwargs = {} + elif not isinstance(model, str): + raise ValueError("You passed model_kwargs to the ORPOTrainer. But your model is already instantiated.") + else: + model_init_kwargs = args.model_init_kwargs + model_init_kwargs["torch_dtype"] = ( + model_init_kwargs["torch_dtype"] + if model_init_kwargs["torch_dtype"] in ["auto", None] + else getattr(torch, model_init_kwargs["torch_dtype"]) + ) + + if isinstance(model, str): + warnings.warn( + "You passed a model_id to the ORPOTrainer. This will automatically create an " + "`AutoModelForCausalLM` or a `PeftModel` (if you passed a `peft_config`) for you." + ) + model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) + + # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16` + # has been called in order to properly call autocast if needed. + self._peft_has_been_casted_to_bf16 = False + + if not is_peft_available() and peft_config is not None: + raise ValueError( + "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models" + ) + elif is_peft_available() and peft_config is not None: + # if model is a peft model and we have a peft_config, we merge and unload it first + if isinstance(model, PeftModel): + model = model.merge_and_unload() + + if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): + _support_gc_kwargs = hasattr( + args, "gradient_checkpointing_kwargs" + ) and "gradient_checkpointing_kwargs" in list( + inspect.signature(prepare_model_for_kbit_training).parameters + ) + + prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing} + + if _support_gc_kwargs: + prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs + + model = prepare_model_for_kbit_training(model, **prepare_model_kwargs) + elif getattr(args, "gradient_checkpointing", False): + # For backward compatibility with older versions of transformers + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + # get peft model with the given config + model = get_peft_model(model, peft_config) + if args.bf16 and getattr(model, "is_loaded_in_4bit", False): + peft_module_casting_to_bf16(model) + # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager + self._peft_has_been_casted_to_bf16 = True + + # For models that use gradient_checkpointing, we need to attach a hook that enables input + # to explicitly have `requires_grad=True`, otherwise training will either silently + # fail or completely fail. + elif getattr(args, "gradient_checkpointing", False): + # For backward compatibility with older versions of transformers + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + if args.generate_during_eval and not is_wandb_available(): + raise ValueError( + "`generate_during_eval=True` requires Weights and Biases to be installed." + " Please install `wandb` to resolve." + ) + + if model is not None: + self.is_encoder_decoder = model.config.is_encoder_decoder + elif args.is_encoder_decoder is None: + raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.") + else: + self.is_encoder_decoder = args.is_encoder_decoder + + if self.is_encoder_decoder: + self.decoder_start_token_id = model.config.decoder_start_token_id + self.pad_token_id = model.config.pad_token_id + + if tokenizer is None: + raise ValueError("tokenizer must be specified to tokenize a ORPO dataset.") + if args.max_length is None: + warnings.warn( + "`max_length` is not set in the ORPOConfig's init" + " it will default to `512` by default, but you should do it yourself in the future.", + UserWarning, + ) + max_length = 512 + else: + max_length = args.max_length + if args.max_prompt_length is None: + warnings.warn( + "`max_prompt_length` is not set in the ORPOConfig's init" + " it will default to `128` by default, but you should do it yourself in the future.", + UserWarning, + ) + max_prompt_length = 128 + else: + max_prompt_length = args.max_prompt_length + + if args.max_completion_length is None and self.is_encoder_decoder: + warnings.warn( + "When using an encoder decoder architecture, you should set `max_completion_length` in the ORPOConfig's init" + " it will default to `128` by default, but you should do it yourself in the future.", + UserWarning, + ) + self.max_completion_length = 128 + else: + self.max_completion_length = args.max_completion_length + + if data_collator is None: + data_collator = DPODataCollatorWithPadding( + pad_token_id=tokenizer.pad_token_id, + label_pad_token_id=args.label_pad_token_id, + is_encoder_decoder=self.is_encoder_decoder, + ) + + if args.remove_unused_columns: + args.remove_unused_columns = False + # warn users + warnings.warn( + "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments" + " we have set it for you, but you should do it yourself in the future.", + UserWarning, + ) + + self.use_dpo_data_collator = True + else: + self.use_dpo_data_collator = False + + if args.disable_dropout: + disable_dropout_in_model(model) + + self.max_length = max_length + self.generate_during_eval = args.generate_during_eval + self.label_pad_token_id = args.label_pad_token_id + self.padding_value = args.padding_value if args.padding_value is not None else tokenizer.pad_token_id + self.max_prompt_length = max_prompt_length + self.truncation_mode = args.truncation_mode + self.tokenizer = tokenizer + + self.beta = args.beta + + self._stored_metrics = defaultdict(lambda: defaultdict(list)) + + # Compute that only on the main process for faster data processing. + # see: https://github.com/huggingface/trl/pull/1255 + with PartialState().local_main_process_first(): + # tokenize the dataset + train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc) + if eval_dataset is not None: + eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc) + + super().__init__( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + model_init=model_init, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + # Add tags for models that have been loaded with the correct transformers version + if hasattr(self.model, "add_model_tags"): + self.model.add_model_tags(self._tag_names) + + if not hasattr(self, "accelerator"): + raise AttributeError( + "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`." + ) + + def _prepare_deepspeed(self, model: PreTrainedModelWrapper): + # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473 + deepspeed_plugin = self.accelerator.state.deepspeed_plugin + config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config) + + if model is not None: + if hasattr(model, "config"): + hidden_size = ( + max(model.config.hidden_sizes) + if getattr(model.config, "hidden_sizes", None) + else getattr(model.config, "hidden_size", None) + ) + if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3: + # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0` + # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081 + config_kwargs.update( + { + "zero_optimization.reduce_bucket_size": hidden_size * hidden_size, + "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size, + "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size, + } + ) + + # If ZeRO-3 is used, we shard both the active and reference model. + # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0) + if config_kwargs["zero_optimization"]["stage"] != 3: + config_kwargs["zero_optimization"]["stage"] = 0 + model, *_ = deepspeed.initialize(model=model, config=config_kwargs) + model.eval() + return model + + def build_tokenized_answer(self, prompt, answer): + """ + Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`. + It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`. + Reference: + https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257 + """ + + full_tokenized = self.tokenizer(prompt + answer, add_special_tokens=False) + prompt_input_ids = self.tokenizer(prompt, add_special_tokens=False)["input_ids"] + + answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :] + answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :] + + # Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]` + full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids]) + + # Prepare input tokens for token by token comparison + full_input_ids = np.array(full_tokenized["input_ids"]) + + if len(full_input_ids) != len(full_concat_input_ids): + raise ValueError("Prompt input ids and answer input ids should have the same length.") + + # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens + # can be merged together when tokenizing prompt+answer. This could result + # on the last token from the prompt being different when tokenized on its own + # vs when done as prompt+answer. + response_token_ids_start_idx = len(prompt_input_ids) + + # If tokenized prompt is different than both prompt+answer, then it means the + # last token has changed due to merging. + if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]: + response_token_ids_start_idx -= 1 + + prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx] + prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx] + + if len(prompt_input_ids) != len(prompt_attention_mask): + raise ValueError("Prompt input ids and attention mask should have the same length.") + + answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:] + answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:] + + return dict( + prompt_input_ids=prompt_input_ids, + prompt_attention_mask=prompt_attention_mask, + input_ids=answer_input_ids, + attention_mask=answer_attention_mask, + ) + + def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> Dict: + """Tokenize a single row from a ORPO specific dataset. + + At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation + in case the prompt + chosen or prompt + rejected responses is/are too long. First + we truncate the prompt; if we're still too long, we truncate the chosen/rejected. + + We also create the labels for the chosen/rejected responses, which are of length equal to + the sum of the length of the prompt and the chosen/rejected response, with + label_pad_token_id for the prompt tokens. + """ + batch = {} + prompt = feature["prompt"] + chosen = feature["chosen"] + rejected = feature["rejected"] + + if not self.is_encoder_decoder: + # Check issues below for more details + # 1. https://github.com/huggingface/trl/issues/907 + # 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257 + # 3. https://github.com/LianjiaTech/BELLE/issues/337 + + if not isinstance(prompt, str): + raise ValueError(f"prompt should be an str but got {type(prompt)}") + prompt_tokens = self.tokenizer(prompt, add_special_tokens=False) + prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()} + + if not isinstance(chosen, str): + raise ValueError(f"chosen should be an str but got {type(chosen)}") + chosen_tokens = self.build_tokenized_answer(prompt, chosen) + + if not isinstance(rejected, str): + raise ValueError(f"rejected should be an str but got {type(rejected)}") + rejected_tokens = self.build_tokenized_answer(prompt, rejected) + + # Last prompt token might get merged by tokenizer and + # it should not be included for generation if that happens + prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"]) + + chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"]) + rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"]) + prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids) + + for k, v in prompt_tokens.items(): + prompt_tokens[k] = v[:prompt_len_input_ids] + + # Make sure prompts only have one different token at most an + # and length only differs by 1 at most + num_diff_tokens = sum( + [a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])] + ) + num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids) + if num_diff_tokens > 1 or num_diff_len > 1: + raise ValueError( + "Chosen and rejected prompt_input_ids might only differ on the " + "last token due to tokenizer merge ops." + ) + + # add BOS token to head of prompt + prompt_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + prompt_tokens["prompt_input_ids"] + chosen_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + chosen_tokens["prompt_input_ids"] + rejected_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + rejected_tokens["prompt_input_ids"] + + prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens["prompt_attention_mask"] + chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens["prompt_attention_mask"] + rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens["prompt_attention_mask"] + + # add EOS token to end of answer + chosen_tokens["input_ids"].append(self.tokenizer.eos_token_id) + chosen_tokens["attention_mask"].append(1) + + rejected_tokens["input_ids"].append(self.tokenizer.eos_token_id) + rejected_tokens["attention_mask"].append(1) + + longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"])) + + # if combined sequence is too long, truncate the prompt + for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]: + if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length: + if self.truncation_mode == "keep_start": + for k in ["prompt_input_ids", "prompt_attention_mask"]: + answer_tokens[k] = answer_tokens[k][: self.max_prompt_length] + elif self.truncation_mode == "keep_end": + for k in ["prompt_input_ids", "prompt_attention_mask"]: + answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :] + else: + raise ValueError(f"Unknown truncation mode: {self.truncation_mode}") + + # if that's still too long, truncate the response + for answer_tokens in [chosen_tokens, rejected_tokens]: + if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length: + for k in ["input_ids", "attention_mask"]: + answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length] + + # Create labels + chosen_sequence_tokens = { + k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"] + } + rejected_sequence_tokens = { + k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"] + } + chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:] + chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [ + self.label_pad_token_id + ] * len(chosen_tokens["prompt_input_ids"]) + rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:] + rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [ + self.label_pad_token_id + ] * len(rejected_tokens["prompt_input_ids"]) + + for k, toks in { + "chosen_": chosen_sequence_tokens, + "rejected_": rejected_sequence_tokens, + "": prompt_tokens, + }.items(): + for type_key, tokens in toks.items(): + if type_key == "token_type_ids": + continue + batch[f"{k}{type_key}"] = tokens + + else: + chosen_tokens = self.tokenizer( + chosen, truncation=True, max_length=self.max_completion_length, add_special_tokens=True + ) + rejected_tokens = self.tokenizer( + rejected, truncation=True, max_length=self.max_completion_length, add_special_tokens=True + ) + prompt_tokens = self.tokenizer( + prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True + ) + + batch["chosen_labels"] = chosen_tokens["input_ids"] + batch["rejected_labels"] = rejected_tokens["input_ids"] + batch["prompt_input_ids"] = prompt_tokens["input_ids"] + batch["prompt_attention_mask"] = prompt_tokens["attention_mask"] + + if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"): + batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels( + labels=torch.tensor(batch["rejected_labels"]) + ) + batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels( + labels=torch.tensor(batch["chosen_labels"]) + ) + + return batch + + @staticmethod + def concatenated_inputs( + batch: Dict[str, Union[List, torch.LongTensor]], + is_encoder_decoder: bool = False, + label_pad_token_id: int = -100, + padding_value: int = 0, + device: Optional[torch.device] = None, + ) -> Dict[str, torch.LongTensor]: + """Concatenate the chosen and rejected inputs into a single tensor. + + Args: + batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length). + is_encoder_decoder: Whether the model is an encoder-decoder model. + label_pad_token_id: The label pad token id. + padding_value: The padding value to use for the concatenated inputs_ids. + device: The device for the concatenated inputs. + + Returns: + A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'. + """ + concatenated_batch = {} + + if is_encoder_decoder: + max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1]) + else: + max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1]) + + for k in batch: + if k.startswith("chosen") and isinstance(batch[k], torch.Tensor): + if "labels" in k or is_encoder_decoder: + pad_value = label_pad_token_id + elif k.endswith("_input_ids"): + pad_value = padding_value + elif k.endswith("_attention_mask"): + pad_value = 0 + concatenated_key = k.replace("chosen", "concatenated") + concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value) + for k in batch: + if k.startswith("rejected") and isinstance(batch[k], torch.Tensor): + if "labels" in k or is_encoder_decoder: + pad_value = label_pad_token_id + elif k.endswith("_input_ids"): + pad_value = padding_value + elif k.endswith("_attention_mask"): + pad_value = 0 + concatenated_key = k.replace("rejected", "concatenated") + concatenated_batch[concatenated_key] = torch.cat( + ( + concatenated_batch[concatenated_key], + pad_to_length(batch[k], max_length, pad_value=pad_value), + ), + dim=0, + ).to(device=device) + + if is_encoder_decoder: + concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device) + concatenated_batch["concatenated_attention_mask"] = ( + batch["prompt_attention_mask"].repeat(2, 1).to(device=device) + ) + + return concatenated_batch + + def odds_ratio_loss( + self, + policy_chosen_logps: torch.FloatTensor, + policy_rejected_logps: torch.FloatTensor, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Compute ORPO's odds ratio (OR) loss for a batch of policy and reference model log probabilities. + + Args: + policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) + policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) + + Returns: + A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). + The losses tensor contains the ORPO loss for each example in the batch. + The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively. + The log odds ratio of the chosen responses over the rejected responses ratio for logging purposes. + The `log(sigmoid(log_odds_chosen))` for logging purposes. + """ + + # Derived from Eqs. (4) and (7) from https://arxiv.org/abs/2403.07691 by using log identities and exp(log(P(y|x)) = P(y|x) + log_odds = (policy_chosen_logps - policy_rejected_logps) - ( + torch.log(1 - torch.exp(policy_chosen_logps)) - torch.log(1 - torch.exp(policy_rejected_logps)) + ) + sig_ratio = F.sigmoid(log_odds) + ratio = torch.log(sig_ratio) + losses = self.beta * ratio + + chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach() + rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach() + + return losses, chosen_rewards, rejected_rewards, torch.mean(ratio).item(), torch.mean(log_odds).item() + + @staticmethod + def get_batch_logps( + logits: torch.FloatTensor, + labels: torch.LongTensor, + average_log_prob: bool = False, + label_pad_token_id: int = -100, + is_encoder_decoder: bool = False, + ) -> torch.FloatTensor: + """Compute the log probabilities of the given labels under the given logits. + + Args: + logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) + labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length) + average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens. + label_pad_token_id: The label pad token id. + is_encoder_decoder: Whether the model is an encoder-decoder model. + + Returns: + A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits. + """ + if logits.shape[:-1] != labels.shape: + raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.") + + if not is_encoder_decoder: + labels = labels[:, 1:].clone() + logits = logits[:, :-1, :] + loss_mask = labels != label_pad_token_id + + # dummy token; we'll ignore the losses on these tokens later + labels[labels == label_pad_token_id] = 0 + + per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + + if average_log_prob: + return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) + else: + return (per_token_logps * loss_mask).sum(-1) + + def concatenated_forward( + self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]] + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. + + We do this to avoid doing two forward passes, because it's faster for FSDP. + """ + concatenated_batch = self.concatenated_inputs( + batch, + is_encoder_decoder=self.is_encoder_decoder, + label_pad_token_id=self.label_pad_token_id, + padding_value=self.padding_value, + device=self.accelerator.device, + ) + len_chosen = batch["chosen_labels"].shape[0] + + model_kwargs = ( + { + "decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]), + } + if self.is_encoder_decoder + else {} + ) + + outputs = model( + concatenated_batch["concatenated_input_ids"], + attention_mask=concatenated_batch["concatenated_attention_mask"], + **model_kwargs, + ) + all_logits = outputs.logits + + def cross_entropy_loss(logits, labels): + if not self.is_encoder_decoder: + # Shift so that tokens < n predict n + logits = logits[..., :-1, :].contiguous() + labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + logits = logits.view(-1, logits.shape[-1]) + labels = labels.view(-1) + # Enable model parallelism + labels = labels.to(logits.device) + loss = loss_fct(logits, labels) + return loss + + labels = concatenated_batch["concatenated_labels"].clone() + chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen]) + + all_logps = self.get_batch_logps( + all_logits, + concatenated_batch["concatenated_labels"], + average_log_prob=True, + is_encoder_decoder=self.is_encoder_decoder, + label_pad_token_id=self.label_pad_token_id, + ) + + chosen_logps = all_logps[:len_chosen] + rejected_logps = all_logps[len_chosen:] + + chosen_logits = all_logits[:len_chosen] + rejected_logits = all_logits[len_chosen:] + + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss) + + def get_batch_loss_metrics( + self, + model, + batch: Dict[str, Union[List, torch.LongTensor]], + train_eval: Literal["train", "eval"] = "train", + ): + """Compute the ORPO loss and other metrics for the given batch of inputs for train or test.""" + metrics = {} + + ( + policy_chosen_logps, + policy_rejected_logps, + policy_chosen_logits, + policy_rejected_logits, + policy_nll_loss, + ) = self.concatenated_forward(model, batch) + + losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss( + policy_chosen_logps, policy_rejected_logps + ) + # full ORPO loss + loss = policy_nll_loss - losses.mean() + + reward_accuracies = (chosen_rewards > rejected_rewards).float() + + prefix = "eval_" if train_eval == "eval" else "" + metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu() + metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu() + metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu() + metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().cpu() + metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean().cpu() + metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu() + metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean().cpu() + metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu() + metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean().cpu() + metrics[f"{prefix}log_odds_ratio"] = log_odds_ratio + metrics[f"{prefix}log_odds_chosen"] = log_odds_chosen + + return loss, metrics + + def compute_loss( + self, + model: Union[PreTrainedModel, nn.Module], + inputs: Dict[str, Union[torch.Tensor, Any]], + return_outputs=False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: + if not self.use_dpo_data_collator: + warnings.warn( + "compute_loss is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than " + "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator" + ) + + compute_loss_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext + + with compute_loss_context_manager(): + loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train") + + # force log the metrics + self.store_metrics(metrics, train_eval="train") + + if return_outputs: + return (loss, metrics) + return loss + + def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]: + """Generate samples from the model and reference model for the given batch of inputs.""" + + # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with + # the torch cuda amp context manager as some hidden states are silently casted to full precision. + generate_context_manager = nullcontext if not self._peft_has_been_casted_to_bf16 else torch.cuda.amp.autocast + + with generate_context_manager(): + policy_output = model.generate( + input_ids=batch["prompt_input_ids"], + attention_mask=batch["prompt_attention_mask"], + max_length=self.max_length, + do_sample=True, + pad_token_id=self.tokenizer.pad_token_id, + ) + + policy_output = pad_to_length(policy_output, self.max_length, self.tokenizer.pad_token_id) + policy_output_decoded = self.tokenizer.batch_decode(policy_output, skip_special_tokens=True) + + return policy_output_decoded + + def prediction_step( + self, + model: Union[PreTrainedModel, nn.Module], + inputs: Dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[List[str]] = None, + ): + if not self.use_dpo_data_collator: + warnings.warn( + "prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than " + "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator" + ) + if ignore_keys is None: + if hasattr(model, "config"): + ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", []) + else: + ignore_keys = [] + + prediction_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext + + with torch.no_grad(), prediction_context_manager(): + loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval") + + # force log the metrics + self.store_metrics(metrics, train_eval="eval") + + if prediction_loss_only: + return (loss.detach(), None, None) + + # logits for the chosen and rejected samples from model + logits_dict = { + "eval_logits/chosen": metrics["eval_logits/chosen"], + "eval_logits/rejected": metrics["eval_logits/rejected"], + } + logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys) + logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device) + labels = torch.zeros(logits.shape[0], device=self.accelerator.device) + + return (loss.detach(), logits, labels) + + def store_metrics(self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None: + for key, value in metrics.items(): + self._stored_metrics[train_eval][key].append(value) + + def evaluation_loop( + self, + dataloader: DataLoader, + description: str, + prediction_loss_only: Optional[bool] = None, + ignore_keys: Optional[List[str]] = None, + metric_key_prefix: str = "eval", + ) -> EvalLoopOutput: + """ + Overriding built-in evaluation loop to store metrics for each batch. + Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`. + + Works both with or without labels. + """ + + # Sample and save to game log if requested (for one batch to save time) + if self.generate_during_eval: + # Generate random indices within the range of the total number of samples + num_samples = len(dataloader.dataset) + random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size) + + # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader + random_batch_dataset = dataloader.dataset.select(random_indices) + random_batch = self.data_collator(random_batch_dataset) + random_batch = self._prepare_inputs(random_batch) + + policy_output_decoded = self.get_batch_samples(self.model, random_batch) + + self.log( + { + "game_log": wandb.Table( + columns=["Prompt", "Policy"], + rows=[ + [prompt, pol[len(prompt) :]] + for prompt, pol in zip(random_batch["prompt"], policy_output_decoded) + ], + ) + } + ) + self.state.log_history.pop() + + # Base evaluation + initial_output = super().evaluation_loop( + dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix + ) + + return initial_output + + def log(self, logs: Dict[str, float]) -> None: + """ + Log `logs` on the various objects watching training, including stored metrics. + + Args: + logs (`Dict[str, float]`): + The values to log. + """ + # logs either has 'loss' or 'eval_loss' + train_eval = "train" if "loss" in logs else "eval" + # Add averaged stored metrics to logs + for key, metrics in self._stored_metrics[train_eval].items(): + logs[key] = torch.tensor(metrics).mean().item() + del self._stored_metrics[train_eval] + return super().log(logs) + + def _shift_right(self, input_ids): + if self.decoder_start_token_id is None: + raise ValueError( + "model.config.decoder_start_token_id has to be defined. It is usually set to the pad_token_id." + ) + + # shift inputs to the right + if is_torch_fx_proxy(input_ids): + # Item assignment is not supported natively for proxies. + shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), self.decoder_start_token_id) + shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1) + else: + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = self.decoder_start_token_id + + if self.pad_token_id is None: + raise ValueError("model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id) + + return shifted_input_ids + + @wraps(Trainer.push_to_hub) + def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str: + """ + Overwrite the `push_to_hub` method in order to force-add the tag "orpo" when pushing the + model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. + """ + kwargs = trl_sanitze_kwargs_for_tagging(model=self.model, tag_names=self._tag_names, kwargs=kwargs) + + return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs)