Skip to content

Commit

Permalink
[SFT] add SFT Trainer Config dataclass (#1530)
Browse files Browse the repository at this point in the history
* initial SFT Config

* remove pdb

* fix chat_template

* undo formatting

* add back removed commits

* fix the tests

* add back options to SftScriptArguments

* use sft_script_args

* Update trl/commands/cli_utils.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/commands/cli_utils.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* rename SFTScriptArguments and split names

* formatting docstrings

* docstring

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
  • Loading branch information
kashif and lewtun authored Apr 23, 2024
1 parent 24fd8dd commit f30daa4
Show file tree
Hide file tree
Showing 11 changed files with 519 additions and 285 deletions.
141 changes: 80 additions & 61 deletions docs/source/sft_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -12,42 +12,47 @@ The following code-snippet takes care of all the data pre-processing and trainin

```python
from datasets import load_dataset
from trl import SFTTrainer
from trl import SFTConfig, SFTTrainer

dataset = load_dataset("imdb", split="train")

sft_config = SFTConfig(
dataset_text_field="text",
max_seq_length=512,
output_dir="/tmp",
)
trainer = SFTTrainer(
"facebook/opt-350m",
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=512,
args=training_args,
)
trainer.train()
```
Make sure to pass a correct value for `max_seq_length` as the default value will be set to `min(tokenizer.model_max_length, 1024)`.
Make sure to pass the correct value for `max_seq_length` as the default value will be set to `min(tokenizer.model_max_length, 1024)`.

You can also construct a model outside of the trainer and pass it as follows:

```python
from transformers import AutoModelForCausalLM
from datasets import load_dataset
from trl import SFTTrainer
from trl import SFTConfig, SFTTrainer

dataset = load_dataset("imdb", split="train")

model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")

sft_config = SFTConfig(output_dir="/tmp")

trainer = SFTTrainer(
model,
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=512,
args=sft_config,
)

trainer.train()
```

The above snippets will use the default training arguments from the [`transformers.TrainingArguments`](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments) class. If you want to modify that, make sure to create your own `TrainingArguments` object and pass it to the [`SFTTrainer`] constructor as it is done on the [`supervised_finetuning.py` script](https://github.com/huggingface/trl/blob/main/examples/stack_llama/scripts/supervised_finetuning.py) on the stack-llama example.
The above snippets will use the default training arguments from the [`SFTConfig`] class. If you want to modify the defaults pass in your modification to the `SFTConfig` constructor and pass them to the trainer via the `args` argument.

## Advanced usage

Expand All @@ -59,7 +64,7 @@ To instantiate that collator for instruction data, pass a response template and
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM

dataset = load_dataset("lucasmccabe-lmi/CodeAlpaca-20k", split="train")

Expand All @@ -79,6 +84,7 @@ collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenize
trainer = SFTTrainer(
model,
train_dataset=dataset,
args=SFTConfig(output_dir="/tmp"),
formatting_func=formatting_prompts_func,
data_collator=collator,
)
Expand All @@ -91,7 +97,7 @@ To instantiate that collator for assistant style conversation data, pass a respo
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM

dataset = load_dataset("timdettmers/openassistant-guanaco", split="train")

Expand All @@ -104,8 +110,8 @@ collator = DataCollatorForCompletionOnlyLM(instruction_template=instruction_temp

trainer = SFTTrainer(
model,
args=SFTConfig(output_dir="/tmp"),
train_dataset=dataset,
dataset_text_field="text",
data_collator=collator,
)

Expand All @@ -116,7 +122,7 @@ Make sure to have a `pad_token_id` which is different from `eos_token_id` which

#### Using token_ids directly for `response_template`

Some tokenizers like Llama 2 (`meta-llama/Llama-2-XXb-hf`) tokenize sequences differently depending whether they have context or not. For example:
Some tokenizers like Llama 2 (`meta-llama/Llama-2-XXb-hf`) tokenize sequences differently depending on whether they have context or not. For example:

```python
from transformers import AutoTokenizer
Expand Down Expand Up @@ -146,7 +152,7 @@ RuntimeError: Could not find response key [835, 4007, 22137, 29901] in token IDs
```


To solve this, you can tokenize the `response_template` with the same context than in the dataset, truncate it as needed and pass the `token_ids` directly to the `response_template` argument of the `DataCollatorForCompletionOnlyLM` class. For example:
To solve this, you can tokenize the `response_template` with the same context as in the dataset, truncate it as needed and pass the `token_ids` directly to the `response_template` argument of the `DataCollatorForCompletionOnlyLM` class. For example:

```python
response_template_with_context = "\n### Assistant:" # We added context here: "\n". This is enough for this tokenizer
Expand Down Expand Up @@ -199,7 +205,7 @@ If your dataset uses one of the above formats, you can directly pass it to the t

```python
from datasets import load_dataset
from trl import SFTTrainer
from trl import SFTConfig, SFTTrainer

...

Expand All @@ -210,15 +216,15 @@ dataset = load_dataset("philschmid/dolly-15k-oai-style", split="train")

...

sft_config = STFConfig(packing=True)
trainer = SFTTrainer(
"facebook/opt-350m",
args=training_args,
args=sft_config,
train_dataset=dataset,
packing=True,
)
```

If the dataset is not in one those format you can either preprocess the dataset to match the formatting or pass a formatting function to the SFTTrainer to do it for you. Let's have a look.
If the dataset is not in one of those format you can either preprocess the dataset to match the formatting or pass a formatting function to the SFTTrainer to do it for you. Let's have a look.


### Format your input prompts
Expand Down Expand Up @@ -246,13 +252,14 @@ def formatting_prompts_func(example):

trainer = SFTTrainer(
model,
args=sft_config,
train_dataset=dataset,
formatting_func=formatting_prompts_func,
)

trainer.train()
```
To properly format your input make sure to process all the examples by looping over them and returning a list of processed text. Check out a full example on how to use SFTTrainer on alpaca dataset [here](https://github.com/huggingface/trl/pull/444#issue-1760952763)
To properly format your input make sure to process all the examples by looping over them and returning a list of processed text. Check out a full example of how to use SFTTrainer on alpaca dataset [here](https://github.com/huggingface/trl/pull/444#issue-1760952763)

### Packing dataset ([`ConstantLengthDataset`])

Expand Down Expand Up @@ -283,10 +290,11 @@ def formatting_func(example):
text = f"### Question: {example['question']}\n ### Answer: {example['answer']}"
return text

sft_config = STFConfig(packing=True)
trainer = SFTTrainer(
"facebook/opt-350m",
train_dataset=dataset,
packing=True,
args=sft_config,
formatting_func=formatting_func
)

Expand All @@ -300,18 +308,19 @@ You can directly pass the kwargs of the `from_pretrained()` method to the [`SFTT

```python
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.bfloat16)
```

```python
...

sft_config = SFTConfig(
model_init_kwargs={
"torch_dtype": "bfloat16",
},
output_dir="/tmp",
)
trainer = SFTTrainer(
"facebook/opt-350m",
train_dataset=dataset,
dataset_text_field="text",
model_init_kwargs={
"torch_dtype": torch.bfloat16,
},
args=sft_config,
)

trainer.train()
Expand All @@ -320,11 +329,11 @@ Note that all keyword arguments of `from_pretrained()` are supported.

### Training adapters

We also support a tight integration with 🤗 PEFT library so that any user can conveniently train adapters and share them on the Hub instead of training the entire model
We also support tight integration with 🤗 PEFT library so that any user can conveniently train adapters and share them on the Hub instead of training the entire model

```python
from datasets import load_dataset
from trl import SFTTrainer
from trl import SFTConfig, SFTTrainer
from peft import LoraConfig

dataset = load_dataset("imdb", split="train")
Expand All @@ -340,7 +349,7 @@ peft_config = LoraConfig(
trainer = SFTTrainer(
"EleutherAI/gpt-neo-125m",
train_dataset=dataset,
dataset_text_field="text",
args=SFTConfig(output_dir="/tmp"),
peft_config=peft_config
)

Expand All @@ -351,7 +360,7 @@ You can also continue training your `PeftModel`. For that, first load a `PeftMod

### Training adapters with base 8 bit models

For that you need to first load your 8bit model outside the Trainer and pass a `PeftConfig` to the trainer. For example:
For that, you need to first load your 8 bit model outside the Trainer and pass a `PeftConfig` to the trainer. For example:

```python
...
Expand All @@ -373,7 +382,7 @@ model = AutoModelForCausalLM.from_pretrained(
trainer = SFTTrainer(
model,
train_dataset=dataset,
dataset_text_field="text",
args=STFConfig(),
peft_config=peft_config,
)

Expand Down Expand Up @@ -441,7 +450,7 @@ model = AutoModelForCausalLM.from_pretrained(
If you don't use quantization, make sure your model is loaded in half-precision and dispatch your model on a supported GPU device.
After loading your model, you can either train it as it is, or attach adapters and train adapters on it in case your model is quantized.

In contrary to Flash Attention 1, the integration makes it possible to train your model on an arbitrary dataset that also includes padding tokens.
In contrast to Flash Attention 1, the integration makes it possible to train your model on an arbitrary dataset that also includes padding tokens.


### Using model creation utility
Expand Down Expand Up @@ -479,10 +488,7 @@ trainer = SFTTrainer(
)
```




### Enhance model's performances using NEFTune
### Enhance the model's performances using NEFTune

NEFTune is a technique to boost the performance of chat models and was introduced by the paper ["NEFTune: Noisy Embeddings Improve Instruction Finetuning"](https://arxiv.org/abs/2310.05914) from Jain et al. it consists of adding noise to the embedding vectors during training. According to the abstract of the paper:

Expand All @@ -492,20 +498,21 @@ NEFTune is a technique to boost the performance of chat models and was introduce
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/neft-screenshot.png">
</div>

To use it in `SFTTrainer` simply pass `neftune_noise_alpha` when creating your `SFTTrainer` instance. Note that to avoid any surprising behaviour, NEFTune is disabled after training to retrieve back the original behaviour of the embedding layer.
To use it in `SFTTrainer` simply pass `neftune_noise_alpha` when creating your `SFTConfig` instance. Note that to avoid any surprising behaviour, NEFTune is disabled after training to retrieve back the original behaviour of the embedding layer.

```python
from datasets import load_dataset
from trl import SFTTrainer
from trl import STFConfig, SFTTrainer

dataset = load_dataset("imdb", split="train")

sft_config = STFConfig(
neftune_noise_alpha=5,
)
trainer = SFTTrainer(
"facebook/opt-350m",
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=512,
neftune_noise_alpha=5,
args=sft_config,
)
trainer.train()
```
Expand Down Expand Up @@ -533,42 +540,50 @@ First install `unsloth` according to the [official documentation](https://github

```python
import torch
from transformers import TrainingArguments
from trl import SFTTrainer
from trl import SFTConfig, SFTTrainer
from unsloth import FastLanguageModel

max_seq_length = 2048 # Supports automatic RoPE Scaling, so choose any number

# Load model
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "unsloth/mistral-7b",
max_seq_length = max_seq_length,
dtype = None, # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True, # Use 4bit quantization to reduce memory usage. Can be False
model_name="unsloth/mistral-7b",
max_seq_length=max_seq_length,
dtype=None, # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit=True, # Use 4bit quantization to reduce memory usage. Can be False
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)

# Do model patching and add fast LoRA weights
model = FastLanguageModel.get_peft_model(
model,
r = 16,
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",],
lora_alpha = 16,
lora_dropout = 0, # Dropout = 0 is currently optimized
bias = "none", # Bias = "none" is currently optimized
use_gradient_checkpointing = True,
random_state = 3407,
r=16,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
lora_alpha=16,
lora_dropout=0, # Dropout = 0 is currently optimized
bias="none", # Bias = "none" is currently optimized
use_gradient_checkpointing=True,
random_state=3407,
)

args = TrainingArguments(output_dir = "./output")
args = SFTConfig(
output_dir="./output",
max_seq_length=max_seq_length,
dataset_text_field="text",
)

trainer = SFTTrainer(
model = model,
args = args,
train_dataset = dataset,
dataset_text_field = "text",
max_seq_length = max_seq_length,
model=model,
args=args,
train_dataset=dataset,
)
trainer.train()
```
Expand All @@ -579,7 +594,7 @@ The saved model is fully compatible with Hugging Face's transformers library. Le

Pay attention to the following best practices when training a model with that trainer:

- [`SFTTrainer`] always pads by default the sequences to the `max_seq_length` argument of the [`SFTTrainer`]. If none is passed, the trainer will retrieve that value from the tokenizer. Some tokenizers do not provide default value, so there is a check to retrieve the minimum between 2048 and that value. Make sure to check it before training.
- [`SFTTrainer`] always pads by default the sequences to the `max_seq_length` argument of the [`SFTTrainer`]. If none is passed, the trainer will retrieve that value from the tokenizer. Some tokenizers do not provide a default value, so there is a check to retrieve the minimum between 2048 and that value. Make sure to check it before training.
- For training adapters in 8bit, you might need to tweak the arguments of the `prepare_model_for_kbit_training` method from PEFT, hence we advise users to use `prepare_in_int8_kwargs` field, or create the `PeftModel` outside the [`SFTTrainer`] and pass it.
- For a more memory-efficient training using adapters, you can load the base model in 8bit, for that simply add `load_in_8bit` argument when creating the [`SFTTrainer`], or create a base model in 8bit outside the trainer and pass it.
- If you create a model outside the trainer, make sure to not pass to the trainer any additional keyword arguments that are relative to `from_pretrained()` method.
Expand All @@ -606,6 +621,10 @@ You may experience some issues with GPTQ Quantization after completing training.

[[autodoc]] SFTTrainer

## SFTConfig

[[autodoc]] SFTConfig

## Datasets

In the SFTTrainer we smartly support `datasets.IterableDataset` in addition to other style datasets. This is useful if you are using large corpora that you do not want to save all to disk. The data will be tokenized and processed on the fly, even when packing is enabled.
Expand Down
Loading

0 comments on commit f30daa4

Please sign in to comment.