-
Notifications
You must be signed in to change notification settings - Fork 41
Updating Perlmutter installation instructions #1924
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Memory benchmark result| Test Name | %Δ | Master (MB) | PR (MB) | Δ (MB) | Time PR (s) | Time Master (s) |
| -------------------------------------- | ------------ | ------------------ | ------------------ | ------------ | ------------------ | ------------------ |
test_objective_jac_w7x | 1.07 % | 3.992e+03 | 4.034e+03 | 42.69 | 35.06 | 31.65 |
test_proximal_jac_w7x_with_eq_update | -1.25 % | 6.834e+03 | 6.749e+03 | -85.56 | 159.33 | 159.80 |
test_proximal_freeb_jac | -0.27 % | 1.321e+04 | 1.318e+04 | -35.44 | 78.51 | 78.13 |
test_proximal_freeb_jac_blocked | -1.05 % | 7.668e+03 | 7.587e+03 | -80.77 | 68.38 | 69.08 |
test_proximal_freeb_jac_batched | -0.55 % | 7.605e+03 | 7.563e+03 | -41.64 | 69.20 | 69.78 |
test_proximal_jac_ripple | 1.04 % | 7.555e+03 | 7.634e+03 | 78.50 | 69.71 | 69.71 |
test_proximal_jac_ripple_spline | 0.69 % | 3.470e+03 | 3.493e+03 | 23.88 | 72.63 | 71.83 |
test_eq_solve | -3.09 % | 2.058e+03 | 1.995e+03 | -63.54 | 125.88 | 125.46 |For the memory plots, go to the summary of |
|
Does this resolve #1324 too? |
Yes |
unalmis
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Merge at will
|
Actually, this will likely change after #1923 (comment) |
docs/installation.rst
Outdated
| .. code-block:: sh | ||
|
|
||
| conda create -n desc-env python=3.12 | ||
| CONDA_OVERRIDE_CUDA="12.4" conda create --name desc-env "jax==0.6.0" "jaxlib==0.6.0=cuda12*" -c conda-forge |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The pip method yigit and I posted also seems to work fine, and I prefer that to conda-forge as it's what's more supported by the jax devs, and the conda-forge packages seem to be out of date (latest is 0.7.0, but 0.7.2 was released last week).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are subtleties that we realized an hour ago while fixing the installation. I'm updating the PR now.
|
@f0uriest could you test this? |
f0uriest
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tested these and they work for me but see notes below
docs/installation.rst
Outdated
| pip install -r devtools/dev-requirements.txt | ||
| pip install --no-cache-dir -r devtools/dev-requirements.txt | ||
| pip install --no-cache-dir --editable . | ||
| pip install --no-cache-dir "jax[cuda12]" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think we might still need a version pin here for future proofing to avoid accidentally installing an incompatible version? (or possibly on that doesn't play well with nersc)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I used this command in many different systems, and it seems to work consistently, even if there is a newer version of jax (for example, 0.7.2 is out but I use 0.6.2) it grabs the one installed for CPU before on the environment. If people change the order, like call pip install --no-cache-dir -r devtools/dev-requirements.txt after jax for GPU, it will definitely break and install back CPU (we experienced this with Rahul yesterday).
This is a guess but I think what is happening is that the difference for CPU GPU jax is not the jax itself but the jaxlib. So, if you already have jax (installed during CPU phase), the corresponding jaxlib checks the version of that, and they are in sync. Of course, people can include other packages or other steps which can mess it up. If there is a clean way to do this, I am fine (but let me say, I really don't like the sed command)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have to tried another version using this process?
Here's what I did. I set the max version of jax to <=0.6.1 in requirements.txt and repeated the exact same process, even making sure to use
pip install --no-cache-dir "jax[cuda12]==0.6.1" "jaxlib[cuda12]==0.6.1"
but now I get the following error:
E0925 15:05:46.285344 817786 pjrt_stream_executor_client.cc:2917] Execution of replica 0 failed: INTERNAL: jaxlib/gpu/solver_handle_pool.cc:37: operation gpusolverDnCreate(&handle) failed: cuSolver internal error
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/pscratch/sd/r/rgaur/DESC2/DESC/desc/equilibrium/__init__.py", line 3, in <module>
from .equilibrium import EquilibriaFamily, Equilibrium
File "/pscratch/sd/r/rgaur/DESC2/DESC/desc/equilibrium/equilibrium.py", line 16, in <module>
from desc.compute import compute as compute_fun
File "/pscratch/sd/r/rgaur/DESC2/DESC/desc/compute/__init__.py", line 30, in <module>
from . import (
File "/pscratch/sd/r/rgaur/DESC2/DESC/desc/compute/_bootstrap.py", line 16, in <module>
from ..integrals.surface_integral import surface_averages_map
File "/pscratch/sd/r/rgaur/DESC2/DESC/desc/integrals/__init__.py", line 3, in <module>
from .bounce_integral import Bounce1D, Bounce2D
File "/pscratch/sd/r/rgaur/DESC2/DESC/desc/integrals/bounce_integral.py", line 122, in <module>
leggauss(32),
^^^^^^^^^^^^
File "/global/homes/r/rgaur/.conda/envs/desc-env2/lib/python3.12/site-packages/orthax/legendre.py", line 1509, in leggauss
x = jnp.linalg.eigvalsh(m)
^^^^^^^^^^^^^^^^^^^^^^
jaxlib._jax.XlaRuntimeError: INTERNAL: jaxlib/gpu/solver_handle_pool.cc:37: operation gpusolverDnCreate(&handle) failed: cuSolver internal error
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about 0.6.0? Are you saying it should work if I repeat the same process with 0.6.0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As far as I remember this was specific to 0.6.1. So, yes, it should work without issue.
docs/installation.rst
Outdated
| Note that you may also need to execute `unset LD_LIBRARY_PATH` before starting a python process (e.g. execute this as part of your slurm script, before calling python to run DESC) for the JAX/CUDA initialization to work properly. | ||
| The `--no-cache-dir` avoids conflicts with existing DESC environments or other software that use CUDA on your system. | ||
|
|
||
| Before running a DESC script, you MUST also execute `unset LD_LIBRARY_PATH` either in your interactive node (for interactive jobs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried these instructions on both a login node and jupyter worker and didn't need this. Are you both sure it's necessary, and if so do we know why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I already had one DESC environment and I was trying to install another. To make sure the new environment works, I needed this command.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe the MUST part is too strict. I tried the instructions on my fresh Perlmutter account, so my pip cache was completely empty. I didn't need this when I was using compute node. But for Rahul, the older environments seemed to cause problems, and the only way to fix it was --no-cache-dir and unset LD_LIBRARY_PATH. We can make the LD_LIBRARY_PATH step optional (although it doesn't hurt to type).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried this both in a totally clean conda setup (no previous environments) and with other environments and in both cases I don't need it, with any recent versions of jax. It might be something specific to your situations? IE do you have any stuff in your bashrc or other codes installed?
I think we should move this to a "if you see this issue here's a possible fix" but I definitely think we shouldn't say this has to be done. Note that LD_LIBRARY_PATH exists outside of conda, and is used by lots of codes and libraries (iirc i had to mess with it when using sfincs, stride, dcon etc), so unsetting it may break other things people have installed, even if they create a fresh conda environment for desc, and we really don't want that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nothing in bashrc. No recent codes but I may have gx from an old installation.
Also, when you say you tested, what exactly are you doing? import desc and exit?
You should explain your test because I worry it's something trivial.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am happy to change the wording but my focus is to make it robust or atleast understand the problem so next time something like this happens we can create a stable solution quickly!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll update the wording I promise but there are other problems either with Perlmutter and/or jax.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
my usual test is 3 fold:
import jax; jax.devices("gpu")
do this before anything with desc, if it fails i know its a jax issue not a desc issue
from desc import set_device; set_device("gpu")
to see if desc introduces any further issues
python -m desc desc/examples/DSHAPE -vv -g
runs the dshape example on gpu, to make sure everything actually works
you can then verify with nvidia-smi that jax is running on gpu by looking at memory and usage
|
Actually, I was doing optimization and I see other anomalies. I modify the function errorif in desc/utils.py 7 def errorif(cond, err=ValueError, msg=""):
6 """Raise an error if condition is met.
5
4 Similar to assert but allows wider range of Error types, rather than
3 just AssertionError.
2 """
1 if cond:
1 try:
2 jax.debug.print(msg)
3 except:
4 pass
5 raise err(colored(msg, "red"))
and it prints
hundreds of times before and during optimization. Eventually, I get the following error But this error only occurs sometimes. |
This seems weird, if you run the same script on your local laptop do you get similar errors? i.e. is it a jax thing or just a thing with custom grids? |
|
The script that I am running on a GPU is memory intensive. Let me reduce it to a low-memory minimally failing case and then run it on my local system |
|
So if I repeat the exact same instructions but try to install "jax[cuda12]==0.6.1", DESC doesn't detect a GPU. |
This doesn't look like an installation issue to me. Are you sure this only happens on Perlmutter? |
|
I am talking about two different problems here. The first one, I can reproduce on my laptop with jax = jaxlib = 0.6.1. import numpy as np
from scipy.constants import elementary_charge, mu_0
from desc.backend import jnp
from desc.equilibrium import Equilibrium
from desc.grid import LinearGrid, QuadratureGrid
from desc.integrals import Bounce2D
from desc.magnetic_fields import (
OmnigenousField,
)
from desc.objectives import (
ForceBalance,
ObjectiveFunction,
Omnigenity,
)
from desc.geometry import FourierRZToroidalSurface
surf = FourierRZToroidalSurface.from_qp_model(
major_radius=1,
aspect_ratio=20,
elongation=6,
mirror_ratio=0.2,
torsion=0.1,
NFP=1,
sym=True,
)
eq = Equilibrium(Psi=6e-3, M=4, N=4, surface=surf)
eq, _ = eq.solve(objective="force", verbose=3)
field = OmnigenousField(
L_B=0,
M_B=2,
L_x=0,
M_x=0,
N_x=0,
NFP=eq.NFP,
helicity=(0, eq.NFP),
B_lm=np.array([0.8, 1.2]),
)
f = np.zeros_like(self.res_array, dtype=float)
res = 2
grid = LinearGrid(M=int(eq.M * res), N=int(eq.N * res), NFP=eq.NFP)
obj = ObjectiveFunction(
Omnigenity(eq=eq, field=field, eq_grid=grid, field_grid=grid)
)
obj.build(verbose=0)
f[i] = obj.compute_scalar(obj.x(eq, field))
np.testing.assert_allclose(f, f[-1], rtol=1e-3)The second one is specific to Perlmutter. I have created #1936 the first one. |
|
Here's what I did. I set the max version of jax to <=0.6.1 in requirements.txt and repeated the exact same process, even making sure to use If I repeat the same process with jax==0.6.2 I get not errors, atleast while loading things. @f0uriest is orthax the culprit here or does jax suck this bad? |
See response above #1924 (comment) |
dpanici
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rahulgaur104 Pin the working version (0.6.2), also I think this part should be kept within the dropdown box for Perlmutter (formatting nitpick):
The –no-cache-dir avoids conflicts with existing DESC environments or other software that use CUDA on your system.
Before running a DESC script, you MUST also execute unset LD_LIBRARY_PATH either in your interactive node (for interactive jobs) or in your SLURM script (for submitted jobs).
|
Can we wait to merge this for CI PR? |
dpanici
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wait for CI before merging to check #1946 worked
|
Looks like Codecov change in #1946 didn't work... |
Resolves PlasmaControl#1923 Resolves PlasmaControl#1324 --------- Co-authored-by: Yigit Gunsur Elmacioglu <102380275+YigitElma@users.noreply.github.com> Co-authored-by: Dario Panici <37969854+dpanici@users.noreply.github.com>

Resolves #1923
Resolves #1324