Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite TensorFlow train_step and test_step #17057

Merged
merged 8 commits into from
May 17, 2022
Prev Previous commit
Extract label name remapping to a method
  • Loading branch information
Rocketknight1 committed May 17, 2022
commit a9fdf074359a33d35a988aa884d2b373ffb47035
41 changes: 17 additions & 24 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -950,6 +950,21 @@ def compute_loss(self, *args, **kwargs):
)
return self.hf_compute_loss(*args, **kwargs)

def get_label_to_output_name_mapping(self):
arg_names = list(dict(inspect.signature(self.call).parameters).keys())
if self._label_to_output_map is not None:
return self._label_to_output_map
elif "start_positions" in arg_names:
return {"start_positions": "start_logits", "end_positions": "end_logits"}
elif "sentence_order_label" in arg_names:
return {"labels": "prediction_logits", "sentence_order_label": "sop_logits"}
elif "next_sentence_label" in arg_names:
return {"labels": "prediction_logits", "next_sentence_label": "seq_relationship_logits"}
elif "mc_labels" in arg_names:
return {"labels": "logits", "mc_labels": "mc_logits"}
else:
return dict()

def train_step(self, data):
"""
A modification of Keras's default `train_step` that correctly handles matching outputs to labels for our models
Expand All @@ -961,18 +976,7 @@ def train_step(self, data):
# We hardcode the most common renamings; models with weirder names can set `self._label_to_output_map`
arg_names = list(dict(inspect.signature(self.call).parameters).keys())
label_kwargs = find_labels(self.__class__)
if self._label_to_output_map is not None:
label_to_output = self._label_to_output_map
elif "start_positions" in arg_names:
label_to_output = {"start_positions": "start_logits", "end_positions": "end_logits"}
elif "sentence_order_label" in arg_names:
label_to_output = {"labels": "prediction_logits", "sentence_order_label": "sop_logits"}
elif "next_sentence_label" in arg_names:
label_to_output = {"labels": "prediction_logits", "next_sentence_label": "seq_relationship_logits"}
elif "mc_labels" in arg_names:
label_to_output = {"labels": "logits", "mc_labels": "mc_logits"}
else:
label_to_output = dict()
label_to_output = self.get_label_to_output_name_mapping()
output_to_label = {val: key for key, val in label_to_output.items()}
if not self._using_dummy_loss:
data = data_adapter.expand_1d(data)
Expand Down Expand Up @@ -1069,18 +1073,7 @@ def test_step(self, data):
# We hardcode the most common renamings; models with weirder names can set `self._label_to_output_map`
arg_names = list(dict(inspect.signature(self.call).parameters).keys())
label_kwargs = find_labels(self.__class__)
if self._label_to_output_map is not None:
label_to_output = self._label_to_output_map
elif "start_positions" in arg_names:
label_to_output = {"start_positions": "start_logits", "end_positions": "end_logits"}
elif "sentence_order_label" in arg_names:
label_to_output = {"labels": "prediction_logits", "sentence_order_label": "sop_logits"}
elif "next_sentence_label" in arg_names:
label_to_output = {"labels": "prediction_logits", "next_sentence_label": "seq_relationship_logits"}
elif "mc_labels" in arg_names:
label_to_output = {"labels": "logits", "mc_labels": "mc_logits"}
else:
label_to_output = dict()
label_to_output = self.get_label_to_output_name_mapping()
output_to_label = {val: key for key, val in label_to_output.items()}
if not self._using_dummy_loss:
data = data_adapter.expand_1d(data)
Expand Down