Skip to content

Provide PyTorch implementations by wrapping JAX functions #277

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 52 commits into from
Apr 23, 2025

Conversation

matt-graham
Copy link
Collaborator

@matt-graham matt-graham commented Mar 11, 2025

This PR removes the current manual reimplementation of the precompute transform and some of the utility functions provided by s2fft to allow use with PyTorch in favour of wrapping the JAX implementations using JAX and PyTorch's mutual support for the DLPack standard as outlined by Matt Johnson in this Gist.

Some local benchmarking suggests there is no performance degradation with this wrapping approach compared to the 'native' implementations beyond the very smallest bandlimits L and a potential a small constant factor speedup for larger L - see benchmarks results in files below

precompute-spherical-torch-benchmarks.json
precompute-spherical-torch-wrapper-benchmarks.json

As all imports from torch are after changes in this PR confined to the s2fft.utils.torch_wrapper module and the import there is guarded in an try: ... except ImportError block this PR also removes torch from the required dependencies for the project, with an informative error message being raised when the user tries to use the wrapper functionality without torch being installed.

Todo

  • Add tests for functions introduced in torch_wrapper module
  • Decide whether to keep wrapped utility modules s2fft.utils.quadrature_torch and s2fft.utils.resampling_torch
  • Add wrappers for non-precompute transforms
  • Update documentation to reflect new wider support for PyTorch

@matt-graham matt-graham added the enhancement New feature or request label Mar 11, 2025
@@ -518,7 +442,9 @@ def forward_transform_jax(

flm = jnp.zeros(samples.flm_shape(L), dtype=jnp.complex128)
flm = flm.at[:, m_start_ind:].set(
jnp.einsum("...tlm, ...tm -> ...lm", kernel, ftm, optimize=True)
jnp.einsum(
"...tlm, ...tm -> ...lm", kernel.astype(flm.dtype), ftm, optimize=True
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explicit dtype cast here was needed to avoid getting warnings when computing gradients through this function which were causing test failures, which indicated an implicit cast from complex to real types was happening which might be loosing information. I think this was due to the kernel argument here being of real type but as ftm is complex and the flm output of the einsum is complex, kernel will be implicitly cast to complex as part of einsum and in the reverse-pass the derivatives with respect to the kernel will therefore be complex even though we only should retain the real valued component.

@matt-graham matt-graham marked this pull request as ready for review April 11, 2025 10:29
@matt-graham matt-graham requested a review from kmulderdas April 11, 2025 10:29
@matt-graham
Copy link
Collaborator Author

As discussed in meeting today marking as this as ready for review and requesting review from @kmulderdas. The remaining todos on exposing torch versions of on-the-fly transforms and also at that point documenting wider torch support can be dealt with an a separate PR.

@@ -23,7 +23,7 @@
reality_to_test = [True, False]
methods_to_test = ["numpy", "jax", "torch"]
recursions_to_test = ["price-mcewen", "risbo", "auto"]
iter_to_test = [0, 3]
iter_to_test = [0, 1]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reduced the maximum number of iterations we test for as I noticed the tests with iter=3 where particularly slow, and from a testing perspective just checking the code works with a non-zero number of iterations is sufficient.

@@ -8,6 +9,8 @@
from s2fft.precompute_transforms.wigner import forward, inverse
from s2fft.sampling import so3_samples as samples

jax.config.update("jax_enable_x64", True)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was needed to ensure we don't get numerical issues due to using single precision when checking gradients in the method="torch" tests, as these now use JAX under the hood.

@matt-graham
Copy link
Collaborator Author

As discussed in meeting today marking as this as ready for review and requesting review from @kmulderdas. The remaining todos on exposing torch versions of on-the-fly transforms and also at that point documenting wider torch support can be dealt with an a separate PR.

I've now added PyTorch wrappers for the on-the-fly versions of spherical / Wigner-d transforms and updated docs and tests accordingly.

Copy link

@kmulderdas kmulderdas left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've gone over the changes and they look great. I don't see any reason for not merging this PR, although it might be good to have a chat at a later time about the breaking JAX changes (v0.6.0).

@matt-graham matt-graham merged commit 83296b9 into main Apr 23, 2025
14 checks passed
@matt-graham matt-graham deleted the mmg/pytorch-wrapper branch April 23, 2025 14:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Suggestion] : Remove strict requirement on backends. Add PyTorch support for on-the-fly transforms Add code-wide switch for PyTorch support
2 participants