Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite TensorFlow train_step and test_step #17057

Merged
merged 8 commits into from
May 17, 2022
Prev Previous commit
Next Next commit
Remove breakpoint before pushing (this is your job)
  • Loading branch information
Rocketknight1 committed May 9, 2022
commit a3dcca3975f07a2f5b7817717e8e3b729622d5c4
11 changes: 6 additions & 5 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,8 +958,7 @@ def train_step(self, data):
that they are available to the model during the forward pass.
"""

# For now, we hardcode the most common renamings - this will hopefully be expanded to model-specific
# attributes in future.
# We hardcode the most common renamings; models with weirder names can set `self._label_to_output_map`
arg_names = list(dict(inspect.signature(self.call).parameters).keys())
label_kwargs = find_labels(self.__class__)
if self._label_to_output_map is not None:
Expand All @@ -970,6 +969,8 @@ def train_step(self, data):
label_to_output = {"labels": "prediction_logits", "sentence_order_label": "sop_logits"}
elif "next_sentence_label" in arg_names:
label_to_output = {"labels": "prediction_logits", "next_sentence_label": "seq_relationship_logits"}
elif "mc_labels" in arg_names:
label_to_output = {"labels": "logits", "mc_labels": "mc_logits"}
else:
label_to_output = dict()
output_to_label = {val: key for key, val in label_to_output.items()}
Expand Down Expand Up @@ -1039,7 +1040,6 @@ def train_step(self, data):
if loss is None:
loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)

breakpoint()
# Run backwards pass.
self.optimizer.minimize(loss, self.trainable_variables, tape=tape)

Expand All @@ -1061,8 +1061,7 @@ def test_step(self, data):
labels where appropriate. It will also copy label keys into the input dict when using the dummy loss, to ensure
that they are available to the model during the forward pass.
"""
# For now, we hardcode the most common renamings - this will hopefully be expanded to model-specific
# attributes in future.
# We hardcode the most common renamings; models with weirder names can set `self._label_to_output_map`
arg_names = list(dict(inspect.signature(self.call).parameters).keys())
label_kwargs = find_labels(self.__class__)
if self._label_to_output_map is not None:
Expand All @@ -1073,6 +1072,8 @@ def test_step(self, data):
label_to_output = {"labels": "prediction_logits", "sentence_order_label": "sop_logits"}
elif "next_sentence_label" in arg_names:
label_to_output = {"labels": "prediction_logits", "next_sentence_label": "seq_relationship_logits"}
elif "mc_labels" in arg_names:
label_to_output = {"labels": "logits", "mc_labels": "mc_logits"}
else:
label_to_output = dict()
output_to_label = {val: key for key, val in label_to_output.items()}
Expand Down