Skip to content

Commit

Permalink
Add code comments and make code path clearer.
Browse files Browse the repository at this point in the history
Remove packing check as packing support for pretokenised data is merged
to trl. See huggingface/trl#2011

Signed-off-by: Dushyant Behl <dushyantbehl@users.noreply.github.com>
  • Loading branch information
dushyantbehl committed Dec 3, 2024
1 parent befb5e7 commit e629228
Show file tree
Hide file tree
Showing 6 changed files with 302 additions and 376 deletions.
39 changes: 38 additions & 1 deletion tests/data/test_data_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@
from tests.artifacts.testdata import MODEL_NAME, TWITTER_COMPLAINTS_DATA_JSONL

# Local
from tuning.data.data_handlers import apply_custom_data_formatting_template
from tuning.data.data_handlers import (
apply_custom_data_formatting_template,
combine_sequence,
)


def test_apply_custom_formatting_template():
Expand Down Expand Up @@ -71,3 +74,37 @@ def test_apply_custom_formatting_template_gives_error_with_wrong_keys():
"template": template,
},
)


@pytest.mark.parametrize(
"input_element,output_element,expected_res",
[
("foo ", "bar", "foo bar"),
("foo\n", "bar", "foo\nbar"),
("foo\t", "bar", "foo\tbar"),
("foo", "bar", "foo bar"),
],
)
def test_combine_sequence(input_element, output_element, expected_res):
"""Ensure that input / output elements are combined with correct whitespace handling."""
comb_seq = combine_sequence(input_element, output_element)
assert isinstance(comb_seq, str)
assert comb_seq == expected_res


@pytest.mark.parametrize(
"input_element,output_element,expected_res",
[
("foo ", "bar", "foo bar"),
("foo\n", "bar", "foo\nbar"),
("foo\t", "bar", "foo\tbar"),
("foo", "bar", "foo bar"),
],
)
def test_combine_sequence_adds_eos(input_element, output_element, expected_res):
"""Ensure that input / output elements are combined with correct whitespace handling."""
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
comb_seq = combine_sequence(input_element, output_element, tokenizer.eos_token)
expected_res += tokenizer.eos_token
assert isinstance(comb_seq, str)
assert comb_seq == expected_res
92 changes: 11 additions & 81 deletions tests/data/test_data_preprocessing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,53 +43,15 @@
# Local
from tuning.config import configs
from tuning.data.data_config import DataPreProcessorConfig, DataSetConfig
from tuning.data.data_preprocessing_utils import (
combine_sequence,
get_data_collator,
validate_data_args,
)
from tuning.data.data_processors import HFBasedDataPreProcessor, get_datapreprocessor
from tuning.data.data_preprocessing_utils import get_data_collator
from tuning.data.data_processors import DataPreProcessor, get_datapreprocessor
from tuning.data.setup_dataprocessor import (
_process_dataconfig_file,
is_pretokenized_dataset,
process_dataargs,
)


@pytest.mark.parametrize(
"input_element,output_element,expected_res",
[
("foo ", "bar", "foo bar"),
("foo\n", "bar", "foo\nbar"),
("foo\t", "bar", "foo\tbar"),
("foo", "bar", "foo bar"),
],
)
def test_combine_sequence(input_element, output_element, expected_res):
"""Ensure that input / output elements are combined with correct whitespace handling."""
comb_seq = combine_sequence(input_element, output_element)
assert isinstance(comb_seq, str)
assert comb_seq == expected_res


@pytest.mark.parametrize(
"input_element,output_element,expected_res",
[
("foo ", "bar", "foo bar"),
("foo\n", "bar", "foo\nbar"),
("foo\t", "bar", "foo\tbar"),
("foo", "bar", "foo bar"),
],
)
def test_combine_sequence_adds_eos(input_element, output_element, expected_res):
"""Ensure that input / output elements are combined with correct whitespace handling."""
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
comb_seq = combine_sequence(input_element, output_element, tokenizer.eos_token)
expected_res += tokenizer.eos_token
assert isinstance(comb_seq, str)
assert comb_seq == expected_res


@pytest.mark.parametrize(
"datafile, column_names",
[
Expand Down Expand Up @@ -222,7 +184,6 @@ def test_load_dataset_without_dataconfig_and_datafile():
)
def test_is_pretokenized_data(data, result):
"""Ensure that the correct collator type is fetched based on the data args"""

assert is_pretokenized_dataset(data=data) == result


Expand Down Expand Up @@ -361,43 +322,16 @@ def test_get_data_collator(
),
],
)
def test_validate_args(data_args, packing):
def test_process_data_args_throws_error_where_needed(data_args, packing):
"""Ensure that respective errors are thrown for incorrect data arguments"""
with pytest.raises(ValueError):
is_traindata_tokenized = is_pretokenized_dataset(data_args.training_data_path)
is_evaldata_tokenized = is_pretokenized_dataset(data_args.validation_data_path)
validate_data_args(
data_args, packing, is_traindata_tokenized, is_evaldata_tokenized
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
TRAIN_ARGS = configs.TrainingArguments(
packing=packing,
max_seq_length=1024,
output_dir="tmp", # Not needed but positional
)


@pytest.mark.parametrize(
"data_args, packing",
[
# pretokenized train dataset and no validation dataset passed
(
configs.DataArguments(
training_data_path=TWITTER_COMPLAINTS_TOKENIZED_JSONL,
),
False,
),
# pretokenized train and validation datasets
(
configs.DataArguments(
training_data_path=TWITTER_COMPLAINTS_TOKENIZED_JSONL,
validation_data_path=TWITTER_COMPLAINTS_TOKENIZED_JSONL,
),
False,
),
],
)
def test_validate_args_pretokenized(data_args, packing):
"""Ensure that supported data args do not error out when passing pretokenized datasets"""
is_traindata_tokenized = is_pretokenized_dataset(data_args.training_data_path)
is_evaldata_tokenized = is_pretokenized_dataset(data_args.validation_data_path)
validate_data_args(
data_args, packing, is_traindata_tokenized, is_evaldata_tokenized
)
(_, _, _, _, _, _) = process_dataargs(data_args, tokenizer, TRAIN_ARGS)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -448,11 +382,7 @@ def test_process_dataconfig_file(data_config_path, data_path):
data_args = configs.DataArguments(data_config_path=temp_yaml_file_path)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
packing = (False,)
max_seq_length = 1024
(train_set, _, _, _, _, _) = _process_dataconfig_file(
data_args, tokenizer, packing, max_seq_length
)
(train_set, _, _) = _process_dataconfig_file(data_args, tokenizer)
assert isinstance(train_set, Dataset)
if datasets_name == "text_dataset_input_output_masking":
column_names = set(["input_ids", "attention_mask", "labels"])
Expand Down Expand Up @@ -625,7 +555,7 @@ def test_process_dataset_configs(datafile, column_names, datasetconfigname):
"""Test process_dataset_configs for expected output."""
dataprocessor_config = DataPreProcessorConfig()
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
processor = HFBasedDataPreProcessor(
processor = DataPreProcessor(
processor_config=dataprocessor_config,
tokenizer=tokenizer,
)
Expand Down
60 changes: 53 additions & 7 deletions tuning/data/data_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,55 @@
# Definition of some predefined data preprocessing functions that we need.

# Standard
from typing import Dict
from typing import Dict, List
import re

# Third Party
from transformers import AutoTokenizer

# Local
from tuning.data.data_preprocessing_utils import combine_sequence, custom_data_formatter

### Utils for custom masking / manipulating input / output strs, etc
def combine_sequence(input_element: str, output_element: str, eos_token: str = ""):
"""Combines / concatenates input & output element.
Args:
input_element: str
Input component of the combined sequence.
output_element: str
Output component of the combined sequence.
eos_token: str
EOS token associated with the tokenizer. \
If passed, it will be concatenated at end
Returns:
str
Sequence combined with whitespace.
"""
if not input_element.endswith((" ", "\n", "\t")) and not output_element.startswith(
(" ", "\n", "\t")
):
return input_element + " " + output_element + eos_token
return input_element + output_element + eos_token


def tokenize_and_apply_input_masking(
element: Dict[str, str],
tokenizer: AutoTokenizer,
column_names: List[str],
input_field_name: str,
output_field_name: str,
**tokenizer_kwargs,
):
if (input_field_name or output_field_name) not in column_names:
raise ValueError(
f"Dataset should contain {input_field_name} \
and {output_field_name} field if \
no dataset_text_field or data_formatter_template specified"
)

input_text = element[input_field_name]
output_text = element[output_field_name]

# TODO: Eventually move the code here
combined = combine_sequence(input_text, output_text, eos_token=tokenizer.eos_token)

fn_kwargs = tokenizer_kwargs.get("fn_kwargs", {})
Expand All @@ -56,7 +85,10 @@ def tokenize_and_apply_input_masking(


def apply_dataset_formatting(
element: Dict[str, str], tokenizer: AutoTokenizer, dataset_text_field: str, **kwargs
element: Dict[str, str],
tokenizer: AutoTokenizer,
dataset_text_field: str,
**kwargs,
):
return {
f"{dataset_text_field}": element[f"{dataset_text_field}"] + tokenizer.eos_token
Expand Down Expand Up @@ -85,8 +117,22 @@ def apply_custom_data_formatting_template(

template += tokenizer.eos_token

# TODO: Eventually move the code here.
return custom_data_formatter(element, template, dataset_text_field)
def replace_text(match_obj):
captured_groups = match_obj.groups()
if len(captured_groups) != 1:
raise ValueError(
"Unexpectedly captured multiple groups in template formatting"
)

index_object = captured_groups[0]
if index_object not in element:
raise KeyError("Requested template string is not a valid key in dict")

return element[index_object]

return {
dataset_text_field: re.sub(r"{{([\s0-9a-zA-Z_\-\.]+)}}", replace_text, template)
}


AVAILABLE_DATA_HANDLERS = {
Expand Down
Loading

0 comments on commit e629228

Please sign in to comment.