Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix error with x64 loss #976

Conversation

stefanocortinovis
Copy link
Contributor

This PR fixes a bug for contrib.reduce_on_plateau() when 64-bit floats are enabled:

  • The current implementation of contrib.reduce_on_plateau() results in a TypeError: true_fun and false_fun output must have identical types due to a jax.lax.cond using functions with different return types when the value passed to transform.update() is a 64-bit scalar tensor.
  • The inconsistency in the types is due to this average value computation.
  • The bug was not picked up by current tests because, in the x64 tests, values were not explicitly cast to 64-bit tensors.
  • The new implementation corrects the issue with the tests and solves the bug by explicitly casting value passed to transform.update() to 32-bits where needed.

@fabianp
Copy link
Member

fabianp commented May 29, 2024

@stefanocortinovis please ping me once the changes have been implemented

@stefanocortinovis
Copy link
Contributor Author

@fabianp done! I will just need to add a couple of small changes when #975 gets merged.

@copybara-service copybara-service bot merged commit b622fc3 into google-deepmind:main Jun 3, 2024
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants