Skip to content

Commit

Permalink
Fix SFT for VLM example (#1865)
Browse files Browse the repository at this point in the history
* fix vsft example commands

* fix use_cache and get tokenizer from processor

* rm unused AutoTokenizer

* Squashed commit of the following:

commit 8bd2ab8
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Sun Jul 28 14:06:19 2024 +0200

    Refactor judges (#1856)

    * BaseJudge -> BasePairwiseJudge

    * hf judge asyncio

    * refactor judges

    * doc

    * doc

    * doc

    * memeber judge

    * :inherited-members:

    * :inherited-members:

    * doc

    * give up

    * judge tldr with judge class

    * fix rank in multithread

    * format

    * improve doc

    * update doc

    * typo doc

    * doc online dpo

    * Update judge_tldr.py

    ---------

    Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>

commit 82b07d6
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Fri Jul 26 11:43:48 2024 +0200

    Llama in modelling value head tests (#1878)

commit 72bf6c2
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Fri Jul 26 11:33:07 2024 +0200

    Skip BigBird save and load test until next transformers version (#1874)

commit 74e54b5
Author: Edward Beeching <edbeeching@users.noreply.github.com>
Date:   Fri Jul 26 09:36:25 2024 +0200

    fix online dpo example (#1879)

commit 3930973
Author: Rishav Dash <57321948+Rishav-hub@users.noreply.github.com>
Date:   Thu Jul 25 14:17:37 2024 +0530

    Bug Fix while training using SFTTrainer with DataCollatorForCompletionOnlyLM (#1861)

    * Bug Fix while training using SFTTrainer with DataCollatorForCompletionOnlyLM

    Added ```dataset_text_field``` in the SFTConfig while training

    * Update docs/source/sft_trainer.mdx

    ---------

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

commit db8e09e
Author: Rishav Dash <57321948+Rishav-hub@users.noreply.github.com>
Date:   Thu Jul 25 14:06:57 2024 +0530

    Import missing ```setup_chat_format``` (#1862)

commit 1dae55f
Author: elie <97572401+eliebak@users.noreply.github.com>
Date:   Thu Jul 25 10:27:34 2024 +0200

    add fsdp_qlora config and bnb_4bit_quant_storage (#1863)

commit c8cef79
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Wed Jul 24 21:06:57 2024 +0200

    arXiv to HF Papers (#1870)

commit 7dcf437
Author: Kashif Rasul <kashif.rasul@gmail.com>
Date:   Wed Jul 24 12:27:50 2024 +0200

    [online-DPO] online dpo cleanups (#1864)

    * online dpo cleanups

    * remove unused self.policy

    * add OnlineDPOTrainer and config to __init__.py

    * import from trainer

    * online dpo test

    * rename policy to model and ref_policy to ref_model

    * renamed internally

    * formatting

commit 4e85bd7
Author: Costa Huang <costa.huang@outlook.com>
Date:   Thu Jul 18 14:35:31 2024 -0400

    Online DPO and Online trainer refactor (#1809)

    * online dpo trainer based on rloo trainer

    * push changes

    * refactor

    * use `batch_generation` method

    * precommit

    * remove breakpoint()

    * quick refactor

    * push the current changes

    * quick change

    * refactor

    * use the config name as the experiment name

    * fix logging

    * update online DPO docs

    * push docs

    * increment global step so tensorboard works again.

    * precommit

    * remove unused common online trainer

    * add online DPO docs

    * quick refactor

    * push changes

    * Update docs/source/online_dpo_trainer.md

    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

    ---------

    Co-authored-by: Michael Noukhovitch <mnoukhov@gmail.com>
    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

commit c9d5636
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Thu Jul 18 18:28:49 2024 +0200

    rm token (#1852)

* add section in doc

* Squashed commit of the following:

commit 890232f
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Tue Jul 30 14:29:47 2024 +0200

    update example overview (#1883)

    Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>

commit 9929370
Author: Clara Pohland <54847419+claralp@users.noreply.github.com>
Date:   Sun Jul 28 21:10:08 2024 +0200

    Move BCO to separate BCOTrainer with fixes (#1869)

    * kto_trainer: skip KL data for BCO

    * kto_trainer: BCO allow no positives or no negatives in batch

    * kto_trainer: make RunningMoments object serializable

    * add BCOTrainer

    * fix BCO UDM for not interleaved data

    * kto_trainer: remove unused UDM part

    * bco_trainer: add tests and docs, minor fixes

    * code style fixes

    * Update docs/source/bco_trainer.mdx

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

    * fix BCO UDM for bfloat16

    * Update trl/trainer/bco_config.py

    * Update trl/trainer/bco_config.py

    Co-authored-by: Seungjae Jung <seanexplode@gmail.com>

    * Update trl/trainer/utils.py

    Co-authored-by: Seungjae Jung <seanexplode@gmail.com>

    * Update trl/trainer/bco_trainer.py

    Co-authored-by: Seungjae Jung <seanexplode@gmail.com>

    * Update trl/trainer/bco_config.py

    * Update _toctree.yml

    * Update trl/trainer/bco_config.py

    * Update trl/trainer/bco_trainer.py

    * RunningMoments, fix multi GPU serialization

    * fix tests

    ---------

    Co-authored-by: Clara Luise Pohland <clara-luise.pohland@telekom.de>
    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
    Co-authored-by: Seungjae Jung <seanexplode@gmail.com>

commit 6171cdd
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Sun Jul 28 15:51:38 2024 +0200

    Re-add BigBird Pegasus save/load test (#1882)

    Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>

commit 33d2151
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Sun Jul 28 15:07:10 2024 +0200

    Re-add BigBird Pegasus save/load test (#1876)

    * skip bigbird in ci

    * readd big bird test

    * pytest parametrize

    * dont check the version

    * rm model name

    * re add big bird

    * Merge branch 'main' into readd-bigbird-save-load-test

    ---------

    Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>

commit 8bd2ab8
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Sun Jul 28 14:06:19 2024 +0200

    Refactor judges (#1856)

    * BaseJudge -> BasePairwiseJudge

    * hf judge asyncio

    * refactor judges

    * doc

    * doc

    * doc

    * memeber judge

    * :inherited-members:

    * :inherited-members:

    * doc

    * give up

    * judge tldr with judge class

    * fix rank in multithread

    * format

    * improve doc

    * update doc

    * typo doc

    * doc online dpo

    * Update judge_tldr.py

    ---------

    Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>

commit 82b07d6
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Fri Jul 26 11:43:48 2024 +0200

    Llama in modelling value head tests (#1878)

commit 72bf6c2
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Fri Jul 26 11:33:07 2024 +0200

    Skip BigBird save and load test until next transformers version (#1874)

commit 74e54b5
Author: Edward Beeching <edbeeching@users.noreply.github.com>
Date:   Fri Jul 26 09:36:25 2024 +0200

    fix online dpo example (#1879)

commit 3930973
Author: Rishav Dash <57321948+Rishav-hub@users.noreply.github.com>
Date:   Thu Jul 25 14:17:37 2024 +0530

    Bug Fix while training using SFTTrainer with DataCollatorForCompletionOnlyLM (#1861)

    * Bug Fix while training using SFTTrainer with DataCollatorForCompletionOnlyLM

    Added ```dataset_text_field``` in the SFTConfig while training

    * Update docs/source/sft_trainer.mdx

    ---------

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

commit db8e09e
Author: Rishav Dash <57321948+Rishav-hub@users.noreply.github.com>
Date:   Thu Jul 25 14:06:57 2024 +0530

    Import missing ```setup_chat_format``` (#1862)

commit 1dae55f
Author: elie <97572401+eliebak@users.noreply.github.com>
Date:   Thu Jul 25 10:27:34 2024 +0200

    add fsdp_qlora config and bnb_4bit_quant_storage (#1863)

commit c8cef79
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Wed Jul 24 21:06:57 2024 +0200

    arXiv to HF Papers (#1870)

commit 7dcf437
Author: Kashif Rasul <kashif.rasul@gmail.com>
Date:   Wed Jul 24 12:27:50 2024 +0200

    [online-DPO] online dpo cleanups (#1864)

    * online dpo cleanups

    * remove unused self.policy

    * add OnlineDPOTrainer and config to __init__.py

    * import from trainer

    * online dpo test

    * rename policy to model and ref_policy to ref_model

    * renamed internally

    * formatting

commit 4e85bd7
Author: Costa Huang <costa.huang@outlook.com>
Date:   Thu Jul 18 14:35:31 2024 -0400

    Online DPO and Online trainer refactor (#1809)

    * online dpo trainer based on rloo trainer

    * push changes

    * refactor

    * use `batch_generation` method

    * precommit

    * remove breakpoint()

    * quick refactor

    * push the current changes

    * quick change

    * refactor

    * use the config name as the experiment name

    * fix logging

    * update online DPO docs

    * push docs

    * increment global step so tensorboard works again.

    * precommit

    * remove unused common online trainer

    * add online DPO docs

    * quick refactor

    * push changes

    * Update docs/source/online_dpo_trainer.md

    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

    ---------

    Co-authored-by: Michael Noukhovitch <mnoukhov@gmail.com>
    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

commit c9d5636
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Thu Jul 18 18:28:49 2024 +0200

    rm token (#1852)

* simplify script

* doc

* use traning args

* args instead of trianing args

* fix doc

* drop eval

* rm eval section

* re-add bigbirg

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
  • Loading branch information
qgallouedec and qgallouedec authored Aug 2, 2024
1 parent ddf4c8d commit df12913
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 96 deletions.
124 changes: 118 additions & 6 deletions docs/source/sft_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -533,12 +533,12 @@ Note however, that the amount of performance gain is _dataset dependent_ and in

You can further accelerate QLoRA / LoRA (2x faster, 60% less memory) using the [`unsloth`](https://github.com/unslothai/unsloth) library that is fully compatible with `SFTTrainer`. Currently `unsloth` supports only Llama (Yi, TinyLlama, Qwen, Deepseek etc) and Mistral architectures. Some benchmarks on 1x A100 listed below:

| 1 A100 40GB | Dataset | 🤗 | 🤗 + Flash Attention 2 | 🦥 Unsloth | 🦥 VRAM saved |
|-----------------|-----------|-----|-------------------------|-----------------|----------------|
| Code Llama 34b | Slim Orca | 1x | 1.01x | **1.94x** | -22.7% |
| Llama-2 7b | Slim Orca | 1x | 0.96x | **1.87x** | -39.3% |
| Mistral 7b | Slim Orca | 1x | 1.17x | **1.88x** | -65.9% |
| Tiny Llama 1.1b | Alpaca | 1x | 1.55x | **2.74x** | -57.8% |
| 1 A100 40GB | Dataset | 🤗 | 🤗 + Flash Attention 2 | 🦥 Unsloth | 🦥 VRAM saved |
| --------------- | --------- | --- | --------------------- | --------- | ------------ |
| Code Llama 34b | Slim Orca | 1x | 1.01x | **1.94x** | -22.7% |
| Llama-2 7b | Slim Orca | 1x | 0.96x | **1.87x** | -39.3% |
| Mistral 7b | Slim Orca | 1x | 1.17x | **1.88x** | -65.9% |
| Tiny Llama 1.1b | Alpaca | 1x | 1.55x | **2.74x** | -57.8% |

First install `unsloth` according to the [official documentation](https://github.com/unslothai/unsloth). Once installed, you can incorporate unsloth into your workflow in a very simple manner; instead of loading `AutoModelForCausalLM`, you just need to load a `FastLanguageModel` as follows:

Expand Down Expand Up @@ -621,6 +621,118 @@ model = AutoModelForCausalLM.from_pretrained(

You may experience some issues with GPTQ Quantization after completing training. Lowering `gradient_accumulation_steps` to `4` will resolve most issues during the quantization process to GPTQ format.

## Extending `SFTTrainer` for Vision Language Models

`SFTTrainer` does not inherently support vision-language data. However, we provide a guide on how to tweak the trainer to support vision-language data. Specifically, you need to use a custom data collator that is compatible with vision-language data. This guide outlines the steps to make these adjustments. For a concrete example, refer to the script [`examples/scripts/vsft_llava.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/vsft_llava.py) which demonstrates how to fine-tune the LLaVA 1.5 model on the [HuggingFaceH4/llava-instruct-mix-vsft](https://huggingface.co/datasets/HuggingFaceH4/llava-instruct-mix-vsft) dataset.

### Preparing the Data

The data format is flexible, provided it is compatible with the custom collator that we will define later. A common approach is to use conversational data. Given that the data includes both text and images, the format needs to be adjusted accordingly. Below is an example of a conversational data format involving both text and images:

```python
images = ["obama.png"]
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Who is this?"},
{"type": "image"}
]
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "Barack Obama"}
]
},
{
"role": "user",
"content": [
{"type": "text", "text": "What is he famous for?"}
]
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "He is the 44th President of the United States."}
]
}
]
```

To illustrate how this data format will be processed using the LLaVA model, you can use the following code:

```python
from transformers import AutoProcessor

processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
print(processor.apply_chat_template(messages, tokenize=False))
```

The output will be formatted as follows:

```txt
Who is this? ASSISTANT: Barack Obama USER: What is he famous for? ASSISTANT: He is the 44th President of the United States.
```

<iframe src="https://huggingface.co/datasets/HuggingFaceH4/llava-instruct-mix-vsft/embed/viewer/default/train" frameborder="0" width="100%" height="560px"></iframe>


### A custom collator for processing multi-modal data

Unlike the default behavior of `SFTTrainer`, processing multi-modal data is done on the fly during the data collation process. To do this, you need to define a custom collator that processes both the text and images. This collator must take a list of examples as input (see the previous section for an example of the data format) and return a batch of processed data. Below is an example of such a collator:

```python
def collate_fn(examples):
# Get the texts and images, and apply the chat template
texts = [processor.apply_chat_template(example["messages"], tokenize=False) for example in examples]
images = [example["images"][0] for example in examples]

# Tokenize the texts and process the images
batch = processor(texts, images, return_tensors="pt", padding=True)

# The labels are the input_ids, and we mask the padding tokens in the loss computation
labels = batch["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = -100
batch["labels"] = labels

return batch
```

We can verify that the collator works as expected by running the following code:

```python
from datasets import load_dataset

dataset = load_dataset("HuggingFaceH4/llava-instruct-mix-vsft", split="train")
examples = [dataset[0], dataset[1]] # Just two examples for the sake of the example
collated_data = collate_fn(examples)
print(collated_data.keys()) # dict_keys(['input_ids', 'attention_mask', 'pixel_values', 'labels'])
```

### Training the vision-language model

Now that we have prepared the data and defined the collator, we can proceed with training the model. To ensure that the data is not processed as text-only, we need to set a couple of arguments in the `SFTConfig`, specifically `dataset_text_field` and `remove_unused_columns`. We also need to set `skip_prepare_dataset` to `True` to avoid the default processing of the dataset. Below is an example of how to set up the `SFTTrainer`.

```python
args.dataset_text_field = "" # needs a dummy field
args.remove_unused_columns = False
args.dataset_kwargs = {"skip_prepare_dataset": True}

trainer = SFTTrainer(
model=model,
args=args,
data_collator=collate_fn,
train_dataset=train_dataset,
tokenizer=processor.tokenizer,
)
```

A full example of training LLaVa 1.5 on the [HuggingFaceH4/llava-instruct-mix-vsft](https://huggingface.co/datasets/HuggingFaceH4/llava-instruct-mix-vsft) dataset can be found in the script [`examples/scripts/vsft_llava.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/vsft_llava.py).

- [Experiment tracking](https://wandb.ai/huggingface/trl/runs/2b2c5l7s)
- [Trained model](https://huggingface.co/HuggingFaceH4/sft-llava-1.5-7b-hf)

## SFTTrainer

[[autodoc]] SFTTrainer
Expand Down
113 changes: 29 additions & 84 deletions examples/scripts/vsft_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,55 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
# regular:
python examples/scripts/vsft_llava.py \
--dataset_name="HuggingFaceH4/llava-instruct-mix-vsft" \
--model_name_or_path="llava-hf/llava-1.5-7b-hf" \
--report_to="wandb" \
--learning_rate=1.4e-5 \
--per_device_train_batch_size=8 \
--gradient_accumulation_steps=1 \
--output_dir="data/vsft-llava-1.5-7b-hf" \
--logging_steps=5 \
--num_train_epochs=1 \
--push_to_hub \
--gradient_checkpointing \
--remove_unused_columns=False \
--torch_dtype=float16 \
--fp16=True
# peft:
pip install pillow
python examples/scripts/vsft_llava.py \
--dataset_name="HuggingFaceH4/llava-instruct-mix-vsft" \
--model_name_or_path="llava-hf/llava-1.5-7b-hf" \
--report_to="wandb" \
--learning_rate=1.4e-5 \
--per_device_train_batch_size=8 \
--gradient_accumulation_steps=1 \
--output_dir="data/vsft-llava-1.5-7b-hf" \
--logging_steps=5 \
--num_train_epochs=1 \
--push_to_hub \
--dataset_name HuggingFaceH4/llava-instruct-mix-vsft \
--model_name_or_path llava-hf/llava-1.5-7b-hf \
--per_device_train_batch_size 8 \
--gradient_accumulation_steps 8 \
--output_dir sft-llava-1.5-7b-hf \
--bf16 \
--torch_dtype bfloat16 \
--gradient_checkpointing \
--remove_unused_columns=False \
--torch_dtype=float16 \
--fp16=True \
--use_peft=True \
--lora_r=64 \
--lora_alpha=16 \
--lora_target_modules=all-linear"
# evaluation:
To evaluate, first install the lmms-eval framework: pip install git+https://github.com/EvolvingLMMs-Lab/lmms-eval.git
then run:
accelerate launch --num_processes=8 -m lmms_eval \
--model llava_hf \
--model_args pretrained=llava-hf/llava-1.5-7b-hf \
--tasks mmbench \
--batch_size 1 \
--output_path ./logs/ \
--log_sample
--use_peft \
--dataloader_num_workers 32 \
--lora_target_modules=all-linear
"""

import logging
Expand All @@ -85,7 +50,7 @@
from datasets import load_dataset

from tqdm.rich import tqdm
from transformers import AutoTokenizer, AutoProcessor, LlavaForConditionalGeneration
from transformers import AutoProcessor, LlavaForConditionalGeneration

from trl import (
ModelConfig,
Expand All @@ -107,6 +72,9 @@
parser = TrlParser((SFTScriptArguments, SFTConfig, ModelConfig))
sft_script_args, training_args, model_config = parser.parse_args_and_config()
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
training_args.dataset_text_field = "" # need a dummy field
training_args.remove_unused_columns = False
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
# Force use our print callback
if TRL_USE_RICH:
training_args.disable_tqdm = True
Expand All @@ -115,8 +83,6 @@
################
# Model, Tokenizer & Processor
################
LLAVA_CHAT_TEMPLATE = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. {% for message in messages %}{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}<image>{% endif %}{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}{% if add_generation_prompt %}ASSISTANT: {% endif %}"""

torch_dtype = (
model_config.torch_dtype
if model_config.torch_dtype in ["auto", None]
Expand All @@ -130,14 +96,9 @@
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
tokenizer = AutoTokenizer.from_pretrained(
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, use_fast=True
)
tokenizer.chat_template = LLAVA_CHAT_TEMPLATE
processor = AutoProcessor.from_pretrained(
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code
)
processor.tokenizer = tokenizer

model = LlavaForConditionalGeneration.from_pretrained(
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs
Expand All @@ -146,34 +107,20 @@
################
# Create a data collator to encode text and image pairs
################
def collate_fn(examples):
# Get the texts and images, and apply the chat template
texts = [processor.apply_chat_template(example["messages"], tokenize=False) for example in examples]
images = [example["images"][0] for example in examples]

class LLavaDataCollator:
def __init__(self, processor):
self.processor = processor

def __call__(self, examples):
texts = []
images = []
for example in examples:
if len(example["images"]) > 1:
raise ValueError("This collator only supports one image per example")
messages = example["messages"]
text = self.processor.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=False
)
texts.append(text)
images.append(example["images"][0])

batch = self.processor(texts, images, return_tensors="pt", padding=True)

labels = batch["input_ids"].clone()
if self.processor.tokenizer.pad_token_id is not None:
labels[labels == self.processor.tokenizer.pad_token_id] = -100
batch["labels"] = labels
# Tokenize the texts and process the images
batch = processor(texts, images, return_tensors="pt", padding=True)

return batch
# The labels are the input_ids, and we mask the padding tokens in the loss computation
labels = batch["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = -100
batch["labels"] = labels

data_collator = LLavaDataCollator(processor)
return batch

################
# Dataset
Expand All @@ -199,14 +146,12 @@ def __call__(self, examples):
trainer = SFTTrainer(
model=model,
args=training_args,
data_collator=collate_fn,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
dataset_text_field="text", # need a dummy field
tokenizer=tokenizer,
tokenizer=processor.tokenizer,
peft_config=get_peft_config(model_config),
callbacks=[RichProgressCallback] if TRL_USE_RICH else None,
data_collator=data_collator,
dataset_kwargs={"skip_prepare_dataset": True},
)

trainer.train()
Expand Down
4 changes: 2 additions & 2 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,15 +174,15 @@ def __init__(
PyTorch dataset or Hugging Face dataset. If a Hugging Face dataset is passed, the dataset
will be preprocessed by removing the columns that are not used by the model. If none is passed,
a warning will be raised in a multi-GPU setting.
optimizer (Optional[`torch.optim.Optimizer`]):
optimizer (`Optional[torch.optim.Optimizer]`):
Optimizer used for training. If `None`, the `Adam` is used as default.
data_collator (Optional[function]):
Data collator function that is going to be used for `prepare_dataloader` method. Note this collator
is different from the one we use for training. Pass a valid `training_data_collator` instead.
num_shared_layers (Optional[int]):
Number of shared layers between the model and the reference model. If `None`, all layers are shared.
used only if `ref_model` is `None`.
lr_scheduler (Optional[`torch.optim.lr_scheduler`]):
lr_scheduler (`Optional[torch.optim.lr_scheduler]`):
Learning rate scheduler used for training.
training_data_collator (Optional[function]):
Custom data collator used for training.
Expand Down
8 changes: 4 additions & 4 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,16 @@ class SFTTrainer(Trainer):
The model to train, can be a `PreTrainedModel`, a `torch.nn.Module` or a string with the model name to
load from cache or download. The model can be also converted to a `PeftModel` if a `PeftConfig` object is
passed to the `peft_config` argument.
args (Optional[`SFTConfig`]):
args (`Optional[SFTConfig]`):
The arguments to tweak for training. Will default to a basic instance of [`SFTConfig`] with the `output_dir`
set to a directory named *tmp_trainer* in the current directory if not provided.
data_collator (Optional[`transformers.DataCollator`]):
data_collator (`Optional[transformers.DataCollator]`):
The data collator to use for training.
train_dataset (Optional[`datasets.Dataset`]):
train_dataset (`Optional[datasets.Dataset]`):
The dataset to use for training. We recommend users to use `trl.trainer.ConstantLengthDataset` to create their dataset.
eval_dataset (Optional[Union[`datasets.Dataset`, Dict[`str`, `datasets.Dataset`]]]):
The dataset to use for evaluation. We recommend users to use `trl.trainer.ConstantLengthDataset` to create their dataset.
tokenizer (Optional[`transformers.PreTrainedTokenizer`]):
tokenizer (`Optional[transformers.PreTrainedTokenizer]`):
The tokenizer to use for training. If not specified, the tokenizer associated to the model will be used.
model_init (`Callable[[], transformers.PreTrainedModel]`):
The model initializer to use for training. If None is specified, the default model initializer will be used.
Expand Down

0 comments on commit df12913

Please sign in to comment.