diff --git a/tests/data/test_data_handlers.py b/tests/data/test_data_handlers.py index 403881676..d2a390fe9 100644 --- a/tests/data/test_data_handlers.py +++ b/tests/data/test_data_handlers.py @@ -24,7 +24,10 @@ from tests.artifacts.testdata import MODEL_NAME, TWITTER_COMPLAINTS_DATA_JSONL # Local -from tuning.data.data_handlers import apply_custom_data_formatting_template +from tuning.data.data_handlers import ( + apply_custom_data_formatting_template, + combine_sequence, +) def test_apply_custom_formatting_template(): @@ -71,3 +74,37 @@ def test_apply_custom_formatting_template_gives_error_with_wrong_keys(): "template": template, }, ) + + +@pytest.mark.parametrize( + "input_element,output_element,expected_res", + [ + ("foo ", "bar", "foo bar"), + ("foo\n", "bar", "foo\nbar"), + ("foo\t", "bar", "foo\tbar"), + ("foo", "bar", "foo bar"), + ], +) +def test_combine_sequence(input_element, output_element, expected_res): + """Ensure that input / output elements are combined with correct whitespace handling.""" + comb_seq = combine_sequence(input_element, output_element) + assert isinstance(comb_seq, str) + assert comb_seq == expected_res + + +@pytest.mark.parametrize( + "input_element,output_element,expected_res", + [ + ("foo ", "bar", "foo bar"), + ("foo\n", "bar", "foo\nbar"), + ("foo\t", "bar", "foo\tbar"), + ("foo", "bar", "foo bar"), + ], +) +def test_combine_sequence_adds_eos(input_element, output_element, expected_res): + """Ensure that input / output elements are combined with correct whitespace handling.""" + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + comb_seq = combine_sequence(input_element, output_element, tokenizer.eos_token) + expected_res += tokenizer.eos_token + assert isinstance(comb_seq, str) + assert comb_seq == expected_res diff --git a/tests/data/test_data_preprocessing_utils.py b/tests/data/test_data_preprocessing_utils.py index f4ebf9ab5..02308b2f5 100644 --- a/tests/data/test_data_preprocessing_utils.py +++ b/tests/data/test_data_preprocessing_utils.py @@ -43,12 +43,8 @@ # Local from tuning.config import configs from tuning.data.data_config import DataPreProcessorConfig, DataSetConfig -from tuning.data.data_preprocessing_utils import ( - combine_sequence, - get_data_collator, - validate_data_args, -) -from tuning.data.data_processors import HFBasedDataPreProcessor, get_datapreprocessor +from tuning.data.data_preprocessing_utils import get_data_collator +from tuning.data.data_processors import DataPreProcessor, get_datapreprocessor from tuning.data.setup_dataprocessor import ( _process_dataconfig_file, is_pretokenized_dataset, @@ -56,40 +52,6 @@ ) -@pytest.mark.parametrize( - "input_element,output_element,expected_res", - [ - ("foo ", "bar", "foo bar"), - ("foo\n", "bar", "foo\nbar"), - ("foo\t", "bar", "foo\tbar"), - ("foo", "bar", "foo bar"), - ], -) -def test_combine_sequence(input_element, output_element, expected_res): - """Ensure that input / output elements are combined with correct whitespace handling.""" - comb_seq = combine_sequence(input_element, output_element) - assert isinstance(comb_seq, str) - assert comb_seq == expected_res - - -@pytest.mark.parametrize( - "input_element,output_element,expected_res", - [ - ("foo ", "bar", "foo bar"), - ("foo\n", "bar", "foo\nbar"), - ("foo\t", "bar", "foo\tbar"), - ("foo", "bar", "foo bar"), - ], -) -def test_combine_sequence_adds_eos(input_element, output_element, expected_res): - """Ensure that input / output elements are combined with correct whitespace handling.""" - tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) - comb_seq = combine_sequence(input_element, output_element, tokenizer.eos_token) - expected_res += tokenizer.eos_token - assert isinstance(comb_seq, str) - assert comb_seq == expected_res - - @pytest.mark.parametrize( "datafile, column_names", [ @@ -222,7 +184,6 @@ def test_load_dataset_without_dataconfig_and_datafile(): ) def test_is_pretokenized_data(data, result): """Ensure that the correct collator type is fetched based on the data args""" - assert is_pretokenized_dataset(data=data) == result @@ -361,43 +322,16 @@ def test_get_data_collator( ), ], ) -def test_validate_args(data_args, packing): +def test_process_data_args_throws_error_where_needed(data_args, packing): """Ensure that respective errors are thrown for incorrect data arguments""" with pytest.raises(ValueError): - is_traindata_tokenized = is_pretokenized_dataset(data_args.training_data_path) - is_evaldata_tokenized = is_pretokenized_dataset(data_args.validation_data_path) - validate_data_args( - data_args, packing, is_traindata_tokenized, is_evaldata_tokenized + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + TRAIN_ARGS = configs.TrainingArguments( + packing=packing, + max_seq_length=1024, + output_dir="tmp", # Not needed but positional ) - - -@pytest.mark.parametrize( - "data_args, packing", - [ - # pretokenized train dataset and no validation dataset passed - ( - configs.DataArguments( - training_data_path=TWITTER_COMPLAINTS_TOKENIZED_JSONL, - ), - False, - ), - # pretokenized train and validation datasets - ( - configs.DataArguments( - training_data_path=TWITTER_COMPLAINTS_TOKENIZED_JSONL, - validation_data_path=TWITTER_COMPLAINTS_TOKENIZED_JSONL, - ), - False, - ), - ], -) -def test_validate_args_pretokenized(data_args, packing): - """Ensure that supported data args do not error out when passing pretokenized datasets""" - is_traindata_tokenized = is_pretokenized_dataset(data_args.training_data_path) - is_evaldata_tokenized = is_pretokenized_dataset(data_args.validation_data_path) - validate_data_args( - data_args, packing, is_traindata_tokenized, is_evaldata_tokenized - ) + (_, _, _, _, _, _) = process_dataargs(data_args, tokenizer, TRAIN_ARGS) @pytest.mark.parametrize( @@ -448,11 +382,7 @@ def test_process_dataconfig_file(data_config_path, data_path): data_args = configs.DataArguments(data_config_path=temp_yaml_file_path) tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) - packing = (False,) - max_seq_length = 1024 - (train_set, _, _, _, _, _) = _process_dataconfig_file( - data_args, tokenizer, packing, max_seq_length - ) + (train_set, _, _) = _process_dataconfig_file(data_args, tokenizer) assert isinstance(train_set, Dataset) if datasets_name == "text_dataset_input_output_masking": column_names = set(["input_ids", "attention_mask", "labels"]) @@ -625,7 +555,7 @@ def test_process_dataset_configs(datafile, column_names, datasetconfigname): """Test process_dataset_configs for expected output.""" dataprocessor_config = DataPreProcessorConfig() tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) - processor = HFBasedDataPreProcessor( + processor = DataPreProcessor( processor_config=dataprocessor_config, tokenizer=tokenizer, ) diff --git a/tuning/data/data_handlers.py b/tuning/data/data_handlers.py index 5a0a40ff1..f0100072b 100644 --- a/tuning/data/data_handlers.py +++ b/tuning/data/data_handlers.py @@ -15,26 +15,55 @@ # Definition of some predefined data preprocessing functions that we need. # Standard -from typing import Dict +from typing import Dict, List +import re # Third Party from transformers import AutoTokenizer -# Local -from tuning.data.data_preprocessing_utils import combine_sequence, custom_data_formatter + +### Utils for custom masking / manipulating input / output strs, etc +def combine_sequence(input_element: str, output_element: str, eos_token: str = ""): + """Combines / concatenates input & output element. + + Args: + input_element: str + Input component of the combined sequence. + output_element: str + Output component of the combined sequence. + eos_token: str + EOS token associated with the tokenizer. \ + If passed, it will be concatenated at end + + Returns: + str + Sequence combined with whitespace. + """ + if not input_element.endswith((" ", "\n", "\t")) and not output_element.startswith( + (" ", "\n", "\t") + ): + return input_element + " " + output_element + eos_token + return input_element + output_element + eos_token def tokenize_and_apply_input_masking( element: Dict[str, str], tokenizer: AutoTokenizer, + column_names: List[str], input_field_name: str, output_field_name: str, **tokenizer_kwargs, ): + if (input_field_name or output_field_name) not in column_names: + raise ValueError( + f"Dataset should contain {input_field_name} \ + and {output_field_name} field if \ + no dataset_text_field or data_formatter_template specified" + ) + input_text = element[input_field_name] output_text = element[output_field_name] - # TODO: Eventually move the code here combined = combine_sequence(input_text, output_text, eos_token=tokenizer.eos_token) fn_kwargs = tokenizer_kwargs.get("fn_kwargs", {}) @@ -56,7 +85,10 @@ def tokenize_and_apply_input_masking( def apply_dataset_formatting( - element: Dict[str, str], tokenizer: AutoTokenizer, dataset_text_field: str, **kwargs + element: Dict[str, str], + tokenizer: AutoTokenizer, + dataset_text_field: str, + **kwargs, ): return { f"{dataset_text_field}": element[f"{dataset_text_field}"] + tokenizer.eos_token @@ -85,8 +117,22 @@ def apply_custom_data_formatting_template( template += tokenizer.eos_token - # TODO: Eventually move the code here. - return custom_data_formatter(element, template, dataset_text_field) + def replace_text(match_obj): + captured_groups = match_obj.groups() + if len(captured_groups) != 1: + raise ValueError( + "Unexpectedly captured multiple groups in template formatting" + ) + + index_object = captured_groups[0] + if index_object not in element: + raise KeyError("Requested template string is not a valid key in dict") + + return element[index_object] + + return { + dataset_text_field: re.sub(r"{{([\s0-9a-zA-Z_\-\.]+)}}", replace_text, template) + } AVAILABLE_DATA_HANDLERS = { diff --git a/tuning/data/data_preprocessing_utils.py b/tuning/data/data_preprocessing_utils.py index 9efa04e38..589e4c9ef 100644 --- a/tuning/data/data_preprocessing_utils.py +++ b/tuning/data/data_preprocessing_utils.py @@ -13,127 +13,14 @@ # limitations under the License. # Standard from typing import Callable, Optional -import re # Third Party from transformers import AutoTokenizer, DataCollatorForSeq2Seq from trl import DataCollatorForCompletionOnlyLM -import datasets # Local from tuning.config import configs -# In future we may make the fields configurable -DEFAULT_JSON_INPUT_KEY = "input" -DEFAULT_JSON_OUTPUT_KEY = "output" - - -def validate_data_args( - data_args: configs.DataArguments, - packing: bool, - is_traindataset_tokenized: bool, - is_evaldataset_tokenized: bool, -): - - assert isinstance( - data_args.training_data_path, str - ), "Training data path has to be set and str" - - ### Data format 1 - # if the provided train dataset is pretokenized - # however user provides formatting flags, error out - if is_traindataset_tokenized: - if ( - data_args.response_template - or data_args.data_formatter_template - or data_args.dataset_text_field - ): - raise ValueError( - "fields response_template, data_formatter_template, and dataset_text_field \ - are not applicable for pretokenized datasets" - ) - - # if the train dataset is pretokenized - # ensure validation dataset is pretokenized otherwise error out - if data_args.validation_data_path and not is_evaldataset_tokenized: - raise ValueError( - "validation data should be pretokenized to be used \ - along with pretokenized train data" - ) - - # packing wont be available for pretokenized datasets in trl library - # see: https://github.com/huggingface/trl/issues/1848 - if packing: - raise ValueError("packing will not be used when datasets are pretokenized") - return - - ### Data format 2 - # Dataset containing single sequence needs a response template for masking - if data_args.dataset_text_field or data_args.data_formatter_template: - if data_args.response_template is None: - if packing is False: - raise ValueError( - "Since dataset_text_field or data_formatter_template \ - is provided and packing is disabled, \ - needs a corresponding response template for masking" - ) - - if data_args.response_template: - # To use Response template, pass datasets with single sequence instances \ - # or a formatter template to create single sequence on the fly. - if not (data_args.dataset_text_field or data_args.data_formatter_template): - raise ValueError( - "dataset_text_field and data_formatter_template are None. \ - One of them needs to be set to use response_template" - ) - # Only one of dataset_text_field or data_formatter_template should be set. - if data_args.dataset_text_field and data_args.data_formatter_template: - raise ValueError( - "dataset_text_field and data_formatter_template are both set,\ - but are mutually exclusive options" - ) - - ### Data format 3 - # If not single sequence, JSON should contain input/output fields - if not (data_args.dataset_text_field or data_args.data_formatter_template): - json_dataset = datasets.load_dataset( - "json", data_files=data_args.training_data_path - ) - if DEFAULT_JSON_INPUT_KEY not in json_dataset["train"].column_names: - raise ValueError( - "JSON should contain input field if no dataset_text_field or \ - data_formatter_template specified" - ) - if DEFAULT_JSON_OUTPUT_KEY not in json_dataset["train"].column_names: - raise ValueError( - "JSON should contain output field if no dataset_text_field or \ - data_formatter_template specified" - ) - - -### Utils for custom masking / manipulating input / output strs, etc -def combine_sequence(input_element: str, output_element: str, eos_token: str = ""): - """Combines / concatenates input & output element. - - Args: - input_element: str - Input component of the combined sequence. - output_element: str - Output component of the combined sequence. - eos_token: str - EOS token associated with the tokenizer. \ - If passed, it will be concatenated at end - - Returns: - str - Sequence combined with whitespace. - """ - if not input_element.endswith((" ", "\n", "\t")) and not output_element.startswith( - (" ", "\n", "\t") - ): - return input_element + " " + output_element + eos_token - return input_element + output_element + eos_token - def get_data_collator( packing: bool, @@ -185,24 +72,3 @@ def get_data_collator( raise ValueError( "Could not pick a data collator. Please refer to supported data formats" ) - - -def custom_data_formatter(element, template, formatted_dataset_field): - def replace_text(match_obj): - captured_groups = match_obj.groups() - if len(captured_groups) != 1: - raise ValueError( - "Unexpectedly captured multiple groups in template formatting" - ) - - index_object = captured_groups[0] - if index_object not in element: - raise KeyError("Requested template string is not a valid key in dict") - - return element[index_object] - - return { - formatted_dataset_field: re.sub( - r"{{([\s0-9a-zA-Z_\-\.]+)}}", replace_text, template - ) - } diff --git a/tuning/data/data_processors.py b/tuning/data/data_processors.py index 8547e4eca..f6f3b0ec9 100644 --- a/tuning/data/data_processors.py +++ b/tuning/data/data_processors.py @@ -13,7 +13,6 @@ # limitations under the License. # Standard -from abc import ABC, abstractmethod from typing import Dict, List, Union import logging import os @@ -31,7 +30,7 @@ from tuning.utils.utils import get_extension, get_loader_for_filepath -class DataPreProcessor(ABC): +class DataPreProcessor: tokenizer = None data_config: DataConfig = None @@ -47,33 +46,9 @@ def __init__( # Initialize other objects self.registered_handlers = {} - def load_dataset( - self, - datasetconfig: DataSetConfig, - splitName: str, - datafile: str = None, - **kwargs, - ): - raise NotImplementedError("Needs to be implemented") - def register_data_handler(self, name: str, func: callable): self.registered_handlers[name] = func - @abstractmethod - def process_dataset_configs( - self, dataset_configs: List[DataSetConfig], **extra_kwargs - ) -> Union[Dataset, IterableDataset]: - raise NotImplementedError("Needs to be implemented") - - -class HFBasedDataPreProcessor(DataPreProcessor): - def __init__( - self, - processor_config: DataPreProcessorConfig, - tokenizer: AutoTokenizer, - ): - super().__init__(processor_config=processor_config, tokenizer=tokenizer) - def load_dataset( self, datasetconfig: DataSetConfig, @@ -122,7 +97,7 @@ def _process_dataset_configs( final_datasets = None splitName = "train" # default - logging.info("Starting HFBasedDataPreProcessor...") + logging.info("Starting DataPreProcessor...") # Iterate over the multiple datasets provided to us for d in dataset_configs: logging.info("Loading %s", d.name) @@ -208,9 +183,11 @@ def process_dataset_configs( # Use broadcast_object_list to share the dataset object across ranks # TODO: Check if torch.distributed.barrier() is called in broadcast_object_list() - obj_list = [train_dataset] - torch.distributed.broadcast_object_list(obj_list, src=0) - train_dataset = obj_list[0] + # See https://github.com/pytorch/pytorch/issues/56142 + # for why the list is shared like this + to_share = [train_dataset] + torch.distributed.broadcast_object_list(to_share, src=0) + train_dataset = to_share[0] else: logging.info("Processing data...") train_dataset = self._process_dataset_configs(dataset_configs, **kwargs) @@ -228,13 +205,9 @@ def autoregister_available_handlers(processor: DataPreProcessor): def get_datapreprocessor( processor_config: DataPreProcessorConfig, tokenizer: AutoTokenizer ) -> DataPreProcessor: - processor = processor_config.type - if processor == "default": - processor = HFBasedDataPreProcessor( - processor_config=processor_config, - tokenizer=tokenizer, - ) - else: - processor = None + processor = DataPreProcessor( + processor_config=processor_config, + tokenizer=tokenizer, + ) autoregister_available_handlers(processor) return processor diff --git a/tuning/data/setup_dataprocessor.py b/tuning/data/setup_dataprocessor.py index c83eeb1f0..5db8e0aee 100644 --- a/tuning/data/setup_dataprocessor.py +++ b/tuning/data/setup_dataprocessor.py @@ -30,14 +30,12 @@ DataSetConfig, load_and_validate_data_config, ) -from tuning.data.data_preprocessing_utils import ( - DEFAULT_JSON_INPUT_KEY, - DEFAULT_JSON_OUTPUT_KEY, - get_data_collator, - validate_data_args, -) +from tuning.data.data_preprocessing_utils import get_data_collator from tuning.data.data_processors import get_datapreprocessor +# In future we may make the fields configurable +DEFAULT_JSON_INPUT_KEY = "input" +DEFAULT_JSON_OUTPUT_KEY = "output" # check if the provided dataset is pretokenized or not # the check is taken from trl @@ -55,74 +53,133 @@ def is_pretokenized_dataset(data: Union[str, Dataset, IterableDataset]): return ("input_ids" in data.column_names) and ("labels" in data.column_names) -# For now assume only training dataset is passed via data config file. +# TODO: For now assume only training dataset is passed via data config file. # This is very limited but is done to keep first implementation minimal -def _process_dataconfig_file( - data_args: DataArguments, tokenizer: AutoTokenizer, packing: bool, max_seq_len: int -): +def _process_dataconfig_file(data_args: DataArguments, tokenizer: AutoTokenizer): data_config = load_and_validate_data_config(data_args.data_config_path) processor = get_datapreprocessor( processor_config=data_config.dataprocessor, tokenizer=tokenizer ) train_dataset = processor.process_dataset_configs(data_config.datasets) - data_collator = get_data_collator( - packing, - data_args.response_template, - tokenizer, - # Note: Its important to recompute this post handling to - # check if we already tokenized the dataset or not. - is_pretokenized_dataset(train_dataset), - max_seq_len, - ) + return (train_dataset, None, data_args.dataset_text_field) - dataset_kwargs = {} - if is_pretokenized_dataset(train_dataset): - dataset_kwargs["skip_prepare_dataset"] = True - - ## HACK: For now just assume we take train_dataset via data config - return ( - train_dataset, - None, - data_args.dataset_text_field, - data_collator, - max_seq_len, - dataset_kwargs, - ) +# Data Format 1: Pretokenized Data +def _get_pretokenized_dataset_handlers(data_args, packing, is_eval_tokenized): -def process_dataargs( - data_args: DataArguments, tokenizer: AutoTokenizer, train_args: TrainingArguments -): - """ - Args: - data_args: tuning.config.configs.DataArguments - tokenizer: AutoTokenizer - train_args: TrainingArguments - Training arguments passed to the library - Used for packing and max_seq_length - Returns: - Tuple(Dataset, Dataset, str, DataCollator, int, Dict) - tuple containing train_dataset, eval_dataset, dataset_text_field, - data_collator, max_seq_length and dataset_kwargs - """ + # if the provided train dataset is pretokenized + # however user provides formatting flags, error out + if ( + data_args.response_template + or data_args.data_formatter_template + or data_args.dataset_text_field + ): + raise ValueError( + "fields response_template, data_formatter_template, and dataset_text_field \ + are not applicable for pretokenized datasets" + ) - max_seq_length = min(train_args.max_seq_length, tokenizer.model_max_length) - logging.info("Max sequence length is %s", max_seq_length) - if train_args.max_seq_length > tokenizer.model_max_length: - logging.warning( - "max_seq_length %s exceeds tokenizer.model_max_length \ - %s, using tokenizer.model_max_length %s", - train_args.max_seq_length, - tokenizer.model_max_length, - tokenizer.model_max_length, + # if the train dataset is pretokenized + # ensure validation dataset is pretokenized otherwise error out + if is_eval_tokenized: + raise ValueError( + "validation data should be pretokenized to be used \ + along with pretokenized train data" ) - if data_args.data_config_path: - # Data config is specified so our processing path is diverging - return _process_dataconfig_file( - data_args, tokenizer, train_args.packing, max_seq_length + # Support for packing pretokenized datasets has been merged in trl library + # see: https://github.com/huggingface/trl/pull/2011 + # but we wait till a new transformers version is released to remove this check. + if packing: + raise ValueError("packing will not be used when datasets are pretokenized") + + # We do not need a handler here as this is tokenized dataset + return [], None + + +### Data format 2 +def _get_dataset_formatting_handlers(data_args, packing): + + if data_args.response_template is None: + if packing is False: + raise ValueError( + "Since dataset_text_field or data_formatter_template \ + is provided and packing is disabled, \ + needs a corresponding response template for masking" + ) + + if data_args.response_template: + # To use Response template, pass datasets with single sequence instances \ + # or a formatter template to create single sequence on the fly. + if not (data_args.dataset_text_field or data_args.data_formatter_template): + raise ValueError( + "dataset_text_field and data_formatter_template are None. \ + One of them needs to be set to use response_template" + ) + # Only one of dataset_text_field or data_formatter_template should be set. + if data_args.dataset_text_field and data_args.data_formatter_template: + raise ValueError( + "dataset_text_field and data_formatter_template are both set,\ + but are mutually exclusive options" + ) + + fn_kwargs = {} + dataset_text_field = data_args.dataset_text_field + + if dataset_text_field is None: + dataset_text_field = "new_formatted_field" + + fn_kwargs["dataset_text_field"] = dataset_text_field + if data_args.data_formatter_template is None: + handler = DataHandlerConfig( + "apply_dataset_formatting", + arguments={"fn_kwargs": fn_kwargs, "batched": False}, ) + else: + fn_kwargs["template"] = data_args.data_formatter_template + handler = DataHandlerConfig( + "apply_custom_data_formatting_template", + arguments={"fn_kwargs": fn_kwargs, "batched": False}, + ) + return [handler], dataset_text_field + + +### Data format 3 +def _get_default_json_dataset_handlers(data_args, tokenizer_kwargs): + + fn_kwargs = {} + fn_kwargs["input_field_name"] = DEFAULT_JSON_INPUT_KEY + fn_kwargs["output_field_name"] = DEFAULT_JSON_OUTPUT_KEY + fn_kwargs["tokenizer_kwargs"] = tokenizer_kwargs + + kwargs = { + "fn_kwargs": fn_kwargs, + "batched": False, + "remove_columns": "all", + } + + handler = DataHandlerConfig("tokenize_and_apply_input_masking", arguments=kwargs) + return [handler], data_args.dataset_text_field + + +# Process raw dataargs for various usecases. +# Data Format 1: Pretokenized Data +# Use pretokenized data as-is without preprocessing. +# No handlers are needed for this format. +# Data Format 2: Single Sequence Dataset +# If a text field is specified, append the tokenizer's EOS token to it. +# If a formatter template is provided, apply it and save the result. +# Data remains un-tokenized. +# Data Format 3: JSON Dataset with Input/Output Fields +# Combine input and output fields, tokenize the data, and apply input attention masking. +# Requires both input and output fields; throws an error if missing. +def _process_raw_data_args( + data_args: DataArguments, + tokenizer: AutoTokenizer, + packing: bool, + max_seq_length: int, +): # Create a data processor with default processor config default_processor_config = DataPreProcessorConfig() @@ -130,20 +187,19 @@ def process_dataargs( processor_config=default_processor_config, tokenizer=tokenizer ) - # TODO: This check loads first slice of the dataset to view its columns - # Since this load is not done via processor it is redundant - is_traindata_tokenized = is_pretokenized_dataset(data_args.training_data_path) - is_evaldata_tokenized = is_pretokenized_dataset(data_args.validation_data_path) - - # Validate if data args are set properly - validate_data_args( - data_args, train_args.packing, is_traindata_tokenized, is_evaldata_tokenized - ) + assert isinstance( + data_args.training_data_path, str + ), "Training data path has to be set and str" is_eval_dataset_present = False if data_args.validation_data_path: is_eval_dataset_present = True + # TODO: This check loads first slice of the dataset to view its columns + # Since this load is not done via processor it is redundant + is_traindata_tokenized = is_pretokenized_dataset(data_args.training_data_path) + is_evaldata_tokenized = is_pretokenized_dataset(data_args.validation_data_path) + train_dataset_config = DataSetConfig( name="training_data", data_paths=[data_args.training_data_path], @@ -156,9 +212,6 @@ def process_dataargs( data_handlers=None, ) - fn_kwargs = {} - handlers = None - # Setup some tokenizer kwargs for when we need a tokenizer # TODO: Figure out a way to not hardcode this. tokenizer_kwargs = {} @@ -166,48 +219,23 @@ def process_dataargs( tokenizer_kwargs["truncation"] = True tokenizer_kwargs["padding"] = False - dataset_text_field = data_args.dataset_text_field - - # Use case specific handlers + handlers = None + dataset_text_field = None if is_traindata_tokenized: - # dataset_text_field is irrelevant to pretokenized datasets - dataset_text_field = None - elif data_args.data_formatter_template or dataset_text_field: - if dataset_text_field is None: - dataset_text_field = "new_formatted_field" - - if data_args.data_formatter_template is None: - fn_kwargs["dataset_text_field"] = dataset_text_field - handler = DataHandlerConfig( - "apply_dataset_formatting", - arguments={"fn_kwargs": fn_kwargs, "batched": False}, - ) - handlers = [handler] - else: - fn_kwargs["dataset_text_field"] = dataset_text_field - fn_kwargs["template"] = data_args.data_formatter_template - handler = DataHandlerConfig( - "apply_custom_data_formatting_template", - arguments={"fn_kwargs": fn_kwargs, "batched": False}, - ) - handlers = [handler] + # Data Format 1: Pretokenized Data + handlers, dataset_text_field = _get_pretokenized_dataset_handlers( + data_args, packing, (is_eval_dataset_present and not is_evaldata_tokenized) + ) + elif data_args.data_formatter_template or data_args.dataset_text_field: + # Data Format 2: Single Sequence Dataset + handlers, dataset_text_field = _get_dataset_formatting_handlers( + data_args, packing + ) else: - # TODO: These should be called DEFAULT in the name as they are hardcoded. - fn_kwargs["input_field_name"] = DEFAULT_JSON_INPUT_KEY - fn_kwargs["output_field_name"] = DEFAULT_JSON_OUTPUT_KEY - - fn_kwargs["tokenizer_kwargs"] = tokenizer_kwargs - - kwargs = { - "fn_kwargs": fn_kwargs, - "batched": False, - "remove_columns": "all", - } - - handler = DataHandlerConfig( - "tokenize_and_apply_input_masking", arguments=kwargs + # Data Format 3: JSON Dataset with Input/Output Fields + handlers, dataset_text_field = _get_default_json_dataset_handlers( + data_args, tokenizer_kwargs ) - handlers = [handler] # Now set handlers in the dataset configs train_dataset_config.data_handlers = handlers @@ -217,18 +245,64 @@ def process_dataargs( # And let processor handle the logic train_dataset = data_processor.process_dataset_configs([train_dataset_config]) - logging.info("Training dataset length is %s", len(train_dataset)) - eval_dataset = None if is_eval_dataset_present: eval_dataset = data_processor.process_dataset_configs([eval_dataset_config]) - logging.info("Validation dataset length is %s", len(eval_dataset)) + + return (train_dataset, eval_dataset, dataset_text_field) + + +# If a data config file is provided, load it to get the training dataset. +# - Assumes only the training dataset is specified in the config file. +# - Expects a complete and valid data config file from the user. +# +# If no data config file is specified, process the remaining data arguments +# to determine the use case based on their presence, as explained in _process_raw_data_args. +def process_dataargs( + data_args: DataArguments, tokenizer: AutoTokenizer, train_args: TrainingArguments +): + """ + Args: + data_args: tuning.config.configs.DataArguments + tokenizer: AutoTokenizer + train_args: TrainingArguments + Training arguments passed to the library + Used for packing and max_seq_length + Returns: + Tuple(Dataset, Dataset, str, DataCollator, int, Dict) + tuple containing train_dataset, eval_dataset, dataset_text_field, + data_collator, max_seq_length and dataset_kwargs + + """ + + max_seq_length = min(train_args.max_seq_length, tokenizer.model_max_length) + logging.info("Max sequence length is %s", max_seq_length) + if train_args.max_seq_length > tokenizer.model_max_length: + logging.warning( + "max_seq_length %s exceeds tokenizer.model_max_length \ + %s, using tokenizer.model_max_length %s", + train_args.max_seq_length, + tokenizer.model_max_length, + tokenizer.model_max_length, + ) + + train_dataset = eval_dataset = dataset_text_field = None + + if data_args.data_config_path: + train_dataset, eval_dataset, dataset_text_field = _process_dataconfig_file( + data_args, tokenizer + ) + else: + train_dataset, eval_dataset, dataset_text_field = _process_raw_data_args( + data_args, tokenizer, train_args.packing, max_seq_length + ) data_collator = get_data_collator( train_args.packing, data_args.response_template, tokenizer, - # Note: Its important to recompute this post handling to + # Note: This check should not be removed. + # Its important to recompute this post handling to # check if we already tokenized the dataset or not. is_pretokenized_dataset(train_dataset), max_seq_length,