Skip to content

Commit

Permalink
fix: need to pass skip_prepare_dataset for pretokenized dataset due t…
Browse files Browse the repository at this point in the history
…o breaking change in HF SFTTrainer (#326)

* fix: need to pass skip_prepare_dataset for pretokenized dataset due to breaking change in HF SFTTrainer

Signed-off-by: Harikrishnan Balagopal <harikrishmenon@gmail.com>

* fix: wrong dataset paths, was using non-tokenized data in pre-tokenized dataset tests

Signed-off-by: Harikrishnan Balagopal <harikrishmenon@gmail.com>

---------

Signed-off-by: Harikrishnan Balagopal <harikrishmenon@gmail.com>
Signed-off-by: Angel Luu <angel.luu@us.ibm.com>
  • Loading branch information
HarikrishnanBalagopal authored and aluu317 committed Sep 13, 2024
1 parent 673a79c commit 1a24940
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
5 changes: 2 additions & 3 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
EMPTY_DATA,
MALFORMATTED_DATA,
MODEL_NAME,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
TWITTER_COMPLAINTS_DATA_JSON,
TWITTER_COMPLAINTS_DATA_JSONL,
Expand Down Expand Up @@ -850,8 +849,8 @@ def test_run_with_good_experimental_metadata():
@pytest.mark.parametrize(
"dataset_path",
[
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
TWITTER_COMPLAINTS_TOKENIZED_JSONL,
TWITTER_COMPLAINTS_TOKENIZED_JSON,
],
)
### Tests for pretokenized data
Expand Down
7 changes: 7 additions & 0 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
from tuning.utils.preprocessing_utils import (
format_dataset,
get_data_collator,
is_pretokenized_dataset,
validate_data_args,
)

Expand Down Expand Up @@ -318,6 +319,11 @@ def train(
}
training_args = SFTConfig(**transformer_kwargs)

dataset_kwargs = {}
if is_pretokenized_dataset(
data_args.training_data_path or data_args.validation_data_path
):
dataset_kwargs["skip_prepare_dataset"] = True
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
Expand All @@ -330,6 +336,7 @@ def train(
max_seq_length=max_seq_length,
callbacks=trainer_callbacks,
peft_config=peft_config,
dataset_kwargs=dataset_kwargs,
)

# We track additional metrics and experiment metadata after trainer object creation
Expand Down

0 comments on commit 1a24940

Please sign in to comment.