Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add magic method to our TF models to convert datasets with column inference #17160

Merged
merged 35 commits into from
Jun 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
1ccd726
Add method to call to_tf_dataset() with column inference
Rocketknight1 May 10, 2022
2c71f84
Add test for dataset creation
Rocketknight1 May 11, 2022
919cf82
Add a default arg for data collator
Rocketknight1 May 11, 2022
0066598
Fix test
Rocketknight1 May 11, 2022
b40fa6e
Fix call with non-dev version of datasets
Rocketknight1 May 11, 2022
f5f667d
Test correct column removal too
Rocketknight1 May 11, 2022
258392b
make fixup
Rocketknight1 May 11, 2022
ae4be4a
More tests to make sure we remove unwanted columns
Rocketknight1 May 11, 2022
673e23d
Fix test to avoid predicting on unbuilt models
Rocketknight1 May 11, 2022
0ee6e1d
Fix test to avoid predicting on unbuilt models
Rocketknight1 May 11, 2022
2313b3a
Fix test to remove unwanted head mask columns from inputs
Rocketknight1 May 11, 2022
221ae78
Stop pushing your debug breakpoints to the main repo of the $2bn comp…
Rocketknight1 May 12, 2022
1506182
Skip the test in convnext because no grouped conv support
Rocketknight1 May 12, 2022
a9010b1
Drop bools from the dataset dict
Rocketknight1 May 12, 2022
0a29747
Make style
Rocketknight1 May 12, 2022
5b35ff4
Skip the training test for models whose input dicts don't give us labels
Rocketknight1 May 12, 2022
a1b6e92
Skip transformerXL in the test because it doesn't return a simple loss
Rocketknight1 May 12, 2022
33812ea
Skip TFTapas because of some odd NaN losses
Rocketknight1 May 12, 2022
fb19ea9
make style
Rocketknight1 May 12, 2022
a7f6a85
make fixup
Rocketknight1 May 12, 2022
0787a45
Add docstring
Rocketknight1 May 17, 2022
24c0a66
fixup
Rocketknight1 May 17, 2022
fc89e11
Update src/transformers/modeling_tf_utils.py
Rocketknight1 May 19, 2022
0d4553d
Update src/transformers/modeling_tf_utils.py
Rocketknight1 May 19, 2022
0311172
Update src/transformers/modeling_tf_utils.py
Rocketknight1 May 19, 2022
a7b1f60
Update src/transformers/modeling_tf_utils.py
Rocketknight1 May 19, 2022
f01f8c2
Update src/transformers/modeling_tf_utils.py
Rocketknight1 May 19, 2022
24cc6a6
Remove breakpoint from tests
Rocketknight1 May 19, 2022
e381daf
Fix assert, add requires_backends
Rocketknight1 May 19, 2022
0f473db
Protect tokenizer import with if TYPE_CHECKING
Rocketknight1 May 19, 2022
753c0c5
make fixup
Rocketknight1 May 19, 2022
b23efa1
Add noqa, more fixup
Rocketknight1 May 19, 2022
163b4f7
More rearranging for ~* aesthetics *~
Rocketknight1 May 19, 2022
e29f98f
Adding defaults for shuffle and batch_size to match to_tf_dataset()
Rocketknight1 May 19, 2022
726cb39
Update src/transformers/modeling_tf_utils.py
Rocketknight1 May 19, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 95 additions & 1 deletion src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import re
import warnings
from collections.abc import Mapping
from typing import Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union

import h5py
import numpy as np
Expand All @@ -35,6 +35,7 @@
from huggingface_hub import Repository, list_repo_files
from requests import HTTPError

from . import DataCollatorWithPadding, DefaultDataCollator
from .activations_tf import get_tf_activation
from .configuration_utils import PretrainedConfig
from .dynamic_module_utils import custom_object_save
Expand All @@ -58,9 +59,14 @@
is_offline_mode,
is_remote_url,
logging,
requires_backends,
)


if TYPE_CHECKING:
from . import PreTrainedTokenizerBase


logger = logging.get_logger(__name__)
tf_logger = tf.get_logger()

Expand Down Expand Up @@ -892,6 +898,94 @@ def load_repo_checkpoint(self, repo_path_or_name):
# set it directly, but the user can pass it to fit().
return {"epoch": extra_data["epoch"]}

def prepare_tf_dataset(
self,
dataset: "datasets.Dataset", # noqa:F821
batch_size: int = 8,
shuffle: bool = True,
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
collate_fn: Optional[Callable] = None,
collate_fn_args: Optional[Dict[str, Any]] = None,
drop_remainder: Optional[bool] = None,
prefetch: bool = True,
):
"""
Wraps a HuggingFace `datasets.Dataset` as a `tf.data.Dataset` with collation and batching. This method is
designed to create a "ready-to-use" dataset that can be passed directly to Keras methods like `fit()` without
further modification. The method will drop columns from the dataset if they don't match input names for the
model. If you want to specify the column names to return rather than using the names that match this model, we
recommend using `Dataset.to_tf_dataset()` instead.

Args:
dataset (`Any`):
Rocketknight1 marked this conversation as resolved.
Show resolved Hide resolved
A `datasets.Dataset` to be wrapped as a `tf.data.Dataset`.
batch_size (`int`, defaults to 8):
The size of batches to return.
shuffle (`bool`, defaults to `True`):
Whether to return samples from the dataset in random order. Usually `True` for training datasets and
`False` for validation/test datasets.
tokenizer ([`PreTrainedTokenizerBase`], *optional*):
A `PreTrainedTokenizer` that will be used to pad samples to create batches. Has no effect if a specific
`collate_fn` is passed instead.
collate_fn (`Callable`, *optional*):
A function that collates samples from the dataset into a single batch. Defaults to
`DefaultDataCollator` if no `tokenizer` is supplied or `DataCollatorWithPadding` if a `tokenizer` is
passed.
collate_fn_args (`Dict[str, Any]`, *optional*):
A dict of arguments to pass to the `collate_fn` alongside the list of samples.
drop_remainder (`bool`, *optional*):
Whether to drop the final batch, if the batch_size does not evenly divide the dataset length. Defaults
to the same setting as `shuffle`.
prefetch (`bool`, defaults to `True`):
Whether to add prefetching to the end of the `tf.data` pipeline. This is almost always beneficial for
performance, but can be disabled in edge cases.


Returns:
`Dataset`: A `tf.data.Dataset` which is ready to pass to the Keras API.
"""
requires_backends(self, ["datasets"])
import datasets

if collate_fn is None:
if tokenizer is None:
collate_fn = DefaultDataCollator(return_tensors="tf")
else:
collate_fn = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="tf")
if collate_fn_args is None:
collate_fn_args = dict()

if not isinstance(dataset, datasets.Dataset):
raise TypeError("Dataset argument should be a datasets.Dataset!")
model_inputs = list(dict(inspect.signature(self.call).parameters).keys())
model_labels = find_labels(self.__class__)
unwanted_columns = [
feature
for feature in dataset.features
if feature not in model_inputs and feature not in ("label_ids", "label")
]
dataset = dataset.remove_columns(unwanted_columns)
output_signature, _ = dataset._get_output_signature(
dataset,
batch_size=None,
collate_fn=collate_fn,
collate_fn_args=collate_fn_args,
)
output_columns = list(output_signature.keys())
feature_cols = [col for col in output_columns if col in model_inputs and col not in model_labels]
label_cols = [col for col in output_columns if col in model_labels]
tf_dataset = dataset.to_tf_dataset(
columns=feature_cols,
label_cols=label_cols,
batch_size=batch_size,
shuffle=shuffle,
drop_remainder=drop_remainder,
collate_fn=collate_fn,
collate_fn_args=collate_fn_args,
prefetch=prefetch,
)
return tf_dataset

def compile(
self,
optimizer="rmsprop",
Expand Down
7 changes: 7 additions & 0 deletions tests/models/convnext/test_modeling_tf_convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,13 @@ def test_model(self):
def test_attention_outputs(self):
pass

@unittest.skipIf(
not is_tf_available() or len(tf.config.list_physical_devices("GPU")) == 0,
reason="TF (<=2.8) does not support backprop for grouped convolutions on CPU.",
)
def test_dataset_conversion(self):
super().test_dataset_conversion()

def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
Expand Down
4 changes: 4 additions & 0 deletions tests/models/tapas/test_modeling_tf_tapas.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,10 @@ def test_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)

@unittest.skip(reason="The default test gets NaN losses with the test-generated inputs")
def test_dataset_conversion(self):
pass


def prepare_tapas_single_inputs_for_inference():
# Here we prepare a single table-question pair to test TAPAS inference on:
Expand Down
4 changes: 4 additions & 0 deletions tests/models/transfo_xl/test_modeling_tf_transfo_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,10 @@ def test_model_from_pretrained(self):
model = TFTransfoXLModel.from_pretrained(model_name)
self.assertIsNotNone(model)

@unittest.skip(reason="This model doesn't play well with fit() due to not returning a single loss.")
def test_dataset_conversion(self):
pass


@require_tf
class TFTransfoXLModelLanguageGenerationTest(unittest.TestCase):
Expand Down
52 changes: 52 additions & 0 deletions tests/test_modeling_tf_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from importlib import import_module
from typing import List, Tuple

from datasets import Dataset

from huggingface_hub import delete_repo, login
from requests.exceptions import HTTPError
from transformers import is_tf_available, is_torch_available
Expand Down Expand Up @@ -1509,6 +1511,56 @@ def test_model_main_input_name(self):
observed_main_input_name = list(model_signature.parameters.keys())[1]
self.assertEqual(model_class.main_input_name, observed_main_input_name)

def test_dataset_conversion(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class, return_labels=False)
tf_inputs_dict = {
key: val
for key, val in tf_inputs_dict.items()
if "head_mask" not in key and isinstance(val, tf.Tensor)
}
tf_inputs_dict["extra_unwanted_column"] = list(tf_inputs_dict.values())[0] # Use a random other tensor
input_dataset = Dataset.from_dict(tf_inputs_dict)
tf_dataset = model.prepare_tf_dataset(
input_dataset, batch_size=len(input_dataset), drop_remainder=False, shuffle=False
)
test_batch = next(iter(tf_dataset))
if isinstance(test_batch, tf.Tensor):
self.assertEqual(len(test_batch), len(input_dataset)) # Assert we didn't lose any data
else:
# Assert we discarded the unwanted extra column but kept everything else
self.assertEqual(len(test_batch), len(input_dataset.features) - 1)
self.assertNotIn("extra_unwanted_column", test_batch)
for tensor in test_batch.values():
self.assertTrue(isinstance(tensor, tf.Tensor))
self.assertEqual(len(tensor), len(input_dataset)) # Assert we didn't lose any data
model(test_batch, training=False)

if "labels" in inspect.signature(model_class.call).parameters.keys():
tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
if "labels" not in tf_inputs_dict:
return # This model isn't giving us labels after all, don't try training with it
tf_inputs_dict = {key: val for key, val in tf_inputs_dict.items() if "head_mask" not in key}
tf_inputs_dict["extra_unwanted_column"] = list(tf_inputs_dict.values())[0] # Use a random other tensor
input_dataset = Dataset.from_dict(tf_inputs_dict)
tf_dataset = model.prepare_tf_dataset(
input_dataset, batch_size=len(input_dataset), drop_remainder=False, shuffle=False
)
test_batch, test_batch_labels = next(iter(tf_dataset))
self.assertGreater(len(test_batch_labels), 0) # Assert the labels are present
feature_columns = 1 if isinstance(test_batch, tf.Tensor) else len(test_batch)
label_columns = 1 if isinstance(test_batch_labels, tf.Tensor) else len(test_batch_labels)
# Assert we discarded the unwanted extra column but kept everything else
self.assertEqual(feature_columns + label_columns, len(input_dataset.features) - 1)
if isinstance(test_batch, dict):
self.assertNotIn("extra_unwanted_column", test_batch)
if isinstance(test_batch_labels, dict):
self.assertNotIn("extra_unwanted_column", test_batch_labels)
model.compile(optimizer="sgd", run_eagerly=True)
model.train_on_batch(test_batch, test_batch_labels)

def _generate_random_bad_tokens(self, num_bad_tokens, model):
# special tokens cannot be bad tokens
special_tokens = []
Expand Down