Skip to content

Commit

Permalink
FEAT: Add CLIs in TRL ! (huggingface#1419)
Browse files Browse the repository at this point in the history
* CLI V1

* v1 CLI

* add rich enhancmeents

* revert unindented change

* some comments

* cleaner CLI

* fix

* fix

* remove print callback

* move to cli instead of trl_cli

* revert unneeded changes

* fix test

* Update trl/commands/sft.py

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* remove redundant strings

* fix import issue

* fix other issues

* add packing

* add config parser

* some refactor

* cleaner

* add example config yaml file

* small refactor

* change a bit the logic

* fix issues here and there

* add CLI in docs

* move to examples/sft

* remove redundant licenses

* make it work on dpo

* set to None

* switch to accelerate and fix many things

* add docs

* more docs

* added tests

* doc clarification

* more docs

* fix CI for windows and python 3.8

* fix

* attempt to fix CI

* fix?

* test

* fix

* tweak?

* fix

* test

* another test

* fix

* test

* fix

* fix

* fix

* skip tests for windows

* test @lvwerra approach

* make dev

* revert unneeded changes

* fix sft dpo

* optimize a bit

* address final comments

* update docs

* final comment

---------

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
  • Loading branch information
younesbelkada and lvwerra authored Mar 18, 2024
1 parent 304e208 commit a2aa0f0
Show file tree
Hide file tree
Showing 24 changed files with 1,085 additions and 233 deletions.
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Before you start contributing make sure you installed all the dev tools:

```bash
pip install -e ".[dev]"
make dev
```

## Did you find a bug?
Expand Down
2 changes: 1 addition & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ include settings.ini
include LICENSE
include CONTRIBUTING.md
include README.md
recursive-exclude * __pycache__
recursive-exclude * __pycache__
5 changes: 5 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ check_dirs := examples tests trl
ACCELERATE_CONFIG_PATH = `pwd`/examples/accelerate_configs
COMMAND_FILES_PATH = `pwd`/commands


dev:
[ -L "$(pwd)/trl/commands/scripts" ] && unlink "$(pwd)/trl/commands/scripts" || true
pip install -e ".[dev]"

test:
python -m pytest -n auto --dist=loadfile -s -v ./tests/

Expand Down
2 changes: 2 additions & 0 deletions commands/run_dpo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# but defaults to QLoRA + PEFT
OUTPUT_DIR="test_dpo/"
MODEL_NAME="HuggingFaceM4/tiny-random-LlamaForCausalLM"
DATASET_NAME="trl-internal-testing/Anthropic-hh-rlhf-processed"
MAX_STEPS=5
BATCH_SIZE=2
SEQ_LEN=128
Expand Down Expand Up @@ -36,6 +37,7 @@ accelerate launch $EXTRA_ACCELERATE_ARGS \
--mixed_precision 'fp16' \
`pwd`/examples/scripts/dpo.py \
--model_name_or_path $MODEL_NAME \
--dataset_name $DATASET_NAME \
--output_dir $OUTPUT_DIR \
--max_steps $MAX_STEPS \
--per_device_train_batch_size $BATCH_SIZE \
Expand Down
2 changes: 2 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
title: Quickstart
- local: installation
title: Installation
- local: clis
title: Get started with Command Line Interfaces (CLIs)
- local: how_to_train
title: PPO Training FAQ
- local: use_model
Expand Down
87 changes: 87 additions & 0 deletions docs/source/clis.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Command Line Interfaces (CLIs)

You can use TRL to fine-tune your Language Model on Supervised Fine-Tuning (SFT) or Direct Policy Optimization (DPO) using the TRL CLIs.

Currently supported CLIs are:

- `trl sft`
- `trl dpo`

## Get started

Before getting started, pick up a Language Model from Hugging Face Hub. Supported models can be found with the filter "text-generation" within models. Also make sure to pick up a relevant dataset for your task.

Also make sure to run:
```bash
accelerate config
```
and pick up the right configuration for your training setup (single / multi-GPU, DeepSpeed, etc.). Make sure to complete all steps of `accelerate config` before running any CLI command.

We also recommend you passing a YAML config file to configure your training protocol. Below is a simple example of a YAML file that you can use for training your models with `trl sft` command.

```yaml
model_name_or_path:
HuggingFaceM4/tiny-random-LlamaForCausalLM
dataset_name:
imdb
dataset_text_field:
text
report_to:
none
learning_rate:
0.0001
lr_scheduler_type:
cosine
```
Save that config in a `.yaml` and get directly started ! Note you can overwrite the arguments from the config file by explicitly passing them to the CLI, e.g.:

```bash
trl sft --config example_config.yaml --output_dir test-trl-cli --lr_scheduler_type cosine_with_restarts
```

Will force-use `cosine_with_restarts` for `lr_scheduler_type`.

## Supported Arguments

We do support all arguments from `transformers.TrainingArguments`, for loading your model, we support all arguments from `~trl.ModelConfig`:

[[autodoc]] ModelConfig

You can pass any of these arguments either to the CLI or the YAML file.

### Supervised Fine-tuning (SFT)

Follow the basic instructions above and run `trl sft --output_dir <output_dir> <*args>`:

```bash
trl sft --config config.yaml --output_dir your-output-dir
```

The SFT CLI is based on the `examples/scripts/sft.py` script.

### Direct Policy Optimization (DPO)

First, follow the basic instructions above and run `trl dpo --output_dir <output_dir> <*args>`. Make sure to process your DPO dataset in the TRL format as follows:

1- Make sure to pre-tokenize the dataset using chat templates:

```bash
python examples/datasets/tokenize_ds.py --model gpt2 --dataset yourdataset
```

You might need to adapt the `examples/datasets/tokenize_ds.py` to use yout chat template

2- Format the dataset into TRL format (you can adapt the `examples/datasets/anthropic_hh.py`):

```bash
python examples/datasets/anthropic_hh.py --push_to_hub --hf_entity your-hf-org
```

Once your dataset being pushed, run the dpo CLI as follows:

```bash
trl dpo --config config.yaml --output_dir your-output-dir
```

The SFT CLI is based on the `examples/scripts/dpo.py` script.
20 changes: 20 additions & 0 deletions example_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# This is an example configuration file of TRL CLI, you can use it for
# SFT like that: `trl sft --config config.yaml --output_dir test-sft`
# The YAML file supports environment variables by adding an `env` field
# as below

# env:
# CUDA_VISIBLE_DEVICES: 0

model_name_or_path:
HuggingFaceM4/tiny-random-LlamaForCausalLM
dataset_name:
imdb
dataset_text_field:
text
report_to:
none
learning_rate:
1e-4
lr_scheduler_type:
cosine
139 changes: 63 additions & 76 deletions examples/scripts/dpo.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# flake8: noqa
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -48,76 +49,47 @@
--lora_r=16 \
--lora_alpha=16
"""
from dataclasses import dataclass, field
from typing import Dict, Optional
import logging
import os
from contextlib import nullcontext

import torch
from datasets import Dataset, load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments

from trl import DPOTrainer, ModelConfig, get_kbit_device_map, get_peft_config, get_quantization_config


@dataclass
class ScriptArguments:
beta: float = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})
max_length: int = field(default=512, metadata={"help": "max length of each sample"})
max_prompt_length: int = field(default=128, metadata={"help": "max length of each sample's prompt"})
max_target_length: int = field(
default=128, metadata={"help": "Only used for encoder decoder model. Max target of each sample's prompt"}
)
sanity_check: bool = field(default=True, metadata={"help": "only train on 1000 samples"})
ignore_bias_buffers: bool = field(
default=False,
metadata={
"help": "debug argument for distributed training;"
"fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See"
"https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
},
)
generate_during_eval: bool = field(default=False, metadata={"help": "Generate during evaluation"})
TRL_USE_RICH = os.environ.get("TRL_USE_RICH", False)

from trl.commands.cli_utils import DpoScriptArguments, init_zero_verbose, TrlParser

def extract_anthropic_prompt(prompt_and_response):
"""Extract the anthropic prompt from a prompt and response pair."""
search_term = "\n\nAssistant:"
search_term_idx = prompt_and_response.rfind(search_term)
assert search_term_idx != -1, f"Prompt and response does not contain '{search_term}'"
return prompt_and_response[: search_term_idx + len(search_term)]
if TRL_USE_RICH:
init_zero_verbose()
FORMAT = "%(message)s"

from rich.console import Console
from rich.logging import RichHandler

def get_hh(split: str, sanity_check: bool = False, silent: bool = False, cache_dir: Optional[str] = None) -> Dataset:
"""Load the Anthropic Helpful-Harmless dataset from Hugging Face and convert it to the necessary format.
The dataset is converted to a dictionary with the following structure:
{
'prompt': List[str],
'chosen': List[str],
'rejected': List[str],
}
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments

Prompts should be structured as follows:
\n\nHuman: <prompt>\n\nAssistant:
Multiple turns are allowed, but the prompt should always start with \n\nHuman: and end with \n\nAssistant:.
"""
dataset = load_dataset("Anthropic/hh-rlhf", split=split, cache_dir=cache_dir)
if sanity_check:
dataset = dataset.select(range(min(len(dataset), 1000)))
from trl import (
DPOTrainer,
ModelConfig,
RichProgressCallback,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)

def split_prompt_and_responses(sample) -> Dict[str, str]:
prompt = extract_anthropic_prompt(sample["chosen"])
return {
"prompt": prompt,
"chosen": sample["chosen"][len(prompt) :],
"rejected": sample["rejected"][len(prompt) :],
}

return dataset.map(split_prompt_and_responses)
if TRL_USE_RICH:
logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[RichHandler()], level=logging.INFO)


if __name__ == "__main__":
parser = HfArgumentParser((ScriptArguments, TrainingArguments, ModelConfig))
args, training_args, model_config = parser.parse_args_into_dataclasses()
parser = TrlParser((DpoScriptArguments, TrainingArguments, ModelConfig))
args, training_args, model_config = parser.parse_args_and_config()

# Force use our print callback
if TRL_USE_RICH:
training_args.disable_tqdm = True
console = Console()

################
# Model & Tokenizer
Expand Down Expand Up @@ -152,28 +124,43 @@ def split_prompt_and_responses(sample) -> Dict[str, str]:
name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
]

################
# Optional rich context managers
###############
init_context = nullcontext() if not TRL_USE_RICH else console.status("[bold green]Initializing the DPOTrainer...")
save_context = (
nullcontext()
if not TRL_USE_RICH
else console.status(f"[bold green]Training completed! Saving the model to {training_args.output_dir}")
)

################
# Dataset
################
train_dataset = get_hh("train", sanity_check=args.sanity_check)
eval_dataset = get_hh("test", sanity_check=args.sanity_check)
train_dataset = load_dataset(args.dataset_name, split="train")
eval_dataset = load_dataset(args.dataset_name, split="test")

################
# Training
################
trainer = DPOTrainer(
model,
model_ref,
args=training_args,
beta=args.beta,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
max_length=args.max_length,
max_target_length=args.max_target_length,
max_prompt_length=args.max_prompt_length,
generate_during_eval=args.generate_during_eval,
peft_config=get_peft_config(model_config),
)
with init_context:
trainer = DPOTrainer(
model,
model_ref,
args=training_args,
beta=args.beta,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
max_length=args.max_length,
max_target_length=args.max_target_length,
max_prompt_length=args.max_prompt_length,
generate_during_eval=args.generate_during_eval,
peft_config=get_peft_config(model_config),
callbacks=[RichProgressCallback] if TRL_USE_RICH else None,
)

trainer.train()
trainer.save_model(training_args.output_dir)

with save_context:
trainer.save_model(training_args.output_dir)
Loading

0 comments on commit a2aa0f0

Please sign in to comment.