Skip to content

Commit

Permalink
add initial implementation of dataloader v2
Browse files Browse the repository at this point in the history
Signed-off-by: Dushyant Behl <dushyantbehl@users.noreply.github.com>
  • Loading branch information
dushyantbehl committed Oct 27, 2024
1 parent 7ba3434 commit 5e948f1
Show file tree
Hide file tree
Showing 10 changed files with 705 additions and 22 deletions.
14 changes: 14 additions & 0 deletions examples/predefined_configs/apply_custom_template.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
dataloader:
type: default
datasets:
- name: apply_custom_data_template
data_paths:
- "FILE_PATH"
data_handlers:
- name: tokenize_and_apply_instruction_masking
arguments:
remove_columns: all
batched: false
fn_kwargs:
dataset_text_field: "dataset_text_field"
dataset_template: "dataset_template"
6 changes: 6 additions & 0 deletions examples/predefined_configs/pretokenized_json_data.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
dataloader:
type: default
datasets:
- name: pretokenized_dataset
data_paths:
- "FILE_PATH"
14 changes: 14 additions & 0 deletions examples/predefined_configs/tokenize_and_instruction_masking.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
dataloader:
type: default
datasets:
- name: text_dataset_input_output_masking
data_paths:
- "FILE_PATH"
data_handlers:
- name: tokenize_and_apply_instruction_masking
arguments:
remove_columns: all
batched: false
fn_kwargs:
input_field: "INPUT"
output_field: "OUTPUT"
149 changes: 149 additions & 0 deletions tuning/data/data_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Copyright The FMS HF Tuning Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Standard
from dataclasses import dataclass
from typing import Dict, List, Optional
import logging
import os

# Local
from tuning.utils.utils import load_yaml_or_json


@dataclass
class DataHandlerConfig:
name: str
arguments: Optional[Dict]


@dataclass
class DataSetConfig:
name: str
data_paths: List[str]
sampling: Optional[Dict] = None
splitter_arguments: Optional[Dict] = None
data_handlers: Optional[List[DataHandlerConfig]] = None


@dataclass
class DataLoaderConfig:
type: Optional[str] = "default"
streaming: Optional[bool] = None


@dataclass
class DataConfig:
dataloader: DataLoaderConfig
datasets: List[DataSetConfig]


def _validate_data_handler_config(data_handler) -> DataHandlerConfig:
kwargs = data_handler
assert isinstance(kwargs, dict), "data_handlers in data_config needs to be a dict"
assert "name" in kwargs and isinstance(
kwargs["name"], str
), "data_handlers need to have a name with type str"
assert "arguments" in kwargs, "data handlers need to have arguments"
assert isinstance(
kwargs["arguments"], dict
), "data handler arguments should be of the type dict"
return DataHandlerConfig(**kwargs)


def _validate_dataset_config(dataset_config) -> DataSetConfig:
c = DataSetConfig()
kwargs = dataset_config
assert isinstance(kwargs, dict), "dataset_config in data_config needs to be a dict"
if "name" in kwargs:
assert isinstance(kwargs["name"], str), "dataset name should be string"
c.name = kwargs["name"]
if "data_paths" not in kwargs:
raise ValueError("data_paths should be specified for each dataset")
else:
data_paths = kwargs["data_paths"]
# TODO: Support that data_paths can be a directory or directories
assert (isinstance(data_paths, List), "data_paths should be an array of files")
c.data_paths = []
for p in data_paths:
assert isinstance(p, str), f"path {p} should be of the type string"
assert os.path.exists(p), f"data_paths {p} does not exist"
if not os.isabs(p):
_p = os.path.abspath(p)
logging.warning(
f" Provided path {p} is not absolute changing it to {_p}"
)
p = _p
c.data_paths.append(p)
if "sampling" in kwargs:
sampling_kwargs = kwargs["sampling"]
assert isinstance(
Dict, sampling_kwargs
), "sampling arguments should be of the type dict"
if "ratio" in sampling_kwargs:
ratio = sampling_kwargs["ratio"]
assert (
(isinstance(ratio, float) and (0 <= ratio <= 1.0)),
f"sampling ratio: {ratio} should be float and in range [0.0,1.0]",
)
c.sampling = sampling_kwargs
if "splitter_arguments" in kwargs:
splitter_kwargs = kwargs["splitter_arguments"]
assert isinstance(
Dict, splitter_kwargs
), "splitter_arguments should be of the type dict"
c.splitter_arguments = splitter_kwargs
if "data_handlers" in kwargs:
c.data_handlers = []
for handler in kwargs["data_handlers"]:
c.data_handlers.append(_validate_data_handler_config(handler))
return c


def _validate_dataloader_config(dataloader_config) -> DataLoaderConfig:
kwargs = dataloader_config
c = DataLoaderConfig()
assert isinstance(kwargs, dict), "dataloader in data_config needs to be a dict"
if "streaming" in kwargs:
assert (
isinstance(kwargs["streaming"], bool),
"streaming should be a boolean true or false",
)
c.streaming = kwargs["streaming"]
return c


def validate_data_config(dataconfig: DataConfig):
_validate_dataloader_config(dataconfig.dataloader)
for d in dataconfig.datasets:
_validate_dataset_config(d)


def load_and_validate_data_config(data_config_file: str) -> DataConfig:
raw_data = load_yaml_or_json(data_config_file)
assert isinstance(
raw_data, Dict
), f"The provided data_config file is invalid: {data_config_file}"
data_config = DataConfig()
assert "datasets" in raw_data, "datasets should be provided in data config"
assert isinstance(
raw_data["datasets"], List
), "datasets should be provided as a list"
data_config.datasets = []
for d in raw_data["datasets"]:
data_config.datasets.append(_validate_dataset_config(d))
if "dataloader" in data_config:
dataloader = _validate_dataloader_config(raw_data["dataloader"])
data_config.dataloader = dataloader
return data_config
83 changes: 83 additions & 0 deletions tuning/data/data_handlers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright The FMS HF Tuning Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Definition of some predefined data preprocessing functions that we need.

# Standard
from typing import Dict, List

# Third Party
from transformers import AutoTokenizer

# Local
from tuning.utils.data_utils import custom_data_formatter
from tuning.utils.preprocessing_utils import combine_sequence


def tokenize_and_apply_instruction_masking(
element: Dict[str, str],
tokenizer: AutoTokenizer,
input_field_name: str,
output_field_name: str,
**tokenizer_kwargs,
):
input = element[input_field_name]
output = element[output_field_name]

# TODO: Eventually move the code here
combined = combine_sequence(input, output, eos_token=tokenizer.eos_token)

tokenized_comb_seqs = tokenizer(combined, **tokenizer_kwargs)
tokenized_input = tokenizer(input, **tokenizer_kwargs)

masked_labels = [-100] * len(
tokenized_input.input_ids
) + tokenized_comb_seqs.input_ids[len(tokenized_input.input_ids) :]

# Any benefit of retaining the old columns?
return {
"input_ids": tokenized_comb_seqs.input_ids,
"labels": masked_labels,
"attention_mask": tokenized_comb_seqs.attention_mask,
}


def apply_dataset_formatting(
element: Dict[str, str], tokenizer: AutoTokenizer, dataset_text_field: str, **kwargs
):
return {
f"{dataset_text_field}": element[f"{dataset_text_field}"] + tokenizer.eos_token
}


def apply_custom_data_formatting_template(
element: Dict[str, str],
tokenizer: AutoTokenizer,
dataset_text_field: str,
template: str,
**kwargs,
):
template += tokenizer.eos_token

# TODO: Eventually move the code here.
custom_data_formatter(
element=element, formatted_dataset_field=dataset_text_field, template=template
)


AVAILABLE_DATA_HANDLERS = {
"tokenize_and_apply_instruction_masking": tokenize_and_apply_instruction_masking,
"apply_dataset_formatting": apply_dataset_formatting,
"apply_custom_data_formatting_template": apply_dataset_formatting,
}
Loading

0 comments on commit 5e948f1

Please sign in to comment.