diff --git a/examples/scripts/kto.py b/examples/scripts/kto.py index 49a3edffff..61c2a65b8e 100644 --- a/examples/scripts/kto.py +++ b/examples/scripts/kto.py @@ -13,49 +13,45 @@ # limitations under the License. """ -Run the KTO training script with the following command with some example arguments. -In general, the optimal configuration for KTO will be similar to that of DPO: +Run the KTO training script with the commands below. In general, the optimal configuration for KTO will be similar to that of DPO. -# regular: +# Full training: python examples/scripts/kto.py \ - --model_name_or_path=gpt2 \ + --model_name_or_path=stabilityai/stablelm-2-zephyr-1_6b \ --per_device_train_batch_size 16 \ - --max_steps 1000 \ + --num_train_epochs 1 \ --learning_rate 2e-5 \ --gradient_accumulation_steps 1 \ --logging_steps 10 \ --eval_steps 500 \ - --output_dir="kto_anthropic_hh" \ + --output_dir="kto-aligned-model" \ --warmup_steps 150 \ --report_to wandb \ --bf16 \ - --logging_first_step \ - --no_remove_unused_columns + --logging_first_step -# peft: +# LoRA: python examples/scripts/kto.py \ - --model_name_or_path=gpt2 \ + --model_name_or_path=stabilityai/stablelm-2-zephyr-1_6b \ --per_device_train_batch_size 16 \ - --max_steps 1000 \ + --num_train_epochs 1 \ --learning_rate 2e-4 \ --gradient_accumulation_steps 1 \ --logging_steps 10 \ --eval_steps 500 \ - --output_dir="kto_anthropic_hh" \ + --output_dir="kto-aligned-model-lora" \ --warmup_steps 150 \ --report_to wandb \ --bf16 \ --logging_first_step \ - --no_remove_unused_columns \ --use_peft \ --lora_r=16 \ --lora_alpha=16 """ -from dataclasses import dataclass, field -from typing import Optional +from dataclasses import dataclass -from datasets import Dataset, load_dataset +from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser from trl import KTOConfig, KTOTrainer, ModelConfig, get_peft_config @@ -68,85 +64,48 @@ class ScriptArguments: The arguments for the KTO training script. """ - # debugging - sanity_check: Optional[bool] = field(default=True, metadata={"help": "only train on 1000 samples"}) - - -def extract_anthropic_prompt(prompt_and_response): - """Extract the anthropic prompt from a prompt and response pair.""" - search_term = "\n\nAssistant:" - search_term_idx = prompt_and_response.rfind(search_term) - - if search_term_idx == -1: - raise ValueError(f"Prompt and response does not contain '{search_term}'") - - return prompt_and_response[: search_term_idx + len(search_term)] - - -def get_hh(split: str, sanity_check: bool = False, silent: bool = False, cache_dir: str = None) -> Dataset: - """Load the Anthropic Helpful-Harmless dataset from Hugging Face and convert it to the necessary format. - - The dataset is converted to a dictionary with the following structure: - { - 'prompt': List[str], - 'completion': List[str], - 'label': List[bool], - } - - Prompts should be structured as follows: - \n\nHuman: \n\nAssistant: - Multiple turns are allowed, but the prompt should always start with \n\nHuman: and end with \n\nAssistant:. - """ - dataset = load_dataset("Anthropic/hh-rlhf", split=split, cache_dir=cache_dir) - if sanity_check: - dataset = dataset.select(range(min(len(dataset), 1000))) - - flat_data = { - "prompt": [], - "completion": [], - "label": [], - } - for sample in dataset: - prompt = extract_anthropic_prompt(sample["chosen"]) - flat_data["prompt"].append(prompt) - flat_data["completion"].append(sample["chosen"][len(prompt) :]) - flat_data["label"].append(True) - flat_data["prompt"].append(prompt) - flat_data["completion"].append(sample["rejected"][len(prompt) :]) - flat_data["label"].append(False) - - return dataset.from_dict(flat_data) + dataset_name: str = "trl-lib/kto-mix-14k" if __name__ == "__main__": parser = HfArgumentParser((ScriptArguments, KTOConfig, ModelConfig)) script_args, kto_args, model_args = parser.parse_args_into_dataclasses() - # 1. load a pretrained model + # Load a pretrained model model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path) model_ref = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path) tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token + if tokenizer.chat_template is None: + raise ValueError( + "Tokenizer must have a chat template in order to format the examples. Alternatively, adjust this script to format the examples differently." + ) + + # Load the dataset + dataset = load_dataset(script_args.dataset_name) - # 2. Load the Anthropic Helpful-Harmless dataset - train_dataset = get_hh("train", sanity_check=script_args.sanity_check) + # Apply chat template + def format_dataset(example): + example["prompt"] = tokenizer.apply_chat_template(example["prompt"], tokenize=False) + example["completion"] = tokenizer.apply_chat_template(example["completion"], tokenize=False) + return example - # 3. Load evaluation dataset - eval_dataset = get_hh("test", sanity_check=script_args.sanity_check) + formatted_dataset = dataset.map(format_dataset) - # 4. initialize the KTO trainer + # Initialize the KTO trainer kto_trainer = KTOTrainer( model, model_ref, args=kto_args, - train_dataset=train_dataset, - eval_dataset=eval_dataset, + train_dataset=formatted_dataset["train"], + eval_dataset=formatted_dataset["test"], tokenizer=tokenizer, peft_config=get_peft_config(model_args), ) - # 5. train and save the model + # Train and push the model to the Hub kto_trainer.train() kto_trainer.save_model(kto_args.output_dir) + kto_trainer.push_to_hub()