Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do everything except one operation in float64 #100

Open
AlexanderMath opened this issue Sep 20, 2023 · 3 comments
Open

Do everything except one operation in float64 #100

AlexanderMath opened this issue Sep 20, 2023 · 3 comments
Assignees

Comments

@AlexanderMath
Copy link
Contributor

Transform the Jax graph to perform everything in float64 except a set of user-specified operations. May not be possible, we need to think about what that would look like as a Jax graph transformation.

@hatemhelal
Copy link
Contributor

hatemhelal commented Sep 21, 2023

I had the idea to introduce a function decorator that will run an operation twice:

  1. a baseline run with all the floating point inputs promoted to fp64
  2. a second run with fp32

Then report a difference (possibly to stdout or save to a npz file?). What I'm not sure about is how to dynamically annotate a graph to do this but perhaps others have some insight into the jax-way to attempt that.

If that sounds like a useful step I could draft a PR with the decorator.

@AlexanderMath
Copy link
Contributor Author

I had the idea to introduce a function decorator that will run an operation twice:

Do "operation=nanoDFT" or e.g. "operation=einsum(eri, dm)"? I was thinking that we' run nanoDFT twice, first everything float64, then second where a single decorated operation is in float32. Is this also what you're considering?

@hatemhelal
Copy link
Contributor

I put together #110 which contains just the function decorator idea: the problem I haven't solved is how to inject it into the desired place within a larger compute graph. I think for that we may need a syntax to say "decorate function foo" and compiler pass that does a find and replace on foo in the compute graph.

Or maybe isn't the right way to approach the problem in JAX?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants