-
Notifications
You must be signed in to change notification settings - Fork 13
Provide PyTorch implementations by wrapping JAX functions #277
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@@ -518,7 +442,9 @@ def forward_transform_jax( | |||
|
|||
flm = jnp.zeros(samples.flm_shape(L), dtype=jnp.complex128) | |||
flm = flm.at[:, m_start_ind:].set( | |||
jnp.einsum("...tlm, ...tm -> ...lm", kernel, ftm, optimize=True) | |||
jnp.einsum( | |||
"...tlm, ...tm -> ...lm", kernel.astype(flm.dtype), ftm, optimize=True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Explicit dtype
cast here was needed to avoid getting warnings when computing gradients through this function which were causing test failures, which indicated an implicit cast from complex to real types was happening which might be loosing information. I think this was due to the kernel
argument here being of real type but as ftm
is complex and the flm
output of the einsum
is complex, kernel
will be implicitly cast to complex as part of einsum
and in the reverse-pass the derivatives with respect to the kernel
will therefore be complex even though we only should retain the real valued component.
As discussed in meeting today marking as this as ready for review and requesting review from @kmulderdas. The remaining todos on exposing torch versions of on-the-fly transforms and also at that point documenting wider torch support can be dealt with an a separate PR. |
@@ -23,7 +23,7 @@ | |||
reality_to_test = [True, False] | |||
methods_to_test = ["numpy", "jax", "torch"] | |||
recursions_to_test = ["price-mcewen", "risbo", "auto"] | |||
iter_to_test = [0, 3] | |||
iter_to_test = [0, 1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I reduced the maximum number of iterations we test for as I noticed the tests with iter=3
where particularly slow, and from a testing perspective just checking the code works with a non-zero number of iterations is sufficient.
@@ -8,6 +9,8 @@ | |||
from s2fft.precompute_transforms.wigner import forward, inverse | |||
from s2fft.sampling import so3_samples as samples | |||
|
|||
jax.config.update("jax_enable_x64", True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was needed to ensure we don't get numerical issues due to using single precision when checking gradients in the method="torch"
tests, as these now use JAX under the hood.
I've now added PyTorch wrappers for the on-the-fly versions of spherical / Wigner-d transforms and updated docs and tests accordingly. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've gone over the changes and they look great. I don't see any reason for not merging this PR, although it might be good to have a chat at a later time about the breaking JAX changes (v0.6.0).
This PR removes the current manual reimplementation of the precompute transform and some of the utility functions provided by
s2fft
to allow use with PyTorch in favour of wrapping the JAX implementations using JAX and PyTorch's mutual support for the DLPack standard as outlined by Matt Johnson in this Gist.Some local benchmarking suggests there is no performance degradation with this wrapping approach compared to the 'native' implementations beyond the very smallest bandlimits
L
and a potential a small constant factor speedup for largerL
- see benchmarks results in files belowprecompute-spherical-torch-benchmarks.json
precompute-spherical-torch-wrapper-benchmarks.json
As all imports from
torch
are after changes in this PR confined to thes2fft.utils.torch_wrapper
module and the import there is guarded in antry: ... except ImportError
block this PR also removestorch
from the required dependencies for the project, with an informative error message being raised when the user tries to use the wrapper functionality withouttorch
being installed.Todo
torch_wrapper
modules2fft.utils.quadrature_torch
ands2fft.utils.resampling_torch