Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename compute_loss in TF models #15207

Merged
merged 6 commits into from
Jan 19, 2022
Merged

Rename compute_loss in TF models #15207

merged 6 commits into from
Jan 19, 2022

Conversation

Rocketknight1
Copy link
Member

This PR renames the compute_loss method on our models to hf_compute_loss, as Keras has just added a compute_loss method to its base Model class that causes lots of conflicts. Draft PR for now, since this will probably break something!

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for working on this!

To add some backward compatibility where we can, maybe we could implement a compute_loss method that:

  • calls the super for TF 2.8
  • issues a deprecation wraning and calls hf_compute_loss for older versions
    Wdyt?

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, let's just wait for @LysandreJik to review as well before merging.

Comment on lines 877 to 881
logger.warning(
"The old compute_loss method is deprecated as it conflicts with the Keras compute_loss "
"method added in TF 2.8. If you want the original HF compute_loss, please call "
"hf_compute_loss() instead."
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use warnings.warn(xxx, FutureWarning) for deprecation warning.

Also mention it will be fully removed in v5 of Transformers.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this looks good to me. Thank you for putting a backward compatibility statement!

Comment on lines +872 to +884
def compute_loss(self, *args, **kwargs):
if hasattr(tf.keras.Model, "compute_loss"):
# This will be true in TF 2.8 or greater
return super().compute_loss(*args, **kwargs)
else:
warnings.warn(
"The old compute_loss method is deprecated as it conflicts with the Keras compute_loss "
"method added in TF 2.8. If you want the original HF compute_loss, please call "
"hf_compute_loss() instead. From TF versions >= 2.8, or Transformers versions >= 5, "
"calling compute_loss() will get the Keras method instead.",
FutureWarning,
)
return self.hf_compute_loss(*args, **kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love this, super nice.

@Rocketknight1 Rocketknight1 merged commit 2708bfa into master Jan 19, 2022
@Rocketknight1 Rocketknight1 deleted the hf_compute_loss branch January 19, 2022 13:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Error when running TFT5ForConditionalGeneration with tensorflow-cpu==2.8.0-rc0
4 participants