-
Notifications
You must be signed in to change notification settings - Fork 13
Updating Healpix CUDA primitive #290
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Hello @matt-graham @jasonmcewen @CosmoMatt Just a quick PR to wrap up a few stuff
And finally I added cudastreamhandler which is used to split the XLA provided stream for the VMAP lowering (this header is my own work) There is an issue with building pyssht not sure that this is my fault I will check the failing worflows when I get the chance, but in the meantime a review is appreciated |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello @matt-graham @jasonmcewen @CosmoMatt
Just a quick PR to wrap up a few stuff
1. Updated the binding API to the newest [FFI](https://docs.jax.dev/en/latest/ffi.html) 2. Added a vmap implementation of the cuda primitive 3. Added a transpose rule which allows jacfwd and jacrev (consequently grad aswell) 4. added more tests https://github.com/astro-informatics/s2fft/blob/ASKabalan/tests/test_healpix_ffts.py#L100 5. Removed two files which are now no longer needed with the FFI API ([kernel helpers](https://github.com/astro-informatics/s2fft/blob/main/lib/include/kernel_helpers.h)) (so maybe they should be removed from the license section) 6. Constrained nanobind to be nanobind >=2.0,<2.6" because of a regression [[BUG]: Regression when using scikit build tools and nanobind wjakob/nanobind#982](https://github.com/wjakob/nanobind/issues/982)
And finally I added cudastreamhandler which is used to split the XLA provided stream for the VMAP lowering (this header is my own work)
There is an issue with building pyssht not sure that this is my fault
I will check the failing worflows when I get the chance, but in the meantime a review is appreciated
Hi @ASKabalan, sorry for the delay in getting back to you.
This all sounds great - thanks for picking up #237 in particular and for the updates to use the newer FFI interface.
With regards to the failing workflows - this was probably due to #292 which was fixed in #293. If you merge in latest main
here that should hopefully resolve the upstream dependency build problems that were causing the test workflows to fail.
I've added some initial review comments below. Will have a closer look next week and try testing this out, but don't have access to GPU machine atm.
tests/test_healpix_ffts.py
Outdated
flm_hp = samples.flm_2d_to_hp(flm, L) | ||
f = hp.sphtfunc.alm2map(flm_hp, nside, lmax=L - 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we could use s2fft.inverse(flm, L=L, reality=False, method="jax", sampling="healpix")
here instead of going via healpy
? Rationale being that I would have a slight preference for minimising the number of additional tests that depend on healpy
as it we are no longer requiring it as direct dependency for package and in the long run it might be possible to also remove it as a test dependency.
Co-authored-by: Matt Graham <matthew.m.graham@gmail.com>
I've tried building, installing and running this on a system with CUDA 12.6 + a NVIDIA A100, and running the HEALPix FFT tests with
consistently the tests hang when trying to run the first Running just the IFFT tests with
the tests for both set of test parameters pass. Trying to dig into this a bit, running the following locally import healpy
import jax
import s2fft
import numpy
jax.config.update("jax_enable_x64", True)
seed = 20250416
nside = 4
L = 2 * nside
reality = False
rng = numpy.random.default_rng(seed)
flm = s2fft.utils.signal_generator.generate_flm(rng=rng, L=L, reality=reality)
flm_hp = s2fft.sampling.s2_samples.flm_2d_to_hp(flm, L)
f = healpy.sphtfunc.alm2map(flm_hp, nside, lmax=L - 1)
flm_cuda = s2fft.utils.healpix_ffts.healpix_fft_cuda(f=f, L=L, nside=nside, reality=reality).block_until_ready() raises an error
so it looks like there is some memory addressing issue somewhere in the |
Thank you I was able to reproduce with 12.4.1 but not locally with 12.4 I will take a look |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #290 +/- ##
==========================================
- Coverage 96.55% 96.07% -0.48%
==========================================
Files 32 32
Lines 3450 3469 +19
==========================================
+ Hits 3331 3333 +2
- Misses 119 136 +17 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
@matt-graham Hey I would suggest dropping python3.8 from the test suite since JAX no longer supports it anyway |
Hi @ASKabalan. Do you mean
Yes agreed we should drop Python 3.8 from test matrix - we have an open pull request #305 to update to only supporting Python 3.11+ but this is partially blocked by #212 as the tests currently exit with fatal errors when running on MacOS / Python 3.9+ due to an incompatibility between the OpenMP runtime's the MacOS wheels for |
Add comprehensive documentation and fix dependency issues for CUDA FFT integration. This commit introduces extensive docstrings and inline comments across the C++ and Python codebase, particularly for the CUDA FFT implementation. It also addresses a dependency issue in to ensure proper installation and functionality. Key changes include: - no more CUDA Malloc .. all memory is allocated in Python by XLA - Added detailed docstrings to C++ header files - Enhanced inline comments in C++ source files to explain complex logic and algorithms. - Updated to relax JAX version dependency, resolving installation issues. - Refined docstrings and comments in Python files for clarity and consistency. - Cleaned up debug print statements
Adding a few updates
A batching rule seems to be very important for two things
Being able to jacrev/ jacfwd
and because in most cases .. the size of a healpix map can fit on a single GPU but sometimes we want to batch the spherical transform
I will be doing that next