-
Notifications
You must be signed in to change notification settings - Fork 41
Replace finufft's warning with ours #1944
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
Memory benchmark result| Test Name | %Δ | Master (MB) | PR (MB) | Δ (MB) | Time PR (s) | Time Master (s) |
| -------------------------------------- | ------------ | ------------------ | ------------------ | ------------ | ------------------ | ------------------ |
test_objective_jac_w7x | 3.84 % | 3.885e+03 | 4.034e+03 | 149.03 | 40.07 | 36.16 |
test_proximal_jac_w7x_with_eq_update | -1.28 % | 6.664e+03 | 6.579e+03 | -85.13 | 163.50 | 162.20 |
test_proximal_freeb_jac | -0.21 % | 1.320e+04 | 1.317e+04 | -28.34 | 86.55 | 82.94 |
test_proximal_freeb_jac_blocked | 0.14 % | 7.487e+03 | 7.498e+03 | 10.80 | 75.21 | 75.32 |
test_proximal_freeb_jac_batched | 0.62 % | 7.449e+03 | 7.495e+03 | 45.83 | 73.82 | 74.30 |
test_proximal_jac_ripple | -1.28 % | 3.584e+03 | 3.538e+03 | -45.98 | 65.49 | 66.40 |
test_proximal_jac_ripple_bounce1d | -0.77 % | 3.620e+03 | 3.592e+03 | -27.98 | 79.37 | 74.94 |
test_eq_solve | 4.70 % | 2.040e+03 | 2.136e+03 | 95.86 | 94.88 | 93.46 |For the memory plots, go to the summary of |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #1944 +/- ##
==========================================
- Coverage 95.75% 95.73% -0.03%
==========================================
Files 102 103 +1
Lines 28344 28375 +31
==========================================
+ Hits 27142 27166 +24
- Misses 1202 1209 +7
🚀 New features to boost your workflow:
|
dpanici
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not approving yet, running into an error locally, will try to see the issue and post
desc/integrals/_interp_utils.py
Outdated
|
|
||
| def _test_gpu_jax_finufft(): | ||
| """Replacing jax-finufft's warning with ours.""" | ||
| from tests.test_interp_utils import TestFastInterp, _test_inputs_1D |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test folder isn't always included, ie if you install from pip. I don't think we need to run the full test anyways, we just need to test that it jits ok on gpu
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer to do better than "if it compiles, ship it"
dpanici
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Running
from desc.objectives import *gives this error (on a clean install on cpu just doing pip install --editable ., I'm not sure if this is the way, in general, to safely grab things from the tests directory?
Traceback (most recent call last):
File "/Users/dpanici/Research/work/desc_work/random/flexible_eq/try_import.py", line 1, in <module>
from desc.objectives import *
File "/Users/dpanici/Research/DESC/desc/objectives/__init__.py", line 3, in <module>
from ._bootstrap import BootstrapRedlConsistency
File "/Users/dpanici/Research/DESC/desc/objectives/_bootstrap.py", line 6, in <module>
from desc.compute import get_profiles, get_transforms
File "/Users/dpanici/Research/DESC/desc/compute/__init__.py", line 30, in <module>
from . import (
...<15 lines>...
)
File "/Users/dpanici/Research/DESC/desc/compute/_bootstrap.py", line 16, in <module>
from ..integrals.surface_integral import surface_averages_map
File "/Users/dpanici/Research/DESC/desc/integrals/__init__.py", line 3, in <module>
from ._bounce_utils import fast_chebyshev, fast_cubic_spline, fourier_chebyshev
File "/Users/dpanici/Research/DESC/desc/integrals/_bounce_utils.py", line 8, in <module>
from desc.integrals._interp_utils import (
...<8 lines>...
)
File "/Users/dpanici/Research/DESC/desc/integrals/_interp_utils.py", line 998, in <module>
_test_gpu_jax_finufft()
~~~~~~~~~~~~~~~~~~~~~^^
File "/Users/dpanici/Research/DESC/desc/integrals/_interp_utils.py", line 979, in _test_gpu_jax_finufft
from tests.test_interp_utils import TestFastInterp, _test_inputs_1D
ModuleNotFoundError: No module named 'tests'There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still think these kind of checks that include computation must be application specific. For example, on my laptop, I use GPU but the installation instruction doesn't work, so I just installed jax-finufft cpu to shut the warning down, but with this, I will always get this warning. The additional import time is also not necessary.
Also, we can either use this PR or some other one, but the installation instructions are not complete for jax-finufft.
- For example, the local GPU one doesn't work. You need to install CudaToolkit, we should at least mention that (I haven't tested it with Cuda Toolkit because it might cause my other environments to fail with some CUDA related errors, I don't want to take that risk).
- Della instructions only work on
della-gpulogin node, this again needs explanation. - Another thing is that if you try to use the environment with a H100, you cannot, you need to compile it again in a different environment with
CMAKE_CUDA_ARCHITECTURES=90or something like that. Because on Della and Perlmutter, login nodes have A100 with 8.0 architecture. - I shared the Perlmutter instructions in #1937. They are very very limiting but work. I tried making it more general, but looks like if you change one line of the instructions, everything fails.
dpanici
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If in our current install instructions, running this is enough to confirm what this warning is checking explicitly, then I think this PR may not be necessary.
If it is not enough, then just extending the docs to include the extra check on grad of nufft1d2r or whatever should be sufficient. At least in the opinions of myself, @ddudt and @YigitElma
This only really needs to be checked once upon install, not everytime the code is ran. @f0uriest if you think differently let us know since I think you had originally pointed out the desire for this warning?
from desc import set_device
set_device("gpu")
from desc.examples import get
from desc.objectives import ObjectiveFunction, GammaC
obj = ObjectiveFunction(GammaC(get("W7-X"), num_transit=1, num_pitch=1))
obj.build()
x = obj.x()
obj.compute_scaled_error(x).block_until_ready()|
Ideally actually, can this be done within the objectives/compute functions which use jax-finufft? instead of being upon import or in installation? |
0c6c480 to
ce63d30
Compare
Replaces jax-finufft's warning with the one desc developers want.