Skip to content

Commit

Permalink
support dataset cache
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed Oct 26, 2023
1 parent 838ed9a commit 3fe7df6
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 3 deletions.
22 changes: 19 additions & 3 deletions src/llmtuner/dsets/preprocess.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import os
import tiktoken
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Union
from itertools import chain
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Union

from datasets import load_from_disk

from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.logging import get_logger
from llmtuner.extras.template import get_template_and_fix_tokenizer

if TYPE_CHECKING:
Expand All @@ -12,14 +16,16 @@
from llmtuner.hparams import DataArguments


logger = get_logger(__name__)


def preprocess_dataset(
dataset: Union["Dataset", "IterableDataset"],
tokenizer: "PreTrainedTokenizer",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"]
) -> Union["Dataset", "IterableDataset"]:
column_names = list(next(iter(dataset)).keys())
template = get_template_and_fix_tokenizer(data_args.template, tokenizer)

if data_args.train_on_prompt and template.efficient_eos:
Expand Down Expand Up @@ -226,7 +232,12 @@ def print_unsupervised_dataset_example(example):
preprocess_func = preprocess_unsupervised_dataset
print_function = print_unsupervised_dataset_example

if data_args.cache_path is not None and os.path.exists(data_args.cache_path):
logger.warning("Loading dataset from disk will ignore other data arguments.")
return load_from_disk(data_args.cache_path)

with training_args.main_process_first(desc="dataset map pre-processing"):
column_names = list(next(iter(dataset)).keys())
kwargs = {}
if not data_args.streaming:
kwargs = dict(
Expand All @@ -242,10 +253,15 @@ def print_unsupervised_dataset_example(example):
**kwargs
)

if data_args.cache_path is not None and not os.path.exists(data_args.cache_path):
if training_args.should_save:
dataset.save_to_disk(data_args.cache_path)
raise SystemExit("Dataset saved, rerun this script with the same `--cache_file`.")

if training_args.should_log:
try:
print_function(next(iter(dataset)))
except StopIteration:
raise ValueError("Empty dataset!")
raise RuntimeError("Empty dataset!")

return dataset
7 changes: 7 additions & 0 deletions src/llmtuner/hparams/data_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ class DataArguments:
default=False,
metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."}
)
cache_path: Optional[str] = field(
default=None,
metadata={"help": "Path to save or load the preprocessed datasets."}
)

def __post_init__(self):
if self.streaming and self.val_size > 1e-6 and self.val_size < 1:
Expand All @@ -106,6 +110,9 @@ def __post_init__(self):
if self.streaming and self.max_samples is not None:
raise ValueError("`max_samples` is incompatible with `streaming`.")

if self.streaming and self.cache_path:
raise ValueError("`cache_path` is incompatible with `streaming`.")

def init_for_training(self, seed: int): # support mixing multiple datasets
self.seed = seed
dataset_names = [ds.strip() for ds in self.dataset.split(",")] if self.dataset is not None else []
Expand Down

0 comments on commit 3fe7df6

Please sign in to comment.