-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rename compute_loss in TF models #15207
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for working on this!
To add some backward compatibility where we can, maybe we could implement a compute_loss
method that:
- calls the super for TF 2.8
- issues a deprecation wraning and calls
hf_compute_loss
for older versions
Wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, let's just wait for @LysandreJik to review as well before merging.
logger.warning( | ||
"The old compute_loss method is deprecated as it conflicts with the Keras compute_loss " | ||
"method added in TF 2.8. If you want the original HF compute_loss, please call " | ||
"hf_compute_loss() instead." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We use warnings.warn(xxx, FutureWarning)
for deprecation warning.
Also mention it will be fully removed in v5 of Transformers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this looks good to me. Thank you for putting a backward compatibility statement!
def compute_loss(self, *args, **kwargs): | ||
if hasattr(tf.keras.Model, "compute_loss"): | ||
# This will be true in TF 2.8 or greater | ||
return super().compute_loss(*args, **kwargs) | ||
else: | ||
warnings.warn( | ||
"The old compute_loss method is deprecated as it conflicts with the Keras compute_loss " | ||
"method added in TF 2.8. If you want the original HF compute_loss, please call " | ||
"hf_compute_loss() instead. From TF versions >= 2.8, or Transformers versions >= 5, " | ||
"calling compute_loss() will get the Keras method instead.", | ||
FutureWarning, | ||
) | ||
return self.hf_compute_loss(*args, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Love this, super nice.
This PR renames the
compute_loss
method on our models tohf_compute_loss
, as Keras has just added acompute_loss
method to its baseModel
class that causes lots of conflicts. Draft PR for now, since this will probably break something!