Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add magic method to our TF models to convert datasets with column inference #17160

Merged
merged 35 commits into from
Jun 6, 2022
Merged
Changes from 1 commit
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
1ccd726
Add method to call to_tf_dataset() with column inference
Rocketknight1 May 10, 2022
2c71f84
Add test for dataset creation
Rocketknight1 May 11, 2022
919cf82
Add a default arg for data collator
Rocketknight1 May 11, 2022
0066598
Fix test
Rocketknight1 May 11, 2022
b40fa6e
Fix call with non-dev version of datasets
Rocketknight1 May 11, 2022
f5f667d
Test correct column removal too
Rocketknight1 May 11, 2022
258392b
make fixup
Rocketknight1 May 11, 2022
ae4be4a
More tests to make sure we remove unwanted columns
Rocketknight1 May 11, 2022
673e23d
Fix test to avoid predicting on unbuilt models
Rocketknight1 May 11, 2022
0ee6e1d
Fix test to avoid predicting on unbuilt models
Rocketknight1 May 11, 2022
2313b3a
Fix test to remove unwanted head mask columns from inputs
Rocketknight1 May 11, 2022
221ae78
Stop pushing your debug breakpoints to the main repo of the $2bn comp…
Rocketknight1 May 12, 2022
1506182
Skip the test in convnext because no grouped conv support
Rocketknight1 May 12, 2022
a9010b1
Drop bools from the dataset dict
Rocketknight1 May 12, 2022
0a29747
Make style
Rocketknight1 May 12, 2022
5b35ff4
Skip the training test for models whose input dicts don't give us labels
Rocketknight1 May 12, 2022
a1b6e92
Skip transformerXL in the test because it doesn't return a simple loss
Rocketknight1 May 12, 2022
33812ea
Skip TFTapas because of some odd NaN losses
Rocketknight1 May 12, 2022
fb19ea9
make style
Rocketknight1 May 12, 2022
a7f6a85
make fixup
Rocketknight1 May 12, 2022
0787a45
Add docstring
Rocketknight1 May 17, 2022
24c0a66
fixup
Rocketknight1 May 17, 2022
fc89e11
Update src/transformers/modeling_tf_utils.py
Rocketknight1 May 19, 2022
0d4553d
Update src/transformers/modeling_tf_utils.py
Rocketknight1 May 19, 2022
0311172
Update src/transformers/modeling_tf_utils.py
Rocketknight1 May 19, 2022
a7b1f60
Update src/transformers/modeling_tf_utils.py
Rocketknight1 May 19, 2022
f01f8c2
Update src/transformers/modeling_tf_utils.py
Rocketknight1 May 19, 2022
24cc6a6
Remove breakpoint from tests
Rocketknight1 May 19, 2022
e381daf
Fix assert, add requires_backends
Rocketknight1 May 19, 2022
0f473db
Protect tokenizer import with if TYPE_CHECKING
Rocketknight1 May 19, 2022
753c0c5
make fixup
Rocketknight1 May 19, 2022
b23efa1
Add noqa, more fixup
Rocketknight1 May 19, 2022
163b4f7
More rearranging for ~* aesthetics *~
Rocketknight1 May 19, 2022
e29f98f
Adding defaults for shuffle and batch_size to match to_tf_dataset()
Rocketknight1 May 19, 2022
726cb39
Update src/transformers/modeling_tf_utils.py
Rocketknight1 May 19, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
More tests to make sure we remove unwanted columns
  • Loading branch information
Rocketknight1 committed May 19, 2022
commit ae4be4afb3502a6092dc6cc5b96764c7fcb4f79c
9 changes: 7 additions & 2 deletions tests/test_modeling_tf_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1535,6 +1535,7 @@ def test_dataset_conversion(self):

if "labels" in inspect.signature(model_class.call).parameters.keys():
tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
tf_inputs_dict["extra_unwanted_column"] = list(tf_inputs_dict.values())[0] # Use a random other tensor
input_dataset = Dataset.from_dict(tf_inputs_dict)
tf_dataset = model.prepare_tf_dataset(
input_dataset, batch_size=len(input_dataset), drop_remainder=False, shuffle=False
Expand All @@ -1543,8 +1544,12 @@ def test_dataset_conversion(self):
self.assertGreater(len(test_batch_labels), 0) # Assert the labels are present
feature_columns = 1 if isinstance(test_batch, tf.Tensor) else len(test_batch)
label_columns = 1 if isinstance(test_batch_labels, tf.Tensor) else len(test_batch_labels)
# Assert we didn't lose any columns
self.assertEqual(feature_columns + label_columns, len(input_dataset.features))
# Assert we discarded the unwanted extra column but kept everything else
self.assertEqual(feature_columns + label_columns, len(input_dataset.features) - 1)
if isinstance(test_batch, dict):
self.assertNotIn("extra_unwanted_column", test_batch)
if isinstance(test_batch_labels, dict):
self.assertNotIn("extra_unwanted_column", test_batch_labels)
model.compile(optimizer="sgd", run_eagerly=True)
model.train_on_batch(test_batch, test_batch_labels)

Expand Down