You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Transform the Jax graph to perform everything in float64 except a set of user-specified operations. May not be possible, we need to think about what that would look like as a Jax graph transformation.
The text was updated successfully, but these errors were encountered:
I had the idea to introduce a function decorator that will run an operation twice:
a baseline run with all the floating point inputs promoted to fp64
a second run with fp32
Then report a difference (possibly to stdout or save to a npz file?). What I'm not sure about is how to dynamically annotate a graph to do this but perhaps others have some insight into the jax-way to attempt that.
If that sounds like a useful step I could draft a PR with the decorator.
I had the idea to introduce a function decorator that will run an operation twice:
Do "operation=nanoDFT" or e.g. "operation=einsum(eri, dm)"? I was thinking that we' run nanoDFT twice, first everything float64, then second where a single decorated operation is in float32. Is this also what you're considering?
I put together #110 which contains just the function decorator idea: the problem I haven't solved is how to inject it into the desired place within a larger compute graph. I think for that we may need a syntax to say "decorate function foo" and compiler pass that does a find and replace on foo in the compute graph.
Or maybe isn't the right way to approach the problem in JAX?
Transform the Jax graph to perform everything in float64 except a set of user-specified operations. May not be possible, we need to think about what that would look like as a Jax graph transformation.
The text was updated successfully, but these errors were encountered: