Skip to content

Commit

Permalink
Adding new train_step logic to make things less confusing for users (h…
Browse files Browse the repository at this point in the history
…uggingface#15994)

* Adding new train_step logic to make things less confusing for users

* DO NOT ASK WHY WE NEED THAT SUBCLASS

* Metrics now working, at least for single-output models with type annotations!

* Updates and TODOs for the new train_step

* Make fixup

* Temporary test workaround until T5 has types

* Temporary test workaround until T5 has types

* I think this actually works! Needs a lot of tests though

* MAke style/quality

* Revert changes to T5 tests

* Deleting the aforementioned unmentionable subclass

* Deleting the aforementioned unmentionable subclass

* Adding a Keras API test

* Style fixes

* Removing unneeded TODO and comments

* Update test_step too

* Stop trying to compute metrics with the dummy_loss, patch up test

* Make style

* make fixup

* Docstring cleanup

* make fixup

* make fixup

* Stop expanding 1D input tensors when using dummy loss

* Adjust T5 test given the new compile()

* make fixup

* Skipping test for convnext

* Removing old T5-specific Keras test now that we have a common one

* make fixup

* make fixup

* Only skip convnext test on CPU

* Update src/transformers/modeling_tf_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/modeling_tf_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Avoiding TF import issues

* make fixup

* Update compile() to support TF 2.3

* Skipping model.fit() on template classes for now

* Skipping model.fit() on template class tests for now

* Replace ad-hoc solution with find_labels

* make fixup

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
  • Loading branch information
Rocketknight1 and sgugger authored Apr 5, 2022
1 parent 7ccacdf commit 4354005
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 83 deletions.
171 changes: 118 additions & 53 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
from .configuration_utils import PretrainedConfig
from .dynamic_module_utils import custom_object_save
from .generation_tf_utils import TFGenerationMixin
from .modeling_tf_outputs import TFSeq2SeqLMOutput
from .tf_utils import shape_list
from .tokenization_utils_base import BatchEncoding
from .utils import (
Expand All @@ -53,6 +52,7 @@
RevisionNotFoundError,
cached_path,
copy_func,
find_labels,
has_file,
hf_bucket_url,
is_offline_mode,
Expand Down Expand Up @@ -715,6 +715,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
base_model_prefix = ""
main_input_name = "input_ids"
_auto_class = None
_using_dummy_loss = None

# a list of re pattern of tensor names to ignore from the model when loading the model weights
# (and avoid unnecessary warnings).
Expand Down Expand Up @@ -899,24 +900,46 @@ def compile(
function themselves.
"""
if loss == "passthrough":
if metrics is not None:
raise ValueError(
"Passing metrics as a dict is not supported when using the internal loss! "
"Please either compile the model with a loss, or remove the metrics argument. "
"Note that advanced metrics using the `KerasMetricCallback` can still be used with the internal "
"loss."
)
logger.warning(
"No loss specified in compile() - the model's internal loss computation will be used as the "
"loss. Don't panic - this is a common way to train TensorFlow models in Transformers! "
"Please ensure your labels are passed as keys in the input dict so that they are "
"accessible to the model during the forward pass. To disable this behaviour, please pass a "
"loss argument, or explicitly pass loss=None if you do not want your model to compute a loss."
"To disable this behaviour, please pass a loss argument, or explicitly pass "
"`loss=None` if you do not want your model to compute a loss."
)
loss = dummy_loss
self._using_dummy_loss = True
else:
self._using_dummy_loss = False
parent_args = list(inspect.signature(tf.keras.Model.compile).parameters.keys())
if "steps_per_execution" in parent_args:
super().compile(
optimizer=optimizer,
loss=loss,
metrics=metrics,
loss_weights=loss_weights,
weighted_metrics=weighted_metrics,
run_eagerly=run_eagerly,
steps_per_execution=steps_per_execution,
**kwargs,
)
else:
super().compile(
optimizer=optimizer,
loss=loss,
metrics=metrics,
loss_weights=loss_weights,
weighted_metrics=weighted_metrics,
run_eagerly=run_eagerly,
experimental_steps_per_execution=steps_per_execution,
**kwargs,
)
loss = {"loss": dummy_loss}
super().compile(
optimizer=optimizer,
loss=loss,
metrics=metrics,
loss_weights=loss_weights,
weighted_metrics=weighted_metrics,
run_eagerly=run_eagerly,
steps_per_execution=steps_per_execution,
**kwargs,
)

def compute_loss(self, *args, **kwargs):
if hasattr(tf.keras.Model, "compute_loss"):
Expand All @@ -935,40 +958,54 @@ def compute_loss(self, *args, **kwargs):
def train_step(self, data):
"""
A modification of Keras's default `train_step` that cleans up the printed metrics when we use a dummy loss. If
a user specifies a loss at model compile time, this function behaves as the original Keras `train_step`. In
this case, it expects the same `data` as the original function (i.e. `(inputs, labels)`).
However, when the model is compiled without specifying the loss AND the expected label columns are passed as
part of the input dictionary, the loss is computed internally (inside the model class) and is used in the
backwards pass. In this case, `data` is a singleton tuple containing `(inputs,)`.
a user specifies a loss at model compile time, this function behaves as the original Keras `train_step`.
This is possible under the aforementioned circumstances because our overriden compile function can set an
additional loss function that reduces a `loss` output, and the model will output a `loss` component (notice the
name matching) containing the loss that was used to train the pre-trained model.
When the model is compiled without specifying the loss, our overridden compile function can set a simple dummy
loss that just reads the loss output head of the model. When using this dummy loss, inputs can be passed either
as keys in the input dictionary, or as normal Keras labels.
"""

# These are the only transformations `Model.fit` applies to user-input
# data when a `tf.data.Dataset` is provided.
data = data_adapter.expand_1d(data)
if not self._using_dummy_loss:
data = data_adapter.expand_1d(data)
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
# These next two lines differ from the base method - they avoid issues when the labels are in
# the input dict (and loss is computed internally)
if y is None and "labels" in x:
y = x["labels"] # Stops confusion with metric computations
elif y is None and "input_ids" in x:
# Just make any kind of dummy array to make loss work
y = tf.zeros(tf.shape(x["input_ids"])[0], dtype=tf.int64)

# When using a dummy loss, we ensure that separate labels are copied to the correct model arguments,
# if those keys are not already present in the input dict
if self._using_dummy_loss and y is not None:
arg_names = list(dict(inspect.signature(self.call).parameters).keys())
label_kwargs = find_labels(self.__class__)
# If y is a tensor and the model only has one label-like input, map y to that input
if len(label_kwargs) == 1 and isinstance(y, tf.Tensor):
if isinstance(x, tf.Tensor):
x = {arg_names[0]: x}
label_kwarg = next(iter(label_kwargs))
if label_kwarg not in x:
x[label_kwarg] = y
# Otherwise, copy keys from y to x as long as they weren't already present in x
elif isinstance(y, dict):
if isinstance(x, tf.Tensor):
x = {arg_names[0]: x}
for key, val in y.items():
if key in arg_names and key not in x:
x[key] = val

# Run forward pass.
with tf.GradientTape() as tape:
y_pred = self(x, training=True)
loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
if self._using_dummy_loss:
loss = self.compiled_loss(y_pred.loss, y_pred.loss, sample_weight, regularization_losses=self.losses)
else:
loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
# Run backwards pass.
self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
# When y_pred is a ModelOutput and y is a tf.Tensor the metrics update
# should be done only with the relevant ModelOutput param that is
# considered by the loss.
if isinstance(y_pred, TFSeq2SeqLMOutput) and isinstance(y, tf.Tensor):
y_pred = y_pred["logits"]
self.compiled_metrics.update_state(y, y_pred, sample_weight)

# When using the dummy_loss we know metrics are not present, so we can skip a lot of this
if self._using_dummy_loss:
self.compiled_metrics.update_state(y_pred.loss, y_pred.loss, sample_weight)
else:
self.compiled_metrics.update_state(y, y_pred, sample_weight)
# Collect metrics to return
return_metrics = {}
for metric in self.metrics:
Expand All @@ -985,23 +1022,51 @@ def train_step(self, data):

def test_step(self, data):
"""
A modification of Keras's default test_step that cleans up the printed metrics when we use a dummy loss.
A modification of Keras's default `test_step` that cleans up the printed metrics when we use a dummy loss. If a
user specifies a loss at model compile time, this function behaves as the original Keras `test_step`.
When the model is compiled without specifying the loss, our overridden compile function can set a simple dummy
loss that just reads the loss output head of the model. When using this dummy loss, inputs can be passed either
as keys in the input dictionary, or as normal Keras labels.
"""
data = data_adapter.expand_1d(data)
# These are the only transformations `Model.fit` applies to user-input
# data when a `tf.data.Dataset` is provided.
if not self._using_dummy_loss:
data = data_adapter.expand_1d(data)
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
# These next two lines differ from the base method - they avoid issues when the labels are in
# the input dict (and loss is computed internally)
if y is None and "labels" in x:
y = x["labels"] # Stops confusion with metric computations
elif y is None and "input_ids" in x:
# Just make any kind of dummy array to make loss work
y = tf.zeros(tf.shape(x["input_ids"])[0], dtype=tf.int64)

# When using a dummy loss, we ensure that separate labels are copied to the correct model arguments,
# if those keys are not already present in the input dict
if self._using_dummy_loss and y is not None:
arg_names = list(dict(inspect.signature(self.call).parameters).keys())
label_kwargs = find_labels(self.__class__)
# If y is a tensor and the model only has one label-like input, map y to that input
if len(label_kwargs) == 1 and isinstance(y, tf.Tensor):
if isinstance(x, tf.Tensor):
x = {arg_names[0]: x}
label_kwarg = next(iter(label_kwargs))
if label_kwarg not in x:
x[label_kwarg] = y
# Otherwise, copy keys from y to x as long as they weren't already present in x
elif isinstance(y, dict):
if isinstance(x, tf.Tensor):
x = {arg_names[0]: x}
for key, val in y.items():
if key in arg_names and key not in x:
x[key] = val

# Run forward pass.
y_pred = self(x, training=False)
self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
# Updates stateful loss metrics.
if isinstance(y_pred, TFSeq2SeqLMOutput) and isinstance(y, tf.Tensor):
y_pred = y_pred["logits"]
self.compiled_metrics.update_state(y, y_pred, sample_weight)
if self._using_dummy_loss:
self.compiled_loss(y_pred.loss, y_pred.loss, sample_weight, regularization_losses=self.losses)
else:
self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)

# When using the dummy_loss we know metrics are not present, so we can skip a lot of this
if self._using_dummy_loss:
self.compiled_metrics.update_state(y_pred.loss, y_pred.loss, sample_weight)
else:
self.compiled_metrics.update_state(y, y_pred, sample_weight)
# Collect metrics to return
return_metrics = {}
for metric in self.metrics:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ def create_and_check_causal_lm_model_as_decoder(
list(prediction_scores.numpy().shape), [self.batch_size, self.seq_length, self.vocab_size]
)


def create_and_check_causal_lm_model_past(
self,
config,
Expand Down Expand Up @@ -597,6 +598,10 @@ def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)

@unittest.skip(reason="Template classes interact badly with this test.")
def test_keras_fit(self):
pass

def test_causal_lm_base_model(self):
"""Test the base model of the causal LM model
Expand Down Expand Up @@ -947,6 +952,10 @@ def _get_word_embedding_weight(model, embedding_layer):
models_equal = False
self.assertTrue(models_equal)

@unittest.skip(reason="Template classes interact badly with this test.")
def test_keras_fit(self):
pass


def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
Expand Down
7 changes: 7 additions & 0 deletions tests/convnext/test_modeling_tf_convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,13 @@ def setUp(self):
def test_inputs_embeds(self):
pass

@unittest.skipIf(
not is_tf_available() or len(tf.config.list_physical_devices("GPU")) == 0,
reason="TF (<=2.8) does not support backprop for grouped convolutions on CPU.",
)
def test_keras_fit(self):
pass

@unittest.skip(reason="ConvNext does not support input and output embeddings")
def test_model_common_attributes(self):
pass
Expand Down
30 changes: 0 additions & 30 deletions tests/t5/test_modeling_tf_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,33 +804,3 @@ def test_translation_en_to_ro(self):
translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)

self.assertEqual(translation, expected_translation)

def test_finetune_keras_trainer(self):
"""Ensure that the model can be fine-tuned via the keras API and
that metrics work as expected.
"""

# This metric expects to be called with the logits output
def _accuracy(y_true, y_pred):
return tf.keras.metrics.sparse_categorical_crossentropy(y_true[:, 0], y_pred[:, 0])

# measure the accuracy of the first token
class FirstTokenAccuracy(tf.keras.metrics.MeanMetricWrapper):
def __init__(self, name="accuracy", **kwargs):
super().__init__(_accuracy, name=name, **kwargs)

model = self.model
model.compile("adam", metrics=FirstTokenAccuracy())
tokenizer = T5Tokenizer.from_pretrained("t5-small")

examples = [
("sentiment: Everything is awesome!", "positive"),
("sentiment: Tensorflow datasets are hard to use", "negative"),
]

inputs = dict(tokenizer([x[0] for x in examples], padding=True, return_tensors="tf"))
inputs["labels"] = tokenizer([x[1] for x in examples], return_tensors="tf").input_ids

model.fit(inputs)
m = model.evaluate(inputs)
self.assertEqual(len(m), 2)
50 changes: 50 additions & 0 deletions tests/test_modeling_tf_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1302,6 +1302,56 @@ def test_loss_computation(self):

self.assertEqual(loss.shape, [loss_size])

def test_keras_fit(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
if getattr(model, "hf_compute_loss", None):
# Test that model correctly compute the loss with kwargs
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
# Is there a better way to remove these decoder inputs?
prepared_for_class = {
key: val
for key, val in prepared_for_class.items()
if key not in ("head_mask", "decoder_head_mask", "cross_attn_head_mask", "decoder_input_ids")
}

possible_label_cols = {
"labels",
"label",
"label_ids",
"start_positions",
"start_position",
"end_positions",
"end_position",
"next_sentence_label",
}
label_names = possible_label_cols.intersection(set(prepared_for_class))
self.assertGreater(len(label_names), 0, msg="No matching label names found!")
labels = {key: val for key, val in prepared_for_class.items() if key in label_names}
inputs_minus_labels = {key: val for key, val in prepared_for_class.items() if key not in label_names}
self.assertGreater(len(inputs_minus_labels), 0)
model.compile(optimizer=tf.keras.optimizers.SGD(0.0), run_eagerly=True)
# Make sure the model fits without crashing regardless of where we pass the labels
history1 = model.fit(
prepared_for_class,
validation_data=prepared_for_class,
steps_per_epoch=1,
validation_steps=1,
shuffle=False,
)
val_loss1 = history1.history["val_loss"][0]
history2 = model.fit(
inputs_minus_labels,
labels,
validation_data=(inputs_minus_labels, labels),
steps_per_epoch=1,
validation_steps=1,
shuffle=False,
)
val_loss2 = history2.history["val_loss"][0]
self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3))

def test_generate_with_headmasking(self):
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
Expand Down

0 comments on commit 4354005

Please sign in to comment.