Skip to content

Conversation

@rahulgaur104
Copy link
Collaborator

@rahulgaur104 rahulgaur104 commented Sep 19, 2025

Resolves #1923
Resolves #1324

@github-actions
Copy link
Contributor

github-actions bot commented Sep 19, 2025

Memory benchmark result

|               Test Name                |      %Δ      |    Master (MB)     |      PR (MB)       |    Δ (MB)    |    Time PR (s)     |  Time Master (s)   |
| -------------------------------------- | ------------ | ------------------ | ------------------ | ------------ | ------------------ | ------------------ |
  test_objective_jac_w7x                 |    1.07 %    |     3.992e+03      |     4.034e+03      |    42.69     |       35.06        |       31.65        |
  test_proximal_jac_w7x_with_eq_update   |   -1.25 %    |     6.834e+03      |     6.749e+03      |    -85.56    |       159.33       |       159.80       |
  test_proximal_freeb_jac                |   -0.27 %    |     1.321e+04      |     1.318e+04      |    -35.44    |       78.51        |       78.13        |
  test_proximal_freeb_jac_blocked        |   -1.05 %    |     7.668e+03      |     7.587e+03      |    -80.77    |       68.38        |       69.08        |
  test_proximal_freeb_jac_batched        |   -0.55 %    |     7.605e+03      |     7.563e+03      |    -41.64    |       69.20        |       69.78        |
  test_proximal_jac_ripple               |    1.04 %    |     7.555e+03      |     7.634e+03      |    78.50     |       69.71        |       69.71        |
  test_proximal_jac_ripple_spline        |    0.69 %    |     3.470e+03      |     3.493e+03      |    23.88     |       72.63        |       71.83        |
  test_eq_solve                          |   -3.09 %    |     2.058e+03      |     1.995e+03      |    -63.54    |       125.88       |       125.46       |

For the memory plots, go to the summary of Memory Benchmarks workflow and download the artifact.

@unalmis
Copy link
Collaborator

unalmis commented Sep 19, 2025

Does this resolve #1324 too?

@rahulgaur104 rahulgaur104 self-assigned this Sep 20, 2025
@rahulgaur104
Copy link
Collaborator Author

Does this resolve #1324 too?

Yes

@rahulgaur104 rahulgaur104 added the run_benchmarks Run timing benchmarks on this PR against current master branch label Sep 20, 2025
@rahulgaur104 rahulgaur104 mentioned this pull request Sep 20, 2025
6 tasks
unalmis
unalmis previously approved these changes Sep 23, 2025
Copy link
Collaborator

@unalmis unalmis left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merge at will

@rahulgaur104
Copy link
Collaborator Author

Actually, this will likely change after #1923 (comment)

.. code-block:: sh

conda create -n desc-env python=3.12
CONDA_OVERRIDE_CUDA="12.4" conda create --name desc-env "jax==0.6.0" "jaxlib==0.6.0=cuda12*" -c conda-forge
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pip method yigit and I posted also seems to work fine, and I prefer that to conda-forge as it's what's more supported by the jax devs, and the conda-forge packages seem to be out of date (latest is 0.7.0, but 0.7.2 was released last week).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are subtleties that we realized an hour ago while fixing the installation. I'm updating the PR now.

@rahulgaur104 rahulgaur104 requested review from a team, YigitElma, ddudt, dpanici, f0uriest and unalmis and removed request for a team September 24, 2025 21:37
YigitElma
YigitElma previously approved these changes Sep 24, 2025
@rahulgaur104
Copy link
Collaborator Author

@f0uriest could you test this?

Copy link
Member

@f0uriest f0uriest left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tested these and they work for me but see notes below

pip install -r devtools/dev-requirements.txt
pip install --no-cache-dir -r devtools/dev-requirements.txt
pip install --no-cache-dir --editable .
pip install --no-cache-dir "jax[cuda12]"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we might still need a version pin here for future proofing to avoid accidentally installing an incompatible version? (or possibly on that doesn't play well with nersc)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used this command in many different systems, and it seems to work consistently, even if there is a newer version of jax (for example, 0.7.2 is out but I use 0.6.2) it grabs the one installed for CPU before on the environment. If people change the order, like call pip install --no-cache-dir -r devtools/dev-requirements.txt after jax for GPU, it will definitely break and install back CPU (we experienced this with Rahul yesterday).

This is a guess but I think what is happening is that the difference for CPU GPU jax is not the jax itself but the jaxlib. So, if you already have jax (installed during CPU phase), the corresponding jaxlib checks the version of that, and they are in sync. Of course, people can include other packages or other steps which can mess it up. If there is a clean way to do this, I am fine (but let me say, I really don't like the sed command)

Copy link
Collaborator Author

@rahulgaur104 rahulgaur104 Sep 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have to tried another version using this process?

Here's what I did. I set the max version of jax to <=0.6.1 in requirements.txt and repeated the exact same process, even making sure to use
pip install --no-cache-dir "jax[cuda12]==0.6.1" "jaxlib[cuda12]==0.6.1"
but now I get the following error:

E0925 15:05:46.285344  817786 pjrt_stream_executor_client.cc:2917] Execution of replica 0 failed: INTERNAL: jaxlib/gpu/solver_handle_pool.cc:37: operation gpusolverDnCreate(&handle) failed: cuSolver internal error
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/pscratch/sd/r/rgaur/DESC2/DESC/desc/equilibrium/__init__.py", line 3, in <module>
    from .equilibrium import EquilibriaFamily, Equilibrium
  File "/pscratch/sd/r/rgaur/DESC2/DESC/desc/equilibrium/equilibrium.py", line 16, in <module>
    from desc.compute import compute as compute_fun
  File "/pscratch/sd/r/rgaur/DESC2/DESC/desc/compute/__init__.py", line 30, in <module>
    from . import (
  File "/pscratch/sd/r/rgaur/DESC2/DESC/desc/compute/_bootstrap.py", line 16, in <module>
    from ..integrals.surface_integral import surface_averages_map
  File "/pscratch/sd/r/rgaur/DESC2/DESC/desc/integrals/__init__.py", line 3, in <module>
    from .bounce_integral import Bounce1D, Bounce2D
  File "/pscratch/sd/r/rgaur/DESC2/DESC/desc/integrals/bounce_integral.py", line 122, in <module>
    leggauss(32),
    ^^^^^^^^^^^^
  File "/global/homes/r/rgaur/.conda/envs/desc-env2/lib/python3.12/site-packages/orthax/legendre.py", line 1509, in leggauss
    x = jnp.linalg.eigvalsh(m)
        ^^^^^^^^^^^^^^^^^^^^^^
jaxlib._jax.XlaRuntimeError: INTERNAL: jaxlib/gpu/solver_handle_pool.cc:37: operation gpusolverDnCreate(&handle) failed: cuSolver internal error
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have the solution for that in docs.
image

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about 0.6.0? Are you saying it should work if I repeat the same process with 0.6.0?

Copy link
Collaborator

@YigitElma YigitElma Sep 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I remember this was specific to 0.6.1. So, yes, it should work without issue.

Note that you may also need to execute `unset LD_LIBRARY_PATH` before starting a python process (e.g. execute this as part of your slurm script, before calling python to run DESC) for the JAX/CUDA initialization to work properly.
The `--no-cache-dir` avoids conflicts with existing DESC environments or other software that use CUDA on your system.

Before running a DESC script, you MUST also execute `unset LD_LIBRARY_PATH` either in your interactive node (for interactive jobs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried these instructions on both a login node and jupyter worker and didn't need this. Are you both sure it's necessary, and if so do we know why?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I already had one DESC environment and I was trying to install another. To make sure the new environment works, I needed this command.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe the MUST part is too strict. I tried the instructions on my fresh Perlmutter account, so my pip cache was completely empty. I didn't need this when I was using compute node. But for Rahul, the older environments seemed to cause problems, and the only way to fix it was --no-cache-dir and unset LD_LIBRARY_PATH. We can make the LD_LIBRARY_PATH step optional (although it doesn't hurt to type).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried this both in a totally clean conda setup (no previous environments) and with other environments and in both cases I don't need it, with any recent versions of jax. It might be something specific to your situations? IE do you have any stuff in your bashrc or other codes installed?

I think we should move this to a "if you see this issue here's a possible fix" but I definitely think we shouldn't say this has to be done. Note that LD_LIBRARY_PATH exists outside of conda, and is used by lots of codes and libraries (iirc i had to mess with it when using sfincs, stride, dcon etc), so unsetting it may break other things people have installed, even if they create a fresh conda environment for desc, and we really don't want that.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nothing in bashrc. No recent codes but I may have gx from an old installation.

Also, when you say you tested, what exactly are you doing? import desc and exit?
You should explain your test because I worry it's something trivial.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am happy to change the wording but my focus is to make it robust or atleast understand the problem so next time something like this happens we can create a stable solution quickly!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll update the wording I promise but there are other problems either with Perlmutter and/or jax.

Copy link
Member

@f0uriest f0uriest Sep 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my usual test is 3 fold:

import jax; jax.devices("gpu")

do this before anything with desc, if it fails i know its a jax issue not a desc issue

from desc import set_device; set_device("gpu")

to see if desc introduces any further issues

python -m desc desc/examples/DSHAPE -vv -g

runs the dshape example on gpu, to make sure everything actually works

you can then verify with nvidia-smi that jax is running on gpu by looking at memory and usage

@rahulgaur104
Copy link
Collaborator Author

Actually, I was doing optimization and I see other anomalies. I modify the function errorif in desc/utils.py

   7 def errorif(cond, err=ValueError, msg=""):
   6     """Raise an error if condition is met.
   5 
   4     Similar to assert but allows wider range of Error types, rather than
   3     just AssertionError.
   2     """
   1     if cond:
   1         try:
   2             jax.debug.print(msg)
   3         except:
   4             pass
   5         raise err(colored(msg, "red"))
   

and it prints

Grid does not have unique indices assigned. It is not possible to do this automatically on grids made under JIT.

hundreds of times before and during optimization. Eventually, I get the following error

    5   File "/pscratch/sd/r/rgaur/DESC2/DESC/desc/objectives/_omnigenity.py", line 679, in build
   4     errorif(
   3   File "/pscratch/sd/r/rgaur/DESC2/DESC/desc/utils.py", line 564, in errorif
   2     raise err(colored(msg, "red"))
   1           ^^^^^^^^^^^^^^^^^^^^^^^^
   0  TypeError: 'str' object is not callable

But this error only occurs sometimes.

@dpanici
Copy link
Collaborator

dpanici commented Sep 25, 2025

Actually, I was doing optimization and I see other anomalies. I modify the function errorif in desc/utils.py

   7 def errorif(cond, err=ValueError, msg=""):
   6     """Raise an error if condition is met.
   5 
   4     Similar to assert but allows wider range of Error types, rather than
   3     just AssertionError.
   2     """
   1     if cond:
   1         try:
   2             jax.debug.print(msg)
   3         except:
   4             pass
   5         raise err(colored(msg, "red"))
   

and it prints

Grid does not have unique indices assigned. It is not possible to do this automatically on grids made under JIT.

hundreds of times before and during optimization. Eventually, I get the following error

    5   File "/pscratch/sd/r/rgaur/DESC2/DESC/desc/objectives/_omnigenity.py", line 679, in build
   4     errorif(
   3   File "/pscratch/sd/r/rgaur/DESC2/DESC/desc/utils.py", line 564, in errorif
   2     raise err(colored(msg, "red"))
   1           ^^^^^^^^^^^^^^^^^^^^^^^^
   0  TypeError: 'str' object is not callable

But this error only occurs sometimes.

This seems weird, if you run the same script on your local laptop do you get similar errors? i.e. is it a jax thing or just a thing with custom grids?

@rahulgaur104
Copy link
Collaborator Author

The script that I am running on a GPU is memory intensive. Let me reduce it to a low-memory minimally failing case and then run it on my local system

@rahulgaur104
Copy link
Collaborator Author

So if I repeat the exact same instructions but try to install "jax[cuda12]==0.6.1", DESC doesn't detect a GPU.
This whole thing is a goddamn mess!
Starting from scratch....

@YigitElma
Copy link
Collaborator

Grid does not have unique indices assigned. It is not possible to do this automatically on grids made under JIT.

This doesn't look like an installation issue to me. Are you sure this only happens on Perlmutter?

@rahulgaur104
Copy link
Collaborator Author

rahulgaur104 commented Sep 25, 2025

I am talking about two different problems here. The first one, I can reproduce on my laptop with jax = jaxlib = 0.6.1.
Here'a MWE you can run after changing desc/utils.py as described above to reproduce the error.

import numpy as np
from scipy.constants import elementary_charge, mu_0
from desc.backend import jnp
from desc.equilibrium import Equilibrium
from desc.grid import LinearGrid, QuadratureGrid
from desc.integrals import Bounce2D
from desc.magnetic_fields import (
    OmnigenousField,
)
from desc.objectives import (
    ForceBalance,
    ObjectiveFunction,
    Omnigenity,
)
from desc.geometry import FourierRZToroidalSurface

surf = FourierRZToroidalSurface.from_qp_model(
            major_radius=1,
            aspect_ratio=20,
            elongation=6,
            mirror_ratio=0.2,
            torsion=0.1,
            NFP=1,
            sym=True,
        )
eq = Equilibrium(Psi=6e-3, M=4, N=4, surface=surf)
eq, _ = eq.solve(objective="force", verbose=3)
field = OmnigenousField(
    L_B=0,
    M_B=2,
    L_x=0,
    M_x=0,
    N_x=0,
    NFP=eq.NFP,
    helicity=(0, eq.NFP),
    B_lm=np.array([0.8, 1.2]),
)
f = np.zeros_like(self.res_array, dtype=float)

res = 2

grid = LinearGrid(M=int(eq.M * res), N=int(eq.N * res), NFP=eq.NFP)
obj = ObjectiveFunction(
    Omnigenity(eq=eq, field=field, eq_grid=grid, field_grid=grid)
)
obj.build(verbose=0)
f[i] = obj.compute_scalar(obj.x(eq, field))
np.testing.assert_allclose(f, f[-1], rtol=1e-3)

The second one is specific to Perlmutter.

I have created #1936 the first one.

@rahulgaur104
Copy link
Collaborator Author

rahulgaur104 commented Sep 25, 2025

Here's what I did. I set the max version of jax to <=0.6.1 in requirements.txt and repeated the exact same process, even making sure to use
pip install --no-cache-dir "jax[cuda12]==0.6.1" "jaxlib[cuda12]==0.6.1"
but now I get the following error:

E0925 15:05:46.285344  817786 pjrt_stream_executor_client.cc:2917] Execution of replica 0 failed: INTERNAL: jaxlib/gpu/solver_handle_pool.cc:37: operation gpusolverDnCreate(&handle) failed: cuSolver internal error
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/pscratch/sd/r/rgaur/DESC2/DESC/desc/equilibrium/__init__.py", line 3, in <module>
    from .equilibrium import EquilibriaFamily, Equilibrium
  File "/pscratch/sd/r/rgaur/DESC2/DESC/desc/equilibrium/equilibrium.py", line 16, in <module>
    from desc.compute import compute as compute_fun
  File "/pscratch/sd/r/rgaur/DESC2/DESC/desc/compute/__init__.py", line 30, in <module>
    from . import (
  File "/pscratch/sd/r/rgaur/DESC2/DESC/desc/compute/_bootstrap.py", line 16, in <module>
    from ..integrals.surface_integral import surface_averages_map
  File "/pscratch/sd/r/rgaur/DESC2/DESC/desc/integrals/__init__.py", line 3, in <module>
    from .bounce_integral import Bounce1D, Bounce2D
  File "/pscratch/sd/r/rgaur/DESC2/DESC/desc/integrals/bounce_integral.py", line 122, in <module>
    leggauss(32),
    ^^^^^^^^^^^^
  File "/global/homes/r/rgaur/.conda/envs/desc-env2/lib/python3.12/site-packages/orthax/legendre.py", line 1509, in leggauss
    x = jnp.linalg.eigvalsh(m)
        ^^^^^^^^^^^^^^^^^^^^^^
jaxlib._jax.XlaRuntimeError: INTERNAL: jaxlib/gpu/solver_handle_pool.cc:37: operation gpusolverDnCreate(&handle) failed: cuSolver internal error
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

If I repeat the same process with jax==0.6.2 I get not errors, atleast while loading things.

@f0uriest is orthax the culprit here or does jax suck this bad?

@YigitElma
Copy link
Collaborator

Here's what I did. I set the max version of jax to <=0.6.1 in requirements.txt and repeated the exact same process, even making sure to use
pip install --no-cache-dir "jax[cuda12]==0.6.1" "jaxlib[cuda12]==0.6.1"
but now I get the following error:

See response above #1924 (comment)

@YigitElma YigitElma dismissed their stale review September 29, 2025 17:16

waiting updates

Copy link
Collaborator

@dpanici dpanici left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rahulgaur104 Pin the working version (0.6.2), also I think this part should be kept within the dropdown box for Perlmutter (formatting nitpick):

The –no-cache-dir avoids conflicts with existing DESC environments or other software that use CUDA on your system.

Before running a DESC script, you MUST also execute unset LD_LIBRARY_PATH either in your interactive node (for interactive jobs) or in your SLURM script (for submitted jobs).

@YigitElma
Copy link
Collaborator

Can we wait to merge this for CI PR?

Copy link
Collaborator

@dpanici dpanici left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait for CI before merging to check #1946 worked

@YigitElma YigitElma added the only-docs-comments Don't run workflows if the changes are only on the comments label Oct 1, 2025
@YigitElma
Copy link
Collaborator

Looks like Codecov change in #1946 didn't work...

@YigitElma YigitElma merged commit 403baf1 into master Oct 1, 2025
26 checks passed
@YigitElma YigitElma deleted the rg/perlmutter branch October 1, 2025 17:58
DMCXE pushed a commit to DMCXE/DESC-OOPS that referenced this pull request Oct 14, 2025
Resolves PlasmaControl#1923 
Resolves PlasmaControl#1324

---------

Co-authored-by: Yigit Gunsur Elmacioglu <102380275+YigitElma@users.noreply.github.com>
Co-authored-by: Dario Panici <37969854+dpanici@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

only-docs-comments Don't run workflows if the changes are only on the comments run_benchmarks Run timing benchmarks on this PR against current master branch

Projects

None yet

Development

Successfully merging this pull request may close these issues.

DESC installation on Perlmutter doesn't work Update the installation instructions for RAVEN and NERSC/Perlmutter

6 participants