diff --git a/docs/source/sft_trainer.mdx b/docs/source/sft_trainer.mdx
index 6128704ddf..85cef8922e 100644
--- a/docs/source/sft_trainer.mdx
+++ b/docs/source/sft_trainer.mdx
@@ -12,42 +12,47 @@ The following code-snippet takes care of all the data pre-processing and trainin
```python
from datasets import load_dataset
-from trl import SFTTrainer
+from trl import SFTConfig, SFTTrainer
dataset = load_dataset("imdb", split="train")
+sft_config = SFTConfig(
+ dataset_text_field="text",
+ max_seq_length=512,
+ output_dir="/tmp",
+)
trainer = SFTTrainer(
"facebook/opt-350m",
train_dataset=dataset,
- dataset_text_field="text",
- max_seq_length=512,
+ args=training_args,
)
trainer.train()
```
-Make sure to pass a correct value for `max_seq_length` as the default value will be set to `min(tokenizer.model_max_length, 1024)`.
+Make sure to pass the correct value for `max_seq_length` as the default value will be set to `min(tokenizer.model_max_length, 1024)`.
You can also construct a model outside of the trainer and pass it as follows:
```python
from transformers import AutoModelForCausalLM
from datasets import load_dataset
-from trl import SFTTrainer
+from trl import SFTConfig, SFTTrainer
dataset = load_dataset("imdb", split="train")
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
+sft_config = SFTConfig(output_dir="/tmp")
+
trainer = SFTTrainer(
model,
train_dataset=dataset,
- dataset_text_field="text",
- max_seq_length=512,
+ args=sft_config,
)
trainer.train()
```
-The above snippets will use the default training arguments from the [`transformers.TrainingArguments`](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments) class. If you want to modify that, make sure to create your own `TrainingArguments` object and pass it to the [`SFTTrainer`] constructor as it is done on the [`supervised_finetuning.py` script](https://github.com/huggingface/trl/blob/main/examples/stack_llama/scripts/supervised_finetuning.py) on the stack-llama example.
+The above snippets will use the default training arguments from the [`SFTConfig`] class. If you want to modify the defaults pass in your modification to the `SFTConfig` constructor and pass them to the trainer via the `args` argument.
## Advanced usage
@@ -59,7 +64,7 @@ To instantiate that collator for instruction data, pass a response template and
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
-from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
+from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM
dataset = load_dataset("lucasmccabe-lmi/CodeAlpaca-20k", split="train")
@@ -79,6 +84,7 @@ collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenize
trainer = SFTTrainer(
model,
train_dataset=dataset,
+ args=SFTConfig(output_dir="/tmp"),
formatting_func=formatting_prompts_func,
data_collator=collator,
)
@@ -91,7 +97,7 @@ To instantiate that collator for assistant style conversation data, pass a respo
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
-from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
+from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM
dataset = load_dataset("timdettmers/openassistant-guanaco", split="train")
@@ -104,8 +110,8 @@ collator = DataCollatorForCompletionOnlyLM(instruction_template=instruction_temp
trainer = SFTTrainer(
model,
+ args=SFTConfig(output_dir="/tmp"),
train_dataset=dataset,
- dataset_text_field="text",
data_collator=collator,
)
@@ -116,7 +122,7 @@ Make sure to have a `pad_token_id` which is different from `eos_token_id` which
#### Using token_ids directly for `response_template`
-Some tokenizers like Llama 2 (`meta-llama/Llama-2-XXb-hf`) tokenize sequences differently depending whether they have context or not. For example:
+Some tokenizers like Llama 2 (`meta-llama/Llama-2-XXb-hf`) tokenize sequences differently depending on whether they have context or not. For example:
```python
from transformers import AutoTokenizer
@@ -146,7 +152,7 @@ RuntimeError: Could not find response key [835, 4007, 22137, 29901] in token IDs
```
-To solve this, you can tokenize the `response_template` with the same context than in the dataset, truncate it as needed and pass the `token_ids` directly to the `response_template` argument of the `DataCollatorForCompletionOnlyLM` class. For example:
+To solve this, you can tokenize the `response_template` with the same context as in the dataset, truncate it as needed and pass the `token_ids` directly to the `response_template` argument of the `DataCollatorForCompletionOnlyLM` class. For example:
```python
response_template_with_context = "\n### Assistant:" # We added context here: "\n". This is enough for this tokenizer
@@ -199,7 +205,7 @@ If your dataset uses one of the above formats, you can directly pass it to the t
```python
from datasets import load_dataset
-from trl import SFTTrainer
+from trl import SFTConfig, SFTTrainer
...
@@ -210,15 +216,15 @@ dataset = load_dataset("philschmid/dolly-15k-oai-style", split="train")
...
+sft_config = STFConfig(packing=True)
trainer = SFTTrainer(
"facebook/opt-350m",
- args=training_args,
+ args=sft_config,
train_dataset=dataset,
- packing=True,
)
```
-If the dataset is not in one those format you can either preprocess the dataset to match the formatting or pass a formatting function to the SFTTrainer to do it for you. Let's have a look.
+If the dataset is not in one of those format you can either preprocess the dataset to match the formatting or pass a formatting function to the SFTTrainer to do it for you. Let's have a look.
### Format your input prompts
@@ -246,13 +252,14 @@ def formatting_prompts_func(example):
trainer = SFTTrainer(
model,
+ args=sft_config,
train_dataset=dataset,
formatting_func=formatting_prompts_func,
)
trainer.train()
```
-To properly format your input make sure to process all the examples by looping over them and returning a list of processed text. Check out a full example on how to use SFTTrainer on alpaca dataset [here](https://github.com/huggingface/trl/pull/444#issue-1760952763)
+To properly format your input make sure to process all the examples by looping over them and returning a list of processed text. Check out a full example of how to use SFTTrainer on alpaca dataset [here](https://github.com/huggingface/trl/pull/444#issue-1760952763)
### Packing dataset ([`ConstantLengthDataset`])
@@ -283,10 +290,11 @@ def formatting_func(example):
text = f"### Question: {example['question']}\n ### Answer: {example['answer']}"
return text
+sft_config = STFConfig(packing=True)
trainer = SFTTrainer(
"facebook/opt-350m",
train_dataset=dataset,
- packing=True,
+ args=sft_config,
formatting_func=formatting_func
)
@@ -300,18 +308,19 @@ You can directly pass the kwargs of the `from_pretrained()` method to the [`SFTT
```python
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.bfloat16)
-```
-```python
...
+sft_config = SFTConfig(
+ model_init_kwargs={
+ "torch_dtype": "bfloat16",
+ },
+ output_dir="/tmp",
+)
trainer = SFTTrainer(
"facebook/opt-350m",
train_dataset=dataset,
- dataset_text_field="text",
- model_init_kwargs={
- "torch_dtype": torch.bfloat16,
- },
+ args=sft_config,
)
trainer.train()
@@ -320,11 +329,11 @@ Note that all keyword arguments of `from_pretrained()` are supported.
### Training adapters
-We also support a tight integration with 🤗 PEFT library so that any user can conveniently train adapters and share them on the Hub instead of training the entire model
+We also support tight integration with 🤗 PEFT library so that any user can conveniently train adapters and share them on the Hub instead of training the entire model
```python
from datasets import load_dataset
-from trl import SFTTrainer
+from trl import SFTConfig, SFTTrainer
from peft import LoraConfig
dataset = load_dataset("imdb", split="train")
@@ -340,7 +349,7 @@ peft_config = LoraConfig(
trainer = SFTTrainer(
"EleutherAI/gpt-neo-125m",
train_dataset=dataset,
- dataset_text_field="text",
+ args=SFTConfig(output_dir="/tmp"),
peft_config=peft_config
)
@@ -351,7 +360,7 @@ You can also continue training your `PeftModel`. For that, first load a `PeftMod
### Training adapters with base 8 bit models
-For that you need to first load your 8bit model outside the Trainer and pass a `PeftConfig` to the trainer. For example:
+For that, you need to first load your 8 bit model outside the Trainer and pass a `PeftConfig` to the trainer. For example:
```python
...
@@ -373,7 +382,7 @@ model = AutoModelForCausalLM.from_pretrained(
trainer = SFTTrainer(
model,
train_dataset=dataset,
- dataset_text_field="text",
+ args=STFConfig(),
peft_config=peft_config,
)
@@ -441,7 +450,7 @@ model = AutoModelForCausalLM.from_pretrained(
If you don't use quantization, make sure your model is loaded in half-precision and dispatch your model on a supported GPU device.
After loading your model, you can either train it as it is, or attach adapters and train adapters on it in case your model is quantized.
-In contrary to Flash Attention 1, the integration makes it possible to train your model on an arbitrary dataset that also includes padding tokens.
+In contrast to Flash Attention 1, the integration makes it possible to train your model on an arbitrary dataset that also includes padding tokens.
### Using model creation utility
@@ -479,10 +488,7 @@ trainer = SFTTrainer(
)
```
-
-
-
-### Enhance model's performances using NEFTune
+### Enhance the model's performances using NEFTune
NEFTune is a technique to boost the performance of chat models and was introduced by the paper ["NEFTune: Noisy Embeddings Improve Instruction Finetuning"](https://arxiv.org/abs/2310.05914) from Jain et al. it consists of adding noise to the embedding vectors during training. According to the abstract of the paper:
@@ -492,20 +498,21 @@ NEFTune is a technique to boost the performance of chat models and was introduce
-To use it in `SFTTrainer` simply pass `neftune_noise_alpha` when creating your `SFTTrainer` instance. Note that to avoid any surprising behaviour, NEFTune is disabled after training to retrieve back the original behaviour of the embedding layer.
+To use it in `SFTTrainer` simply pass `neftune_noise_alpha` when creating your `SFTConfig` instance. Note that to avoid any surprising behaviour, NEFTune is disabled after training to retrieve back the original behaviour of the embedding layer.
```python
from datasets import load_dataset
-from trl import SFTTrainer
+from trl import STFConfig, SFTTrainer
dataset = load_dataset("imdb", split="train")
+sft_config = STFConfig(
+ neftune_noise_alpha=5,
+)
trainer = SFTTrainer(
"facebook/opt-350m",
train_dataset=dataset,
- dataset_text_field="text",
- max_seq_length=512,
- neftune_noise_alpha=5,
+ args=sft_config,
)
trainer.train()
```
@@ -533,42 +540,50 @@ First install `unsloth` according to the [official documentation](https://github
```python
import torch
-from transformers import TrainingArguments
-from trl import SFTTrainer
+from trl import SFTConfig, SFTTrainer
from unsloth import FastLanguageModel
max_seq_length = 2048 # Supports automatic RoPE Scaling, so choose any number
# Load model
model, tokenizer = FastLanguageModel.from_pretrained(
- model_name = "unsloth/mistral-7b",
- max_seq_length = max_seq_length,
- dtype = None, # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
- load_in_4bit = True, # Use 4bit quantization to reduce memory usage. Can be False
+ model_name="unsloth/mistral-7b",
+ max_seq_length=max_seq_length,
+ dtype=None, # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
+ load_in_4bit=True, # Use 4bit quantization to reduce memory usage. Can be False
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)
# Do model patching and add fast LoRA weights
model = FastLanguageModel.get_peft_model(
model,
- r = 16,
- target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
- "gate_proj", "up_proj", "down_proj",],
- lora_alpha = 16,
- lora_dropout = 0, # Dropout = 0 is currently optimized
- bias = "none", # Bias = "none" is currently optimized
- use_gradient_checkpointing = True,
- random_state = 3407,
+ r=16,
+ target_modules=[
+ "q_proj",
+ "k_proj",
+ "v_proj",
+ "o_proj",
+ "gate_proj",
+ "up_proj",
+ "down_proj",
+ ],
+ lora_alpha=16,
+ lora_dropout=0, # Dropout = 0 is currently optimized
+ bias="none", # Bias = "none" is currently optimized
+ use_gradient_checkpointing=True,
+ random_state=3407,
)
-args = TrainingArguments(output_dir = "./output")
+args = SFTConfig(
+ output_dir="./output",
+ max_seq_length=max_seq_length,
+ dataset_text_field="text",
+)
trainer = SFTTrainer(
- model = model,
- args = args,
- train_dataset = dataset,
- dataset_text_field = "text",
- max_seq_length = max_seq_length,
+ model=model,
+ args=args,
+ train_dataset=dataset,
)
trainer.train()
```
@@ -579,7 +594,7 @@ The saved model is fully compatible with Hugging Face's transformers library. Le
Pay attention to the following best practices when training a model with that trainer:
-- [`SFTTrainer`] always pads by default the sequences to the `max_seq_length` argument of the [`SFTTrainer`]. If none is passed, the trainer will retrieve that value from the tokenizer. Some tokenizers do not provide default value, so there is a check to retrieve the minimum between 2048 and that value. Make sure to check it before training.
+- [`SFTTrainer`] always pads by default the sequences to the `max_seq_length` argument of the [`SFTTrainer`]. If none is passed, the trainer will retrieve that value from the tokenizer. Some tokenizers do not provide a default value, so there is a check to retrieve the minimum between 2048 and that value. Make sure to check it before training.
- For training adapters in 8bit, you might need to tweak the arguments of the `prepare_model_for_kbit_training` method from PEFT, hence we advise users to use `prepare_in_int8_kwargs` field, or create the `PeftModel` outside the [`SFTTrainer`] and pass it.
- For a more memory-efficient training using adapters, you can load the base model in 8bit, for that simply add `load_in_8bit` argument when creating the [`SFTTrainer`], or create a base model in 8bit outside the trainer and pass it.
- If you create a model outside the trainer, make sure to not pass to the trainer any additional keyword arguments that are relative to `from_pretrained()` method.
@@ -606,6 +621,10 @@ You may experience some issues with GPTQ Quantization after completing training.
[[autodoc]] SFTTrainer
+## SFTConfig
+
+[[autodoc]] SFTConfig
+
## Datasets
In the SFTTrainer we smartly support `datasets.IterableDataset` in addition to other style datasets. This is useful if you are using large corpora that you do not want to save all to disk. The data will be tokenized and processed on the fly, even when packing is enabled.
diff --git a/examples/scripts/sft.py b/examples/scripts/sft.py
index 6fa9a27f18..80ad65d96b 100644
--- a/examples/scripts/sft.py
+++ b/examples/scripts/sft.py
@@ -25,7 +25,7 @@
--num_train_epochs=3 \
--max_steps=-1 \
--push_to_hub \
- --gradient_checkpointing \
+ --gradient_checkpointing
# peft:
python examples/scripts/sft.py \
@@ -44,13 +44,14 @@
--lora_r=64 \
--lora_alpha=16
"""
+
import logging
import os
from contextlib import nullcontext
TRL_USE_RICH = os.environ.get("TRL_USE_RICH", False)
-from trl.commands.cli_utils import init_zero_verbose, SftScriptArguments, TrlParser
+from trl.commands.cli_utils import init_zero_verbose, SFTScriptArguments, TrlParser
if TRL_USE_RICH:
init_zero_verbose()
@@ -63,11 +64,12 @@
from datasets import load_dataset
from tqdm.rich import tqdm
-from transformers import AutoTokenizer, TrainingArguments
+from transformers import AutoTokenizer
from trl import (
ModelConfig,
RichProgressCallback,
+ SFTConfig,
SFTTrainer,
get_peft_config,
get_quantization_config,
@@ -81,7 +83,7 @@
if __name__ == "__main__":
- parser = TrlParser((SftScriptArguments, TrainingArguments, ModelConfig))
+ parser = TrlParser((SFTScriptArguments, SFTConfig, ModelConfig))
args, training_args, model_config = parser.parse_args_and_config()
# Force use our print callback
@@ -115,8 +117,8 @@
################
raw_datasets = load_dataset(args.dataset_name)
- train_dataset = raw_datasets[args.dataset_train_name]
- eval_dataset = raw_datasets[args.dataset_test_name]
+ train_dataset = raw_datasets[args.dataset_train_split]
+ eval_dataset = raw_datasets[args.dataset_test_split]
################
# Optional rich context managers
@@ -138,10 +140,7 @@
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
- dataset_text_field=args.dataset_text_field,
- max_seq_length=args.max_seq_length,
tokenizer=tokenizer,
- packing=args.packing,
peft_config=get_peft_config(model_config),
callbacks=[RichProgressCallback] if TRL_USE_RICH else None,
)
diff --git a/examples/scripts/vsft_llava.py b/examples/scripts/vsft_llava.py
index 64ef239d61..85cb98d5f3 100644
--- a/examples/scripts/vsft_llava.py
+++ b/examples/scripts/vsft_llava.py
@@ -14,7 +14,8 @@
# limitations under the License.
"""
# regular:
-python examples/scripts/vsft.py \
+python examples/scripts/vsft_llava.py \
+ --dataset_name="HuggingFaceH4/llava-instruct-mix-vsft" \
--model_name_or_path="llava-hf/llava-1.5-7b-hf" \
--report_to="wandb" \
--learning_rate=1.4e-5 \
@@ -27,11 +28,11 @@
--gradient_checkpointing \
--remove_unused_columns=False \
--torch_dtype=float16 \
- --fp16=True \
- --dataset_name=HuggingFaceH4/llava-instruct-mix-vsft \
+ --fp16=True
# peft:
-python examples/scripts/vsft.py \
+python examples/scripts/vsft_llava.py \
+ --dataset_name="HuggingFaceH4/llava-instruct-mix-vsft" \
--model_name_or_path="llava-hf/llava-1.5-7b-hf" \
--report_to="wandb" \
--learning_rate=1.4e-5 \
@@ -45,7 +46,6 @@
--remove_unused_columns=False \
--torch_dtype=float16 \
--fp16=True \
- --dataset_name=HuggingFaceH4/llava-instruct-mix-vsft \
--use_peft=True \
--lora_r=64 \
--lora_alpha=16 \
@@ -63,13 +63,14 @@
--output_path ./logs/ \
--log_sample
"""
+
import logging
import os
from contextlib import nullcontext
TRL_USE_RICH = os.environ.get("TRL_USE_RICH", False)
-from trl.commands.cli_utils import init_zero_verbose, SftScriptArguments, TrlParser
+from trl.commands.cli_utils import init_zero_verbose, SFTScriptArguments, TrlParser
if TRL_USE_RICH:
init_zero_verbose()
@@ -83,11 +84,12 @@
from datasets import load_dataset
from tqdm.rich import tqdm
-from transformers import AutoTokenizer, AutoProcessor, TrainingArguments, LlavaForConditionalGeneration
+from transformers import AutoTokenizer, AutoProcessor, LlavaForConditionalGeneration
from trl import (
ModelConfig,
RichProgressCallback,
+ SFTConfig,
SFTTrainer,
get_peft_config,
get_quantization_config,
@@ -101,8 +103,8 @@
if __name__ == "__main__":
- parser = TrlParser((SftScriptArguments, TrainingArguments, ModelConfig))
- args, training_args, model_config = parser.parse_args_and_config()
+ parser = TrlParser((SFTScriptArguments, SFTConfig, ModelConfig))
+ sft_script_args, training_args, model_config = parser.parse_args_and_config()
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
# Force use our print callback
if TRL_USE_RICH:
@@ -170,9 +172,9 @@ def __call__(self, examples):
################
# Dataset
################
- raw_datasets = load_dataset(args.dataset_name)
- train_dataset = raw_datasets["train"]
- eval_dataset = raw_datasets["test"]
+ raw_datasets = load_dataset(sft_script_args.dataset_name)
+ train_dataset = raw_datasets[sft_script_args.dataset_train_split]
+ eval_dataset = raw_datasets[sft_script_args.dataset_test_split]
################
# Optional rich context managers
diff --git a/tests/slow/test_sft_slow.py b/tests/slow/test_sft_slow.py
index 8d455fcab7..300204e4b4 100644
--- a/tests/slow/test_sft_slow.py
+++ b/tests/slow/test_sft_slow.py
@@ -20,9 +20,9 @@
from accelerate.utils.memory import release_memory
from datasets import load_dataset
from parameterized import parameterized
-from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments
+from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
-from trl import SFTTrainer, is_peft_available
+from trl import SFTConfig, SFTTrainer, is_peft_available
from trl.models.utils import setup_chat_format
from ..testing_utils import require_bitsandbytes, require_peft, require_torch_gpu, require_torch_multi_gpu
@@ -61,12 +61,15 @@ def test_sft_trainer_str(self, model_name, packing):
as expected.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
- args = TrainingArguments(
+ args = SFTConfig(
output_dir=tmp_dir,
logging_strategy="no",
report_to="none",
per_device_train_batch_size=2,
max_steps=10,
+ packing=packing,
+ dataset_text_field=self.dataset_text_field,
+ max_seq_length=self.max_seq_length,
)
trainer = SFTTrainer(
@@ -74,9 +77,6 @@ def test_sft_trainer_str(self, model_name, packing):
args=args,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
- packing=packing,
- dataset_text_field=self.dataset_text_field,
- max_seq_length=self.max_seq_length,
)
trainer.train()
@@ -88,12 +88,15 @@ def test_sft_trainer_transformers(self, model_name, packing):
as expected.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
- args = TrainingArguments(
+ args = SFTConfig(
output_dir=tmp_dir,
logging_strategy="no",
report_to="none",
per_device_train_batch_size=2,
max_steps=10,
+ packing=packing,
+ dataset_text_field=self.dataset_text_field,
+ max_seq_length=self.max_seq_length,
)
model = AutoModelForCausalLM.from_pretrained(model_name)
@@ -105,9 +108,6 @@ def test_sft_trainer_transformers(self, model_name, packing):
tokenizer=tokenizer,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
- packing=packing,
- dataset_text_field=self.dataset_text_field,
- max_seq_length=self.max_seq_length,
)
trainer.train()
@@ -122,13 +122,16 @@ def test_sft_trainer_peft(self, model_name, packing):
as expected.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
- args = TrainingArguments(
+ args = SFTConfig(
output_dir=tmp_dir,
logging_strategy="no",
report_to="none",
per_device_train_batch_size=2,
max_steps=10,
fp16=True,
+ packing=packing,
+ dataset_text_field=self.dataset_text_field,
+ max_seq_length=self.max_seq_length,
)
model = AutoModelForCausalLM.from_pretrained(model_name)
@@ -140,9 +143,6 @@ def test_sft_trainer_peft(self, model_name, packing):
tokenizer=tokenizer,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
- packing=packing,
- dataset_text_field=self.dataset_text_field,
- max_seq_length=self.max_seq_length,
peft_config=self.peft_config,
)
@@ -159,13 +159,16 @@ def test_sft_trainer_transformers_mp(self, model_name, packing):
as expected in mixed precision.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
- args = TrainingArguments(
+ args = SFTConfig(
output_dir=tmp_dir,
logging_strategy="no",
report_to="none",
per_device_train_batch_size=2,
max_steps=10,
fp16=True, # this is sufficient to enable amp
+ packing=packing,
+ dataset_text_field=self.dataset_text_field,
+ max_seq_length=self.max_seq_length,
)
model = AutoModelForCausalLM.from_pretrained(model_name)
@@ -177,9 +180,6 @@ def test_sft_trainer_transformers_mp(self, model_name, packing):
tokenizer=tokenizer,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
- packing=packing,
- dataset_text_field=self.dataset_text_field,
- max_seq_length=self.max_seq_length,
)
trainer.train()
@@ -193,12 +193,15 @@ def test_sft_trainer_transformers_mp_gc(self, model_name, packing, gradient_chec
as expected in mixed precision + different scenarios of gradient_checkpointing.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
- args = TrainingArguments(
+ args = SFTConfig(
output_dir=tmp_dir,
logging_strategy="no",
report_to="none",
per_device_train_batch_size=2,
max_steps=10,
+ packing=packing,
+ dataset_text_field=self.dataset_text_field,
+ max_seq_length=self.max_seq_length,
fp16=True, # this is sufficient to enable amp
gradient_checkpointing=True,
gradient_checkpointing_kwargs=gradient_checkpointing_kwargs,
@@ -213,9 +216,6 @@ def test_sft_trainer_transformers_mp_gc(self, model_name, packing, gradient_chec
tokenizer=tokenizer,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
- packing=packing,
- dataset_text_field=self.dataset_text_field,
- max_seq_length=self.max_seq_length,
)
trainer.train()
@@ -230,12 +230,15 @@ def test_sft_trainer_transformers_mp_gc_peft(self, model_name, packing, gradient
as expected in mixed precision + different scenarios of gradient_checkpointing.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
- args = TrainingArguments(
+ args = SFTConfig(
output_dir=tmp_dir,
logging_strategy="no",
report_to="none",
per_device_train_batch_size=2,
max_steps=10,
+ packing=packing,
+ dataset_text_field=self.dataset_text_field,
+ max_seq_length=self.max_seq_length,
fp16=True, # this is sufficient to enable amp
gradient_checkpointing=True,
gradient_checkpointing_kwargs=gradient_checkpointing_kwargs,
@@ -250,9 +253,6 @@ def test_sft_trainer_transformers_mp_gc_peft(self, model_name, packing, gradient
tokenizer=tokenizer,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
- packing=packing,
- dataset_text_field=self.dataset_text_field,
- max_seq_length=self.max_seq_length,
peft_config=self.peft_config,
)
@@ -274,12 +274,15 @@ def test_sft_trainer_transformers_mp_gc_device_map(
as expected in mixed precision + different scenarios of gradient_checkpointing (single, multi-gpu, etc).
"""
with tempfile.TemporaryDirectory() as tmp_dir:
- args = TrainingArguments(
+ args = SFTConfig(
output_dir=tmp_dir,
logging_strategy="no",
report_to="none",
per_device_train_batch_size=2,
max_steps=10,
+ packing=packing,
+ dataset_text_field=self.dataset_text_field,
+ max_seq_length=self.max_seq_length,
fp16=True, # this is sufficient to enable amp
gradient_checkpointing=True,
gradient_checkpointing_kwargs=gradient_checkpointing_kwargs,
@@ -294,9 +297,6 @@ def test_sft_trainer_transformers_mp_gc_device_map(
tokenizer=tokenizer,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
- packing=packing,
- dataset_text_field=self.dataset_text_field,
- max_seq_length=self.max_seq_length,
)
trainer.train()
@@ -312,12 +312,15 @@ def test_sft_trainer_transformers_mp_gc_peft_qlora(self, model_name, packing, gr
as expected in mixed precision + different scenarios of gradient_checkpointing.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
- args = TrainingArguments(
+ args = SFTConfig(
output_dir=tmp_dir,
logging_strategy="no",
report_to="none",
per_device_train_batch_size=2,
max_steps=10,
+ packing=packing,
+ dataset_text_field=self.dataset_text_field,
+ max_seq_length=self.max_seq_length,
fp16=True, # this is sufficient to enable amp
gradient_checkpointing=True,
gradient_checkpointing_kwargs=gradient_checkpointing_kwargs,
@@ -334,9 +337,6 @@ def test_sft_trainer_transformers_mp_gc_peft_qlora(self, model_name, packing, gr
tokenizer=tokenizer,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
- packing=packing,
- dataset_text_field=self.dataset_text_field,
- max_seq_length=self.max_seq_length,
peft_config=self.peft_config,
)
@@ -357,7 +357,9 @@ def test_sft_trainer_with_chat_format_qlora(self, model_name, packing):
with tempfile.TemporaryDirectory() as tmp_dir:
train_dataset = load_dataset("trl-internal-testing/dolly-chatml-sft", split="train")
- args = TrainingArguments(
+ args = SFTConfig(
+ packing=packing,
+ max_seq_length=self.max_seq_length,
output_dir=tmp_dir,
logging_strategy="no",
report_to="none",
@@ -378,8 +380,6 @@ def test_sft_trainer_with_chat_format_qlora(self, model_name, packing):
args=args,
tokenizer=tokenizer,
train_dataset=train_dataset,
- packing=packing,
- max_seq_length=self.max_seq_length,
peft_config=self.peft_config,
)
diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py
index 225621f2d7..379f187e4f 100644
--- a/tests/test_sft_trainer.py
+++ b/tests/test_sft_trainer.py
@@ -25,10 +25,9 @@
AutoProcessor,
AutoTokenizer,
LlavaForConditionalGeneration,
- TrainingArguments,
)
-from trl import SFTTrainer
+from trl import SFTConfig, SFTTrainer
from trl.import_utils import is_peft_available, is_pil_available
from trl.trainer import ConstantLengthDataset, DataCollatorForCompletionOnlyLM
@@ -216,7 +215,7 @@ def test_constant_length_dataset(self):
def test_sft_trainer(self):
with tempfile.TemporaryDirectory() as tmp_dir:
- training_args = TrainingArguments(
+ training_args = SFTConfig(
output_dir=tmp_dir,
dataloader_drop_last=True,
evaluation_strategy="steps",
@@ -224,6 +223,7 @@ def test_sft_trainer(self):
eval_steps=2,
save_steps=2,
per_device_train_batch_size=2,
+ packing=True,
)
trainer = SFTTrainer(
@@ -231,7 +231,6 @@ def test_sft_trainer(self):
args=training_args,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
- packing=True,
)
trainer.train()
@@ -243,7 +242,7 @@ def test_sft_trainer(self):
def test_sft_trainer_uncorrect_data(self):
with tempfile.TemporaryDirectory() as tmp_dir:
- training_args = TrainingArguments(
+ training_args = SFTConfig(
output_dir=tmp_dir,
dataloader_drop_last=True,
evaluation_strategy="steps",
@@ -251,6 +250,7 @@ def test_sft_trainer_uncorrect_data(self):
eval_steps=1,
save_steps=1,
per_device_train_batch_size=2,
+ packing=True,
)
with pytest.raises(ValueError):
@@ -258,13 +258,16 @@ def test_sft_trainer_uncorrect_data(self):
model=self.model,
args=training_args,
train_dataset=self.dummy_dataset,
- packing=True,
)
# this should work since the dummy chatml include the correct format
- _ = SFTTrainer(
- model=self.model,
- args=training_args,
- train_dataset=self.dummy_chatml_dataset,
+ training_args = SFTConfig(
+ output_dir=tmp_dir,
+ dataloader_drop_last=True,
+ evaluation_strategy="steps",
+ max_steps=2,
+ eval_steps=1,
+ save_steps=1,
+ per_device_train_batch_size=2,
max_seq_length=32, # make sure there is at least 1 packed sequence
num_of_sequences=32,
packing=True,
@@ -273,13 +276,33 @@ def test_sft_trainer_uncorrect_data(self):
model=self.model,
args=training_args,
train_dataset=self.dummy_chatml_dataset,
+ )
+
+ training_args = SFTConfig(
+ output_dir=tmp_dir,
+ dataloader_drop_last=True,
+ evaluation_strategy="steps",
+ max_steps=2,
+ eval_steps=1,
+ save_steps=1,
+ per_device_train_batch_size=2,
packing=False,
)
- # this should work since the dummy instruction dataset is the correct format
_ = SFTTrainer(
model=self.model,
args=training_args,
- train_dataset=self.dummy_instruction_dataset,
+ train_dataset=self.dummy_chatml_dataset,
+ )
+ # this should work since the dummy instruction dataset is the correct format
+
+ training_args = SFTConfig(
+ output_dir=tmp_dir,
+ dataloader_drop_last=True,
+ evaluation_strategy="steps",
+ max_steps=2,
+ eval_steps=1,
+ save_steps=1,
+ per_device_train_batch_size=2,
max_seq_length=16, # make sure there is at least 1 packed sequence
packing=True,
)
@@ -287,51 +310,103 @@ def test_sft_trainer_uncorrect_data(self):
model=self.model,
args=training_args,
train_dataset=self.dummy_instruction_dataset,
+ )
+
+ training_args = SFTConfig(
+ output_dir=tmp_dir,
+ dataloader_drop_last=True,
+ evaluation_strategy="steps",
+ max_steps=2,
+ eval_steps=1,
+ save_steps=1,
+ per_device_train_batch_size=2,
packing=False,
)
+ _ = SFTTrainer(
+ model=self.model,
+ args=training_args,
+ train_dataset=self.dummy_instruction_dataset,
+ )
+
+ training_args = SFTConfig(
+ output_dir=tmp_dir,
+ dataloader_drop_last=True,
+ evaluation_strategy="steps",
+ max_steps=2,
+ eval_steps=1,
+ save_steps=1,
+ per_device_train_batch_size=2,
+ max_seq_length=32, # make sure there is at least 1 packed sequence
+ packing=True,
+ )
# This should work
_ = SFTTrainer(
model=self.model,
args=training_args,
train_dataset=self.dummy_dataset,
formatting_func=formatting_prompts_func,
- max_seq_length=32, # make sure there is at least 1 packed sequence
- packing=True,
)
with pytest.raises(ValueError):
# This should not work because not enough data for one sample
+ training_args = SFTConfig(
+ output_dir=tmp_dir,
+ dataloader_drop_last=True,
+ evaluation_strategy="steps",
+ max_steps=2,
+ eval_steps=1,
+ save_steps=1,
+ per_device_train_batch_size=2,
+ max_seq_length=1024, # make sure there is NOT at least 1 packed sequence
+ packing=True,
+ )
_ = SFTTrainer(
model=self.model,
args=training_args,
train_dataset=self.dummy_dataset,
formatting_func=formatting_prompts_func,
- max_seq_length=1024, # make sure there is NOT at least 1 packed sequence
- packing=True,
)
# This should not work as well
with pytest.raises(ValueError):
+ training_args = SFTConfig(
+ output_dir=tmp_dir,
+ dataloader_drop_last=True,
+ evaluation_strategy="steps",
+ max_steps=2,
+ eval_steps=1,
+ save_steps=1,
+ per_device_train_batch_size=2,
+ packing=False,
+ )
_ = SFTTrainer(
model=self.model,
args=training_args,
train_dataset=self.dummy_dataset,
formatting_func=formatting_prompts_func,
- packing=False,
)
# but this should work
+ training_args = SFTConfig(
+ output_dir=tmp_dir,
+ dataloader_drop_last=True,
+ evaluation_strategy="steps",
+ max_steps=2,
+ eval_steps=1,
+ save_steps=1,
+ per_device_train_batch_size=2,
+ packing=False,
+ )
_ = SFTTrainer(
model=self.model,
args=training_args,
train_dataset=self.dummy_dataset,
formatting_func=formatting_prompts_func_batched,
- packing=False,
)
def test_sft_trainer_with_model_num_train_epochs(self):
with tempfile.TemporaryDirectory() as tmp_dir:
- training_args = TrainingArguments(
+ training_args = SFTConfig(
output_dir=tmp_dir,
dataloader_drop_last=True,
evaluation_strategy="steps",
@@ -340,6 +415,7 @@ def test_sft_trainer_with_model_num_train_epochs(self):
save_steps=1,
num_train_epochs=2,
per_device_train_batch_size=2,
+ packing=True,
)
trainer = SFTTrainer(
@@ -347,7 +423,6 @@ def test_sft_trainer_with_model_num_train_epochs(self):
args=training_args,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
- packing=True,
)
trainer.train()
@@ -358,7 +433,7 @@ def test_sft_trainer_with_model_num_train_epochs(self):
assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")
with tempfile.TemporaryDirectory() as tmp_dir:
- training_args = TrainingArguments(
+ training_args = SFTConfig(
output_dir=tmp_dir,
dataloader_drop_last=True,
evaluation_strategy="steps",
@@ -366,16 +441,16 @@ def test_sft_trainer_with_model_num_train_epochs(self):
save_steps=1,
num_train_epochs=2,
per_device_train_batch_size=2,
+ dataset_text_field="text",
+ max_seq_length=16,
+ num_of_sequences=16,
+ packing=True,
)
trainer = SFTTrainer(
model=self.model,
args=training_args,
train_dataset=self.dummy_dataset,
- dataset_text_field="text",
- max_seq_length=16,
- num_of_sequences=16,
- packing=True,
)
trainer.train()
@@ -385,7 +460,7 @@ def test_sft_trainer_with_model_num_train_epochs(self):
assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")
with tempfile.TemporaryDirectory() as tmp_dir:
- training_args = TrainingArguments(
+ training_args = SFTConfig(
output_dir=tmp_dir,
dataloader_drop_last=True,
evaluation_strategy="steps",
@@ -393,14 +468,14 @@ def test_sft_trainer_with_model_num_train_epochs(self):
save_steps=1,
num_train_epochs=2,
per_device_train_batch_size=2,
+ dataset_text_field="text",
+ max_seq_length=16,
)
trainer = SFTTrainer(
model=self.model,
args=training_args,
train_dataset=self.dummy_dataset,
- dataset_text_field="text",
- max_seq_length=16,
)
trainer.train()
@@ -411,7 +486,7 @@ def test_sft_trainer_with_model_num_train_epochs(self):
def test_sft_trainer_with_model(self):
with tempfile.TemporaryDirectory() as tmp_dir:
- training_args = TrainingArguments(
+ training_args = SFTConfig(
output_dir=tmp_dir,
dataloader_drop_last=True,
evaluation_strategy="steps",
@@ -419,6 +494,7 @@ def test_sft_trainer_with_model(self):
eval_steps=1,
save_steps=1,
per_device_train_batch_size=2,
+ packing=True,
)
trainer = SFTTrainer(
@@ -426,7 +502,6 @@ def test_sft_trainer_with_model(self):
args=training_args,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
- packing=True,
)
trainer.train()
@@ -437,23 +512,23 @@ def test_sft_trainer_with_model(self):
assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")
with tempfile.TemporaryDirectory() as tmp_dir:
- training_args = TrainingArguments(
+ training_args = SFTConfig(
output_dir=tmp_dir,
dataloader_drop_last=True,
evaluation_strategy="steps",
max_steps=2,
save_steps=1,
per_device_train_batch_size=2,
+ dataset_text_field="text",
+ max_seq_length=16,
+ num_of_sequences=16,
+ packing=True,
)
trainer = SFTTrainer(
model=self.model,
args=training_args,
train_dataset=self.dummy_dataset,
- dataset_text_field="text",
- max_seq_length=16,
- num_of_sequences=16,
- packing=True,
)
trainer.train()
@@ -464,13 +539,16 @@ def test_sft_trainer_with_model(self):
# with formatting_func + packed
with tempfile.TemporaryDirectory() as tmp_dir:
- training_args = TrainingArguments(
+ training_args = SFTConfig(
output_dir=tmp_dir,
dataloader_drop_last=True,
evaluation_strategy="steps",
max_steps=2,
save_steps=1,
per_device_train_batch_size=2,
+ max_seq_length=16,
+ num_of_sequences=16,
+ packing=True,
)
trainer = SFTTrainer(
@@ -478,9 +556,6 @@ def test_sft_trainer_with_model(self):
args=training_args,
train_dataset=self.dummy_dataset,
formatting_func=formatting_prompts_func,
- max_seq_length=16,
- num_of_sequences=16,
- packing=True,
)
trainer.train()
@@ -491,13 +566,14 @@ def test_sft_trainer_with_model(self):
# with formatting_func + packed
with tempfile.TemporaryDirectory() as tmp_dir:
- training_args = TrainingArguments(
+ training_args = SFTConfig(
output_dir=tmp_dir,
dataloader_drop_last=True,
evaluation_strategy="steps",
max_steps=2,
save_steps=1,
per_device_train_batch_size=2,
+ max_seq_length=16,
)
trainer = SFTTrainer(
@@ -505,7 +581,6 @@ def test_sft_trainer_with_model(self):
args=training_args,
train_dataset=self.dummy_dataset,
formatting_func=formatting_prompts_func_batched,
- max_seq_length=16,
)
trainer.train()
@@ -515,21 +590,21 @@ def test_sft_trainer_with_model(self):
assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")
with tempfile.TemporaryDirectory() as tmp_dir:
- training_args = TrainingArguments(
+ training_args = SFTConfig(
output_dir=tmp_dir,
dataloader_drop_last=True,
evaluation_strategy="steps",
max_steps=2,
save_steps=1,
per_device_train_batch_size=2,
+ dataset_text_field="text",
+ max_seq_length=16,
)
trainer = SFTTrainer(
model=self.model,
args=training_args,
train_dataset=self.dummy_dataset,
- dataset_text_field="text",
- max_seq_length=16,
)
trainer.train()
@@ -540,7 +615,7 @@ def test_sft_trainer_with_model(self):
def test_sft_trainer_with_multiple_eval_datasets(self):
with tempfile.TemporaryDirectory() as tmp_dir:
- training_args = TrainingArguments(
+ training_args = SFTConfig(
output_dir=tmp_dir,
dataloader_drop_last=True,
evaluation_strategy="steps",
@@ -548,6 +623,7 @@ def test_sft_trainer_with_multiple_eval_datasets(self):
eval_steps=1,
save_steps=1,
per_device_train_batch_size=2,
+ packing=True,
)
trainer = SFTTrainer(
@@ -558,7 +634,6 @@ def test_sft_trainer_with_multiple_eval_datasets(self):
"data1": self.eval_dataset,
"data2": self.eval_dataset,
},
- packing=True,
)
trainer.train()
@@ -662,7 +737,7 @@ def test_data_collator_chat_completion_lm_with_multiple_text(self):
def test_sft_trainer_infinite_with_model(self):
with tempfile.TemporaryDirectory() as tmp_dir:
- training_args = TrainingArguments(
+ training_args = SFTConfig(
output_dir=tmp_dir,
dataloader_drop_last=True,
evaluation_strategy="steps",
@@ -670,6 +745,8 @@ def test_sft_trainer_infinite_with_model(self):
eval_steps=1,
save_steps=1,
per_device_train_batch_size=2,
+ packing=True,
+ max_seq_length=500,
)
trainer = SFTTrainer(
@@ -677,8 +754,6 @@ def test_sft_trainer_infinite_with_model(self):
args=training_args,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
- packing=True,
- max_seq_length=500,
)
assert trainer.train_dataset.infinite
@@ -693,12 +768,14 @@ def test_sft_trainer_infinite_with_model(self):
def test_sft_trainer_infinite_with_model_epochs(self):
with tempfile.TemporaryDirectory() as tmp_dir:
- training_args = TrainingArguments(
+ training_args = SFTConfig(
output_dir=tmp_dir,
dataloader_drop_last=True,
num_train_epochs=1,
per_device_train_batch_size=2,
save_strategy="epoch",
+ packing=True,
+ max_seq_length=500,
)
trainer = SFTTrainer(
@@ -706,8 +783,6 @@ def test_sft_trainer_infinite_with_model_epochs(self):
args=training_args,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
- packing=True,
- max_seq_length=500,
)
assert not trainer.train_dataset.infinite
@@ -721,7 +796,7 @@ def test_sft_trainer_infinite_with_model_epochs(self):
def test_sft_trainer_with_model_neftune(self):
with tempfile.TemporaryDirectory() as tmp_dir:
- training_args = TrainingArguments(
+ training_args = SFTConfig(
output_dir=tmp_dir,
dataloader_drop_last=True,
evaluation_strategy="steps",
@@ -729,6 +804,8 @@ def test_sft_trainer_with_model_neftune(self):
eval_steps=1,
save_steps=1,
per_device_train_batch_size=2,
+ neftune_noise_alpha=5,
+ packing=True,
)
trainer = SFTTrainer(
@@ -736,8 +813,6 @@ def test_sft_trainer_with_model_neftune(self):
args=training_args,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
- neftune_noise_alpha=5,
- packing=True,
)
trainer.model = trainer._trl_activate_neftune(trainer.model)
@@ -764,27 +839,29 @@ def test_sft_trainer_with_model_neftune(self):
@require_peft
def test_peft_sft_trainer_str(self):
- peft_config = LoraConfig(
- r=16,
- lora_alpha=32,
- lora_dropout=0.05,
- bias="none",
- task_type="CAUSAL_LM",
- )
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ peft_config = LoraConfig(
+ r=16,
+ lora_alpha=32,
+ lora_dropout=0.05,
+ bias="none",
+ task_type="CAUSAL_LM",
+ )
- _ = SFTTrainer(
- model=self.model_id,
- args=None,
- train_dataset=self.train_dataset,
- eval_dataset=self.eval_dataset,
- peft_config=peft_config,
- packing=True,
- )
+ training_args = SFTConfig(packing=True, output_dir=tmp_dir)
+
+ _ = SFTTrainer(
+ model=self.model_id,
+ args=training_args,
+ train_dataset=self.train_dataset,
+ eval_dataset=self.eval_dataset,
+ peft_config=peft_config,
+ )
@require_peft
def test_peft_sft_trainer(self):
with tempfile.TemporaryDirectory() as tmp_dir:
- training_args = TrainingArguments(
+ training_args = SFTConfig(
output_dir=tmp_dir,
dataloader_drop_last=True,
evaluation_strategy="steps",
@@ -792,6 +869,7 @@ def test_peft_sft_trainer(self):
eval_steps=2,
save_steps=2,
per_device_train_batch_size=2,
+ packing=True,
)
peft_config = LoraConfig(
@@ -808,7 +886,6 @@ def test_peft_sft_trainer(self):
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
peft_config=peft_config,
- packing=True,
)
assert isinstance(trainer.model, PeftModel)
@@ -825,7 +902,7 @@ def test_peft_sft_trainer(self):
@require_peft
def test_peft_sft_trainer_gc(self):
with tempfile.TemporaryDirectory() as tmp_dir:
- training_args = TrainingArguments(
+ training_args = SFTConfig(
output_dir=tmp_dir,
dataloader_drop_last=True,
evaluation_strategy="steps",
@@ -834,6 +911,7 @@ def test_peft_sft_trainer_gc(self):
save_steps=2,
per_device_train_batch_size=2,
gradient_checkpointing=True,
+ packing=True,
)
peft_config = LoraConfig(
@@ -850,7 +928,6 @@ def test_peft_sft_trainer_gc(self):
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
peft_config=peft_config,
- packing=True,
)
assert isinstance(trainer.model, PeftModel)
@@ -867,7 +944,7 @@ def test_peft_sft_trainer_gc(self):
@require_peft
def test_peft_sft_trainer_neftune(self):
with tempfile.TemporaryDirectory() as tmp_dir:
- training_args = TrainingArguments(
+ training_args = SFTConfig(
output_dir=tmp_dir,
dataloader_drop_last=True,
evaluation_strategy="steps",
@@ -875,6 +952,8 @@ def test_peft_sft_trainer_neftune(self):
eval_steps=2,
save_steps=2,
per_device_train_batch_size=2,
+ neftune_noise_alpha=5,
+ packing=True,
)
peft_config = LoraConfig(
@@ -891,8 +970,6 @@ def test_peft_sft_trainer_neftune(self):
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
peft_config=peft_config,
- neftune_noise_alpha=5,
- packing=True,
)
trainer.model = trainer._trl_activate_neftune(trainer.model)
@@ -929,7 +1006,7 @@ def test_peft_sft_trainer_neftune(self):
@require_peft
def test_peft_sft_trainer_tag(self):
with tempfile.TemporaryDirectory() as tmp_dir:
- training_args = TrainingArguments(
+ training_args = SFTConfig(
output_dir=tmp_dir,
dataloader_drop_last=True,
evaluation_strategy="steps",
@@ -938,6 +1015,7 @@ def test_peft_sft_trainer_tag(self):
save_steps=2,
per_device_train_batch_size=2,
gradient_checkpointing=True,
+ packing=True,
)
peft_config = LoraConfig(
@@ -954,7 +1032,6 @@ def test_peft_sft_trainer_tag(self):
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
peft_config=peft_config,
- packing=True,
)
assert trainer.model.model_tags == trainer._tag_names
@@ -962,7 +1039,7 @@ def test_peft_sft_trainer_tag(self):
@require_peft
def test_sft_trainer_tag(self):
with tempfile.TemporaryDirectory() as tmp_dir:
- training_args = TrainingArguments(
+ training_args = SFTConfig(
output_dir=tmp_dir,
dataloader_drop_last=True,
evaluation_strategy="steps",
@@ -971,6 +1048,7 @@ def test_sft_trainer_tag(self):
save_steps=2,
per_device_train_batch_size=2,
gradient_checkpointing=True,
+ packing=True,
)
trainer = SFTTrainer(
@@ -978,14 +1056,13 @@ def test_sft_trainer_tag(self):
args=training_args,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
- packing=True,
)
assert trainer.model.model_tags == trainer._tag_names
def test_sft_trainer_eval_packing(self):
with tempfile.TemporaryDirectory() as tmp_dir:
- training_args = TrainingArguments(
+ training_args = SFTConfig(
output_dir=tmp_dir,
dataloader_drop_last=True,
evaluation_strategy="steps",
@@ -994,6 +1071,9 @@ def test_sft_trainer_eval_packing(self):
save_steps=2,
per_device_train_batch_size=2,
gradient_checkpointing=True,
+ packing=True,
+ max_seq_length=32, # make sure there is at least 1 packed sequence
+ eval_packing=False,
)
trainer = SFTTrainer(
@@ -1001,33 +1081,50 @@ def test_sft_trainer_eval_packing(self):
args=training_args,
train_dataset=self.dummy_chatml_dataset,
eval_dataset=self.dummy_chatml_dataset,
- packing=True,
- max_seq_length=32, # make sure there is at least 1 packed sequence
- eval_packing=False,
)
assert len(trainer.train_dataset["input_ids"]) == 1
assert len(trainer.eval_dataset["input_ids"]) != 1
+ training_args = SFTConfig(
+ output_dir=tmp_dir,
+ dataloader_drop_last=True,
+ evaluation_strategy="steps",
+ max_steps=4,
+ eval_steps=2,
+ save_steps=2,
+ per_device_train_batch_size=2,
+ gradient_checkpointing=True,
+ max_seq_length=32, # make sure there is at least 1 packed sequence
+ packing=True,
+ )
trainer = SFTTrainer(
model=self.model_id,
args=training_args,
train_dataset=self.dummy_chatml_dataset,
eval_dataset=self.dummy_chatml_dataset,
- max_seq_length=32, # make sure there is at least 1 packed sequence
- packing=True,
)
assert len(trainer.train_dataset["input_ids"]) == 1
assert len(trainer.eval_dataset["input_ids"]) == 1
+ training_args = SFTConfig(
+ output_dir=tmp_dir,
+ dataloader_drop_last=True,
+ evaluation_strategy="steps",
+ max_steps=4,
+ eval_steps=2,
+ save_steps=2,
+ per_device_train_batch_size=2,
+ gradient_checkpointing=True,
+ max_seq_length=32, # make sure there is at least 1 packed sequence
+ packing=False,
+ )
trainer = SFTTrainer(
model=self.model_id,
args=training_args,
train_dataset=self.dummy_chatml_dataset,
eval_dataset=self.dummy_chatml_dataset,
- max_seq_length=32, # make sure there is at least 1 packed sequence
- packing=False,
)
assert len(trainer.train_dataset["input_ids"]) != 1
@@ -1036,7 +1133,7 @@ def test_sft_trainer_eval_packing(self):
@requires_pil
def test_sft_trainer_skip_prepare_dataset(self):
with tempfile.TemporaryDirectory() as tmp_dir:
- training_args = TrainingArguments(
+ training_args = SFTConfig(
output_dir=tmp_dir,
dataloader_drop_last=True,
evaluation_strategy="steps",
@@ -1046,6 +1143,8 @@ def test_sft_trainer_skip_prepare_dataset(self):
per_device_train_batch_size=2,
gradient_checkpointing=True,
remove_unused_columns=False,
+ dataset_text_field="text", # need a dummy field
+ dataset_kwargs={"skip_prepare_dataset": True},
)
trainer = SFTTrainer(
@@ -1053,8 +1152,6 @@ def test_sft_trainer_skip_prepare_dataset(self):
args=training_args,
train_dataset=self.dummy_vsft_instruction_dataset,
eval_dataset=self.dummy_vsft_instruction_dataset,
- dataset_text_field="text", # need a dummy field
- dataset_kwargs={"skip_prepare_dataset": True},
)
assert trainer.train_dataset.features == self.dummy_vsft_instruction_dataset.features
assert trainer.eval_dataset.features == self.dummy_vsft_instruction_dataset.features
@@ -1062,7 +1159,7 @@ def test_sft_trainer_skip_prepare_dataset(self):
@requires_pil
def test_sft_trainer_llava(self):
with tempfile.TemporaryDirectory() as tmp_dir:
- training_args = TrainingArguments(
+ training_args = SFTConfig(
output_dir=tmp_dir,
dataloader_drop_last=True,
evaluation_strategy="steps",
@@ -1072,13 +1169,15 @@ def test_sft_trainer_llava(self):
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
remove_unused_columns=False,
+ dataset_text_field="text", # need a dummy field
+ dataset_kwargs={"skip_prepare_dataset": True},
)
tiny_llava = LlavaForConditionalGeneration.from_pretrained(
"trl-internal-testing/tiny-random-LlavaForConditionalGeneration"
)
processor = AutoProcessor.from_pretrained("trl-internal-testing/tiny-random-LlavaForConditionalGeneration")
- processor.tokenizer.chat_template = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. {% for message in messages %}{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}{% endif %}{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}"""
+ processor.tokenizer.chat_template = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. {% for message in messages %}{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}{% endif %}{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}{% if add_generation_prompt %}ASSISTANT: {% endif %}"""
class LLavaDataCollator:
def __init__(self, processor):
@@ -1113,8 +1212,6 @@ def __call__(self, examples):
args=training_args,
train_dataset=self.dummy_vsft_instruction_dataset,
eval_dataset=self.dummy_vsft_instruction_dataset,
- dataset_text_field="text", # need a dummy field
- dataset_kwargs={"skip_prepare_dataset": True},
data_collator=data_collator,
)
diff --git a/trl/__init__.py b/trl/__init__.py
index 06284b4b12..6b33eca27f 100644
--- a/trl/__init__.py
+++ b/trl/__init__.py
@@ -49,10 +49,11 @@
"PPOTrainer",
"RewardConfig",
"RewardTrainer",
+ "SFTConfig",
"SFTTrainer",
],
"commands": [],
- "commands.cli_utils": ["init_zero_verbose", "SftScriptArguments", "DPOScriptArguments", "TrlParser"],
+ "commands.cli_utils": ["init_zero_verbose", "SFTScriptArguments", "DPOScriptArguments", "TrlParser"],
"trainer.utils": ["get_kbit_device_map", "get_peft_config", "get_quantization_config", "RichProgressCallback"],
"multitask_prompt_tuning": [
"MultitaskPromptEmbedding",
@@ -114,10 +115,11 @@
PPOTrainer,
RewardConfig,
RewardTrainer,
+ SFTConfig,
SFTTrainer,
)
from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config, RichProgressCallback
- from .commands.cli_utils import init_zero_verbose, SftScriptArguments, DPOScriptArguments, TrlParser
+ from .commands.cli_utils import init_zero_verbose, SFTScriptArguments, DPOScriptArguments, TrlParser
try:
if not is_diffusers_available():
diff --git a/trl/commands/__init__.py b/trl/commands/__init__.py
index e3955e11ee..1b88357e9a 100644
--- a/trl/commands/__init__.py
+++ b/trl/commands/__init__.py
@@ -25,7 +25,7 @@
if TYPE_CHECKING:
- from .cli_utils import SftScriptArguments, init_zero_verbose, DPOScriptArguments, TrlParser, YamlConfigParser
+ from .cli_utils import SFTScriptArguments, init_zero_verbose, DPOScriptArguments, TrlParser, YamlConfigParser
else:
import sys
diff --git a/trl/commands/cli_utils.py b/trl/commands/cli_utils.py
index 6674b11cb0..7f421efa9c 100644
--- a/trl/commands/cli_utils.py
+++ b/trl/commands/cli_utils.py
@@ -140,13 +140,10 @@ def warning_handler(message, category, filename, lineno, file=None, line=None):
@dataclass
-class SftScriptArguments:
+class SFTScriptArguments:
dataset_name: str = field(default="timdettmers/openassistant-guanaco", metadata={"help": "the dataset name"})
- dataset_text_field: str = field(default=None, metadata={"help": "the text field of the dataset"})
- dataset_train_name: str = field(default="train", metadata={"help": "the name of the training set of the dataset"})
- dataset_test_name: str = field(default="test", metadata={"help": "the name of the training set of the dataset"})
- max_seq_length: int = field(default=512, metadata={"help": "The maximum sequence length for SFT Trainer"})
- packing: bool = field(default=False, metadata={"help": "Whether to apply data packing or not during training"})
+ dataset_train_split: str = field(default="train", metadata={"help": "The dataset split to train on"})
+ dataset_test_split: str = field(default="test", metadata={"help": "The dataset split to evaluate on"})
config: str = field(default=None, metadata={"help": "Path to the optional config file"})
gradient_checkpointing_use_reentrant: bool = field(
default=False, metadata={"help": "Whether to apply `use_reentrant` for gradient_checkpointing"}
@@ -267,7 +264,7 @@ def post_process_dataclasses(self, dataclasses):
if dataclass_obj.__class__.__name__ == "TrainingArguments":
training_args = dataclass_obj
training_args_index = i
- elif dataclass_obj.__class__.__name__ in ("SftScriptArguments", "DPOScriptArguments"):
+ elif dataclass_obj.__class__.__name__ in ("SFTScriptArguments", "DPOScriptArguments"):
trl_args = dataclass_obj
else:
...
diff --git a/trl/trainer/__init__.py b/trl/trainer/__init__.py
index a3cc16048f..c291ff31ec 100644
--- a/trl/trainer/__init__.py
+++ b/trl/trainer/__init__.py
@@ -43,6 +43,7 @@
"ppo_trainer": ["PPOTrainer"],
"reward_config": ["RewardConfig"],
"reward_trainer": ["RewardTrainer", "compute_accuracy"],
+ "sft_config": ["SFTConfig"],
"sft_trainer": ["SFTTrainer"],
"base": ["BaseTrainer"],
"ddpo_config": ["DDPOConfig"],
@@ -89,6 +90,7 @@
from .ppo_trainer import PPOTrainer
from .reward_config import RewardConfig
from .reward_trainer import RewardTrainer, compute_accuracy
+ from .sft_config import SFTConfig
from .sft_trainer import SFTTrainer
try:
diff --git a/trl/trainer/sft_config.py b/trl/trainer/sft_config.py
new file mode 100644
index 0000000000..0a3289b29b
--- /dev/null
+++ b/trl/trainer/sft_config.py
@@ -0,0 +1,65 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from typing import Dict, Optional
+
+from transformers import TrainingArguments
+
+
+@dataclass
+class SFTConfig(TrainingArguments):
+ r"""
+ Initialize SFTConfig.
+
+ Args:
+ dataset_text_field (`Optional[str]`):
+ The name of the text field of the dataset, in case this is passed by a user, the trainer will automatically create a
+ `ConstantLengthDataset` based on the `dataset_text_field` argument. Defaults to None.
+ packing (`Optional[bool]`):
+ Used only in case `dataset_text_field` is passed. This argument is used by the `ConstantLengthDataset` to pack the sequences
+ of the dataset. Defaults to False.
+ max_seq_length (`Optional[int]`):
+ The maximum sequence length to use for the `ConstantLengthDataset` and for automatically creating the Dataset. Defaults to min of the smaller of the `tokenizer.model_max_length` and `1024`.
+ dataset_num_proc (`Optional[int]`):
+ The number of workers to use to tokenize the data. Only used when `packing=False`. Defaults to None.
+ dataset_batch_size (`int`):
+ The number of examples to tokenize per batch. If batch_size <= 0 or batch_size == None,
+ tokenize the full dataset as a single batch. Defaults to 1000.
+ neftune_noise_alpha (`Optional[float]`):
+ If not `None`, this will activate NEFTune noise embeddings. This has been proven to drastically improve model performances for instruction
+ fine-tuning. Check out the original paper here: https://arxiv.org/abs/2310.05914 and the original code here: https://github.com/neelsjain/NEFTune
+ model_init_kwargs: (`Optional[Dict]`, *optional*):
+ Dict of Optional kwargs to pass when instantiating the model from a string.
+ dataset_kwargs: (`Optional[Dict]`, *optional*):
+ Dict of Optional kwargs to pass when creating packed or non-packed datasets
+ eval_packing: (`Optional[bool]`, *optional*):
+ Whether to pack the eval dataset as well. Defaults to `packing` if `None` is passed.
+ num_of_sequences (`Optional[int]`):
+ The number of sequences to use for the `ConstantLengthDataset`. Defaults to `1024`.
+ chars_per_token (`Optional[float]`):
+ The number of characters per token to use for the `ConstantLengthDataset`. Defaults to `3.6`. You can check how this is computed in the
+ stack-llama example: https://github.com/huggingface/trl/blob/08f550674c553c36c51d1027613c29f14f3676a5/examples/stack_llama/scripts/supervised_finetuning.py#L53.
+ """
+
+ dataset_text_field: Optional[str] = None
+ packing: Optional[bool] = False
+ max_seq_length: Optional[int] = None
+ dataset_num_proc: Optional[int] = None
+ dataset_batch_size: int = 1000
+ neftune_noise_alpha: Optional[float] = None
+ model_init_kwargs: Optional[Dict] = None
+ dataset_kwargs: Optional[Dict] = None
+ eval_packing: Optional[bool] = None
+ num_of_sequences: Optional[int] = 1024
+ chars_per_token: Optional[float] = 3.6
diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py
index f27163b00a..8ad72e8dc1 100644
--- a/trl/trainer/sft_trainer.py
+++ b/trl/trainer/sft_trainer.py
@@ -24,6 +24,7 @@
from datasets import Dataset
from datasets.arrow_writer import SchemaInferenceError
from datasets.builder import DatasetGenerationError
+from huggingface_hub.utils._deprecation import _deprecate_arguments
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
@@ -32,7 +33,6 @@
PreTrainedModel,
PreTrainedTokenizerBase,
Trainer,
- TrainingArguments,
)
from transformers.modeling_utils import unwrap_model
from transformers.trainer_callback import TrainerCallback
@@ -40,6 +40,7 @@
from ..extras.dataset_formatting import get_formatting_func_from_dataset
from ..import_utils import is_peft_available
+from .sft_config import SFTConfig
from .utils import (
ConstantLengthDataset,
DataCollatorForCompletionOnlyLM,
@@ -89,45 +90,33 @@ class SFTTrainer(Trainer):
The function to use to preprocess the logits before computing the metrics.
peft_config (`Optional[PeftConfig]`):
The PeftConfig object to use to initialize the PeftModel.
- dataset_text_field (`Optional[str]`):
- The name of the text field of the dataset, in case this is passed by a user, the trainer will automatically create a
- `ConstantLengthDataset` based on the `dataset_text_field` argument.
formatting_func (`Optional[Callable]`):
The formatting function to be used for creating the `ConstantLengthDataset`.
- max_seq_length (`Optional[int]`):
- The maximum sequence length to use for the `ConstantLengthDataset` and for automatically creating the Dataset. Defaults to `512`.
- infinite (`Optional[bool]`):
- Whether to use an infinite dataset or not. Defaults to `False`.
- num_of_sequences (`Optional[int]`):
- The number of sequences to use for the `ConstantLengthDataset`. Defaults to `1024`.
- chars_per_token (`Optional[float]`):
- The number of characters per token to use for the `ConstantLengthDataset`. Defaults to `3.6`. You can check how this is computed in the
- stack-llama example: https://github.com/huggingface/trl/blob/08f550674c553c36c51d1027613c29f14f3676a5/examples/stack_llama/scripts/supervised_finetuning.py#L53.
- packing (`Optional[bool]`):
- Used only in case `dataset_text_field` is passed. This argument is used by the `ConstantLengthDataset` to pack the sequences
- of the dataset.
- dataset_num_proc (`Optional[int]`):
- The number of workers to use to tokenize the data. Only used when `packing=False`. Defaults to None.
- dataset_batch_size (`int`):
- The number of examples to tokenize per batch. If batch_size <= 0 or batch_size == None,
- tokenize the full dataset as a single batch. Defaults to 1000.
- neftune_noise_alpha (`Optional[float]`):
- If not `None`, this will activate NEFTune noise embeddings. This has been proven to drastically improve model performances for instruction
- fine-tuning. Check out the original paper here: https://arxiv.org/abs/2310.05914 and the original code here: https://github.com/neelsjain/NEFTune
- model_init_kwargs: (`Optional[Dict]`, *optional*):
- Dict of Optional kwargs to pass when instantiating the model from a string
- dataset_kwargs: (`Optional[Dict]`, *optional*):
- Dict of Optional kwargs to pass when creating packed or non-packed datasets
- eval_packing: (`Optional[bool]`, *optional*):
- Whether to pack the eval dataset as well. Defaults to `packing` if `None` is passed.
"""
_tag_names = ["trl", "sft"]
+ @_deprecate_arguments(
+ version="1.0.0",
+ deprecated_args=[
+ "dataset_text_field",
+ "packing",
+ "max_seq_length",
+ "dataset_num_proc",
+ "dataset_batch_size",
+ "neftune_noise_alpha",
+ "model_init_kwargs",
+ "dataset_kwargs",
+ "eval_packing",
+ "num_of_sequences",
+ "chars_per_token",
+ ],
+ custom_message="Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.",
+ )
def __init__(
self,
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
- args: Optional[TrainingArguments] = None,
+ args: Optional[SFTConfig] = None,
data_collator: Optional[DataCollator] = None, # type: ignore
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
@@ -152,10 +141,22 @@ def __init__(
dataset_kwargs: Optional[Dict] = None,
eval_packing: Optional[bool] = None,
):
- if model_init_kwargs is None:
+ if model_init_kwargs is not None:
+ warnings.warn(
+ "You passed `model_init_kwargs` to the SFTTrainer, the value you passed will override the one in the `SFTConfig`."
+ )
+ args.model_init_kwargs = model_init_kwargs
+ if args.model_init_kwargs is None:
model_init_kwargs = {}
elif not isinstance(model, str):
- raise ValueError("You passed model_kwargs to the SFTTrainer. But your model is already instantiated.")
+ raise ValueError("You passed model_init_kwargs to the SFTConfig, but your model is already instantiated.")
+ else:
+ model_init_kwargs = args.model_init_kwargs
+ model_init_kwargs["torch_dtype"] = (
+ model_init_kwargs["torch_dtype"]
+ if model_init_kwargs["torch_dtype"] in ["auto", None]
+ else getattr(torch, model_init_kwargs["torch_dtype"])
+ )
if infinite is not None:
warnings.warn(
@@ -169,7 +170,18 @@ def __init__(
)
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
- if packing and data_collator is not None and isinstance(data_collator, DataCollatorForCompletionOnlyLM):
+ if packing:
+ warnings.warn(
+ "You passed a `packing` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`."
+ )
+ args.packing = packing
+ if eval_packing is not None:
+ warnings.warn(
+ "You passed a `eval_packing` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`."
+ )
+ args.eval_packing = eval_packing
+
+ if args.packing and data_collator is not None and isinstance(data_collator, DataCollatorForCompletionOnlyLM):
raise ValueError(
"You passed a `DataCollatorForCompletionOnlyLM` to the SFTTrainer. This is not compatible with the `packing` argument."
)
@@ -239,7 +251,13 @@ def make_inputs_require_grad(module, input, output):
if getattr(tokenizer, "pad_token", None) is None:
tokenizer.pad_token = tokenizer.eos_token
- if max_seq_length is None:
+ if max_seq_length is not None:
+ warnings.warn(
+ "You passed a `max_seq_length` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`."
+ )
+ args.max_seq_length = max_seq_length
+
+ if args.max_seq_length is None:
# to overcome some issues with broken tokenizers
max_seq_length = min(tokenizer.model_max_length, 1024)
@@ -247,21 +265,37 @@ def make_inputs_require_grad(module, input, output):
f"You didn't pass a `max_seq_length` argument to the SFTTrainer, this will default to {max_seq_length}"
)
- self.dataset_num_proc = dataset_num_proc
- self.dataset_batch_size = dataset_batch_size
+ if dataset_num_proc is not None:
+ warnings.warn(
+ "You passed a `dataset_num_proc` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`."
+ )
+ args.dataset_num_proc = dataset_num_proc
+ self.dataset_num_proc = args.dataset_num_proc
- self._trainer_supports_neftune = hasattr(args, "neftune_noise_alpha")
+ if dataset_batch_size != args.dataset_batch_size:
+ warnings.warn(
+ "You passed a `dataset_batch_size` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`."
+ )
+ args.dataset_batch_size = dataset_batch_size
+ self.dataset_batch_size = args.dataset_batch_size
+ self._trainer_supports_neftune = hasattr(args, "neftune_noise_alpha")
if neftune_noise_alpha is not None and self._trainer_supports_neftune:
args.neftune_noise_alpha = neftune_noise_alpha
warnings.warn(
- "You passed a `neftune_noise_alpha` argument to the SFTTrainer, the value you passed will override the one in the `TrainingArguments`."
+ "You passed a `neftune_noise_alpha` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`."
)
# self.neftune_noise_alpha is done at Trainer level
elif not self._trainer_supports_neftune:
self.neftune_noise_alpha = neftune_noise_alpha
- if formatting_func is None and dataset_text_field is None:
+ if dataset_text_field is not None:
+ warnings.warn(
+ "You passed a `dataset_text_field` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`."
+ )
+ args.dataset_text_field = dataset_text_field
+
+ if formatting_func is None and args.dataset_text_field is None:
# check if dataset has ChatML format or instruction format and is supported
# if not stays #None
formatting_func = get_formatting_func_from_dataset(train_dataset, tokenizer)
@@ -272,50 +306,67 @@ def make_inputs_require_grad(module, input, output):
else:
dataset_kwargs["add_special_tokens"] = False
- if not packing:
- if dataset_text_field is None and formatting_func is None:
+ if not args.packing:
+ if args.dataset_text_field is None and formatting_func is None:
raise ValueError(
- "You passed `packing=False` to the SFTTrainer, but you didn't pass a `dataset_text_field` or `formatting_func` argument."
+ "You passed `packing=False` to the SFTTrainer/SFTConfig, but you didn't pass a `dataset_text_field` or `formatting_func` argument."
)
if data_collator is None:
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
+ if num_of_sequences != args.num_of_sequences:
+ warnings.warn(
+ "You passed a `num_of_sequences` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`."
+ )
+ args.num_of_sequences = num_of_sequences
+
+ if chars_per_token != args.chars_per_token:
+ warnings.warn(
+ "You passed a `chars_per_token` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`."
+ )
+ args.chars_per_token = chars_per_token
+
# Pre-process the datasets only once per node. The remaining processes will use the cache.
with PartialState().local_main_process_first():
- if dataset_kwargs is None:
- dataset_kwargs = {}
+ if dataset_kwargs is not None:
+ warnings.warn(
+ "You passed a `dataset_kwargs` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`."
+ )
+ args.dataset_kwargs = dataset_kwargs
+ if args.dataset_kwargs is None:
+ args.dataset_kwargs = {}
if train_dataset is not None:
train_dataset = self._prepare_dataset(
train_dataset,
tokenizer,
- packing,
- dataset_text_field,
- max_seq_length,
+ args.packing,
+ args.dataset_text_field,
+ args.max_seq_length,
formatting_func,
- num_of_sequences,
- chars_per_token,
+ args.num_of_sequences,
+ args.chars_per_token,
remove_unused_columns=args.remove_unused_columns if args is not None else True,
- **dataset_kwargs,
+ **args.dataset_kwargs,
)
if eval_dataset is not None:
_multiple = isinstance(eval_dataset, dict)
_eval_datasets = eval_dataset if _multiple else {"singleton": eval_dataset}
- eval_packing = packing if eval_packing is None else eval_packing
+ eval_packing = args.packing if args.eval_packing is None else args.eval_packing
for _eval_dataset_name, _eval_dataset in _eval_datasets.items():
_eval_datasets[_eval_dataset_name] = self._prepare_dataset(
_eval_dataset,
tokenizer,
eval_packing,
- dataset_text_field,
- max_seq_length,
+ args.dataset_text_field,
+ args.max_seq_length,
formatting_func,
- num_of_sequences,
- chars_per_token,
+ args.num_of_sequences,
+ args.chars_per_token,
remove_unused_columns=args.remove_unused_columns if args is not None else True,
- **dataset_kwargs,
+ **args.dataset_kwargs,
)
if not _multiple:
eval_dataset = _eval_datasets["singleton"]
@@ -344,12 +395,12 @@ def make_inputs_require_grad(module, input, output):
if hasattr(self.model, "add_model_tags"):
self.model.add_model_tags(self._tag_names)
- if self.args.max_steps > 0 and packing:
+ if self.args.max_steps > 0 and args.packing:
warnings.warn(
- "You passed `packing=True` to the SFTTrainer, and you are training your model with `max_steps` strategy. The dataset will be iterated until the `max_steps` are reached."
+ "You passed `packing=True` to the SFTTrainer/SFTConfig, and you are training your model with `max_steps` strategy. The dataset will be iterated until the `max_steps` are reached."
)
self.train_dataset.infinite = True
- elif self.args.max_steps == -1 and packing:
+ elif self.args.max_steps == -1 and args.packing:
self.train_dataset.infinite = False
if any(isinstance(callback, RichProgressCallback) for callback in self.callback_handler.callbacks):