Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
159 commits
Select commit Hold shift + click to select a range
45693b9
Merge changes from 1360
unalmis Jul 29, 2025
479c497
Add NFP warning to eq.compute
unalmis Jul 30, 2025
7181f79
Merge branch 'ku/NFP' into ku/partialsum
unalmis Jul 30, 2025
a85895a
first pass at partial sum
unalmis Jul 30, 2025
c390d07
Merge branch 'master' into ku/partialsum
unalmis Jul 30, 2025
c2a0e09
working commit
unalmis Jul 31, 2025
c12c884
Merge branch 'master' into ku/partialsum
unalmis Jul 31, 2025
9e87fd1
Remove old static attributes
unalmis Jul 31, 2025
3151bc7
partial sum pass two
unalmis Jul 31, 2025
c1cab79
Reduce resolution
unalmis Jul 31, 2025
7dc7916
Updated notebook
unalmis Jul 31, 2025
1bef805
Dummy wrapper to avoid circular import
unalmis Jul 31, 2025
f22e7d4
Update _fast_ion.py
unalmis Jul 31, 2025
39d1912
Cast to array first
unalmis Jul 31, 2025
8558e89
Remove deprecated code
unalmis Jul 31, 2025
b3a872c
Merge branch 'master' into ku/partialsum
unalmis Aug 2, 2025
c95b6ab
Add tests for NUFFTS
unalmis Aug 3, 2025
0062d6a
Add nuft2 vectorized
unalmis Aug 4, 2025
b961570
Merge branch 'master' into ku/partialsum
unalmis Aug 4, 2025
48fe226
Add tests for rfft2 to nufft transform
unalmis Aug 5, 2025
4344e6f
Merge branch 'ku/partialsum' into ku/nufft
unalmis Aug 5, 2025
66a48c2
Update changelog
unalmis Aug 5, 2025
b52ad96
Add failing test to show nufft developers
unalmis Aug 5, 2025
e9cad5d
Resolve some todos
unalmis Aug 6, 2025
f2853be
Fix typo in comment
unalmis Aug 6, 2025
546aee6
Update root-finding
unalmis Aug 6, 2025
c985023
Pull changes from ku/nufft
unalmis Aug 6, 2025
e9e4cd3
Merge branch 'ku/partialsum' into ku/nufft
unalmis Aug 6, 2025
64511a7
Remove comments
unalmis Aug 6, 2025
5543fe9
use partial summation approach in bounce2d.compute_theta
unalmis Aug 6, 2025
bdd174a
Merge branch 'master' into ku/partialsum
unalmis Aug 7, 2025
ad505ad
Merge branch 'ku/partialsum' into ku/nufft
unalmis Aug 7, 2025
39611e4
small optimization
unalmis Aug 7, 2025
a97b38e
small edits
unalmis Aug 7, 2025
8e3702a
working attempt at nuffts
unalmis Aug 8, 2025
11a3bf3
clean up
unalmis Aug 8, 2025
7c59869
Add comment
unalmis Aug 8, 2025
7ad892a
clean up logic
unalmis Aug 8, 2025
5d743f9
update comment
unalmis Aug 8, 2025
270f4e3
upload notebook to see if diff
unalmis Aug 8, 2025
32583cc
Adding tests
unalmis Aug 9, 2025
4ff85e8
Updating eps setting
unalmis Aug 9, 2025
996b58d
updating tests
unalmis Aug 10, 2025
90dd6ea
adding warnings for AD bug
unalmis Aug 10, 2025
d70a06b
running notebook with nufft compute
unalmis Aug 10, 2025
7348dff
add back important comment
unalmis Aug 10, 2025
15e8fc6
clarify code
unalmis Aug 10, 2025
9c44d07
clean up coordinate map
unalmis Aug 10, 2025
b4e9b9a
progress
unalmis Aug 10, 2025
6d3246e
Updating
unalmis Aug 11, 2025
74f2b12
Merge branch 'master' into ku/partialsum
unalmis Aug 11, 2025
a7ba43a
update plots
unalmis Aug 11, 2025
a83de07
add assert statement
unalmis Aug 11, 2025
c6f4523
change comment
unalmis Aug 11, 2025
9065461
Reduce tolerance
unalmis Aug 12, 2025
31a399d
Add links to issues
unalmis Aug 12, 2025
44d68b6
Update comment
unalmis Aug 12, 2025
8a9a482
Resolves #1574
unalmis Aug 13, 2025
03174ba
Merge branch 'master' into ku/partialsum
unalmis Aug 13, 2025
4c5fd6b
Merge branch 'ku/partialsum' into ku/nufft
unalmis Aug 13, 2025
dc1d21c
last commit
unalmis Aug 13, 2025
9341e97
Remove temp variables to help garbarge collection
unalmis Aug 13, 2025
7be33eb
dario comment suggestion
unalmis Aug 13, 2025
775bdf1
Pulling changes down from #1834 which are necessary to address @f0uri…
unalmis Aug 13, 2025
b919f01
Merge branch 'ku/partialsum' into ku/nufft
unalmis Aug 13, 2025
d5ff809
Add comment to address Rory comment
unalmis Aug 13, 2025
29dc77e
Merge branch 'ku/partialsum' into ku/nufft
unalmis Aug 14, 2025
2dc4bda
Add more warnings
unalmis Aug 14, 2025
4d9962e
clean up warnings
unalmis Aug 14, 2025
16dff27
fixing doc build
unalmis Aug 14, 2025
2b20f8d
Changing variable name for Rahul
unalmis Aug 14, 2025
4156e12
Merge branch 'ku/partialsum' into ku/nufft
unalmis Aug 14, 2025
fd2bc47
Add comments
unalmis Aug 14, 2025
b58e34e
Merge branch 'master' into ku/nufft
unalmis Aug 14, 2025
bd12d2e
Fixing stuff from 1529
unalmis Aug 15, 2025
01ee90e
Fixing stuff from 1529
unalmis Aug 15, 2025
41e1f7e
Removing NUFFTs from previous commit
unalmis Aug 15, 2025
eae25c9
Ensure gamma_c in [-1, 1]
unalmis Aug 15, 2025
0f46db3
Ku/real nufft (#1857)
unalmis Aug 15, 2025
fcba0a4
Reducing tolerance
unalmis Aug 15, 2025
79b6d3d
clean conditional
unalmis Aug 15, 2025
86a38c0
Add important fftfreq test
unalmis Aug 15, 2025
dbc22bf
Change code to increase code coverage
unalmis Aug 17, 2025
afb552b
Preparing for merge
unalmis Aug 17, 2025
15f4d96
Merge branch 'ku/fix_1529' into ku/nufft
unalmis Aug 17, 2025
51cb599
Clean up verbosity
unalmis Aug 17, 2025
314f852
Update _interp_utils.py
unalmis Aug 17, 2025
f628399
Making requested change to assume nufft2 is always used on real input
unalmis Aug 18, 2025
2921171
fix docstring
unalmis Aug 18, 2025
7d88e81
fix benign bug
unalmis Aug 18, 2025
aafc2f4
Update default resolutions
unalmis Aug 18, 2025
30fb2a9
Merge branch 'ku/fix_1529' into ku/nufft
unalmis Aug 18, 2025
340bde2
Update plot
unalmis Aug 18, 2025
d67269c
fix comment
unalmis Aug 18, 2025
cf1dba1
update label
unalmis Aug 18, 2025
8fcf3cd
Review requests
unalmis Aug 18, 2025
48bbc27
Merge branch 'master' into ku/nufft
unalmis Aug 20, 2025
9e1b9e7
Merge branch 'master' into ku/nufft
unalmis Aug 20, 2025
6a5dece
Add derivative tests
unalmis Aug 20, 2025
4e7b964
Make names consistent with recently officialized finufft tutorial
unalmis Aug 20, 2025
4c02276
Reduce memory in test
unalmis Aug 20, 2025
17a4ff7
fix typo
unalmis Aug 20, 2025
bde9479
Merge branch 'master' into ku/nufft
unalmis Aug 22, 2025
b6596f9
Merge
unalmis Aug 22, 2025
5044d04
Add comment to docstring
unalmis Aug 23, 2025
01de804
patching jax-finufft bug
unalmis Aug 23, 2025
3b182a3
Upload nufft optimization notebook
unalmis Aug 23, 2025
c052d4e
Reduce number of reshape/transpose to prepare for any unjitting requi…
unalmis Aug 24, 2025
aed8897
Add another test
unalmis Aug 24, 2025
946dcac
Update benchmarks
unalmis Aug 24, 2025
741e044
Fix benchmarks
unalmis Aug 24, 2025
154ac4e
Fixing benchmarks
unalmis Aug 24, 2025
935d747
STill trying to fix these dumb and useless benchmarks
unalmis Aug 24, 2025
8fd0106
Merge branch 'master' into ku/nufft
unalmis Aug 25, 2025
6f770b0
Merge branch 'master' into ku/nufft
unalmis Aug 25, 2025
14dca06
Ku/finufft gpu (#1881)
unalmis Aug 27, 2025
1090f0b
Update CHANGELOG.md
unalmis Aug 27, 2025
7347a90
Merge branch 'master' into ku/nufft
unalmis Aug 28, 2025
d4d5d24
Dancing to appease other developers
unalmis Aug 28, 2025
66fc063
Update benchmarks
unalmis Aug 28, 2025
b0b4678
fix link in docs
unalmis Aug 28, 2025
eaa6e4f
More dancing to appease developers
unalmis Aug 28, 2025
99a1342
Increasing codecov
unalmis Aug 29, 2025
2e3e964
Merge branch 'master' into ku/nufft
unalmis Sep 2, 2025
329e8f2
Fix import warning (thanks to @jlabbate15 for reporting issue)
unalmis Sep 3, 2025
cef94ee
fix thing from last commit
unalmis Sep 3, 2025
8ea6daf
Making @dpanici requested change
unalmis Sep 3, 2025
71f9238
Merge branch 'master' into ku/nufft
unalmis Sep 3, 2025
4ab768e
Merge branch 'master' into ku/nufft
unalmis Sep 4, 2025
46201ff
Merge branch 'master' into ku/nufft
unalmis Sep 4, 2025
10a5a0d
Fix typo in commen
unalmis Sep 6, 2025
e0c2749
Merge branch 'master' into ku/nufft
unalmis Sep 6, 2025
fdb28a1
Merge branch 'master' into ku/nufft
unalmis Sep 9, 2025
aa8a30b
;sadfj;ladsf;af
unalmis Sep 9, 2025
1a8ecd0
;kafjs
unalmis Sep 9, 2025
d1d34c3
Upgrade minimum jax-finufft version to Sep 9 pypi release
unalmis Sep 10, 2025
664f3f4
Same as previous commit
unalmis Sep 10, 2025
bd8b131
Revert previous commit
unalmis Sep 10, 2025
315b457
Update warning for rory
unalmis Sep 10, 2025
a6dd47b
Merge branch 'master' into ku/nufft
unalmis Sep 10, 2025
b7169a5
Add installation test for jax-finufft
unalmis Sep 10, 2025
9f936e6
Resolve https://github.com/PlasmaControl/DESC/pull/1834#discussion_r2…
unalmis Sep 10, 2025
19a1f3c
Ku/nufft notebook (#1906)
unalmis Sep 10, 2025
8f6070e
Remove recommendation to set jac chunk size to 1.
unalmis Sep 11, 2025
49b4cdc
review dog line length
unalmis Sep 11, 2025
5bb3c6d
Merge branch 'master' into ku/nufft
unalmis Sep 11, 2025
af09d7d
Merge branch 'master' into ku/nufft
unalmis Sep 18, 2025
5327079
Merge branch 'master' into ku/nufft
unalmis Sep 23, 2025
8df9c5a
Merge branch 'master' into ku/nufft
unalmis Sep 24, 2025
a003709
Merge branch 'master' into ku/nufft
unalmis Sep 25, 2025
dba0d4f
fix jax test accidentally removing jax-finnuft
dpanici Sep 27, 2025
ee491f0
Proper method to resolve #1938 and address @dpanici review request
unalmis Sep 27, 2025
9f47fd9
Add info for perlmutter as well
unalmis Sep 27, 2025
792ff42
Merge branch 'master' into ku/nufft
dpanici Sep 30, 2025
7c06848
Merge branch 'master' into ku/nufft
unalmis Oct 2, 2025
499e6b0
Merge branch 'master' into ku/nufft
unalmis Oct 2, 2025
488dd91
Merge branch 'master' into ku/nufft
unalmis Oct 6, 2025
35c4481
Merge branch 'master' into ku/nufft
dpanici Oct 7, 2025
291a1bb
Merge branch 'master' into ku/nufft
unalmis Oct 8, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions .github/workflows/jax_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,7 @@ jobs:
strategy:
fail-fast: false
matrix:
jax-version: [0.4.29, 0.4.30, 0.4.31, 0.4.33, 0.4.34, 0.4.35, 0.4.37,
0.4.38, 0.5.0, 0.5.3, 0.6.0, 0.6.1, 0.6.2, 0.7.2]
# 0.4.32 is not available on PyPI
# 0.4.36 has a bug that causes tests to fail
jax-version: [0.5.0, 0.5.3, 0.6.0, 0.6.1, 0.6.2, 0.7.2]
# 0.5.1 and 0.5.2 installations are broken, see jax#26781
# 0.7.0 and 0.7.1 have performance issues, see diffrax#680
group: [1, 2]
Expand All @@ -30,7 +27,7 @@ jobs:
python -m pip install --upgrade pip
- name: Install dependencies with given JAX version
run: |
sed -i '/jax/d' ./requirements.txt
sed -i '1{/^jax/d}' requirements.txt
sed -i '1i\jax[cpu] == ${{ matrix.jax-version }}' ./requirements.txt
cat ./requirements.txt
pip install -r ./devtools/dev-requirements.txt
Expand Down
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ Backend
- ``desc.equilibrium.Equilibrium.set_initial_guess`` now sets lambda to zero for most use cases, and the docstring has been updated to be more explicit on what is done in each case.


Performance Improvements

- [Partial summation in coordinate mapping](https://github.com/PlasmaControl/DESC/pull/1826).
- [NUFFTS](https://github.com/PlasmaControl/DESC/pull/1834) are now used by default for computing bounce integrals.


v0.15.0
-------

Expand Down
15 changes: 5 additions & 10 deletions desc/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
)

from desc.backend import jax, jnp, scan, vmap
from desc.utils import errorif
from desc.utils import errorif, identity

try:
from jax.extend import linear_util as lu
Expand Down Expand Up @@ -70,11 +70,6 @@ def _batch_and_remainder(x, batch_size: int):
return scan_tree, remainder_tree


def _identity(y):
"""Returns the input."""
return y


_unchunk = partial(tree_map, lambda y: y.reshape(-1, *y.shape[2:]))
_concat = partial(tree_map, lambda y1, y2: jnp.concatenate((y1, y2)))
_get_first_chunk = partial(tree_map, lambda x: x[0])
Expand Down Expand Up @@ -106,7 +101,7 @@ def body(carry, x):
return result


def _scanmap(fun, argnums=0, reduction=None, chunk_reduction=_identity):
def _scanmap(fun, argnums=0, reduction=None, chunk_reduction=identity):
"""A helper function to wrap f with a scan_fun.

Refrences
Expand Down Expand Up @@ -141,7 +136,7 @@ def _evaluate_in_chunks(
chunk_size,
argnums,
reduction=None,
chunk_reduction=_identity,
chunk_reduction=identity,
*args,
**kwargs,
):
Expand Down Expand Up @@ -192,7 +187,7 @@ def vmap_chunked(
*,
chunk_size=None,
reduction=None,
chunk_reduction=_identity,
chunk_reduction=identity,
):
"""Behaves like ``vmap`` but uses scan to chunk the computations in smaller chunks.

Expand Down Expand Up @@ -232,7 +227,7 @@ def vmap_chunked(


def batch_map(
fun, fun_input, /, batch_size=None, *, reduction=None, chunk_reduction=_identity
fun, fun_input, /, batch_size=None, *, reduction=None, chunk_reduction=identity
):
"""Compute ``chunk_reduction(fun(fun_input))`` in batches.

Expand Down
1 change: 1 addition & 0 deletions desc/compute/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
_core,
_curve,
_equil,
_fast_ion,
_field,
_geometry,
_metric,
Expand Down
42 changes: 32 additions & 10 deletions desc/compute/_fast_ion.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def _drift2(data, B, pitch):
"num_pitch",
"pitch_batch_size",
"surf_batch_size",
"nufft_eps",
"spline",
],
)
Expand Down Expand Up @@ -135,7 +136,6 @@ def _Gamma_c(params, transforms, profiles, data, **kwargs):
assert (
surf_batch_size == 1 or pitch_batch_size is None
), f"Expected pitch_batch_size to be None, got {pitch_batch_size}."
spline = kwargs.get("spline", True)
fl_quad = (
kwargs["fieldline_quad"] if "fieldline_quad" in kwargs else leggauss(Y_B // 2)
)
Expand All @@ -147,6 +147,9 @@ def _Gamma_c(params, transforms, profiles, data, **kwargs):
(automorphism_sin, grad_automorphism_sin),
)
)
nufft_eps = kwargs.get("nufft_eps", 1e-7)
spline = kwargs.get("spline", True)
vander = kwargs.get("_vander", None)

def Gamma_c(data):
bounce = Bounce2D(
Expand All @@ -157,8 +160,10 @@ def Gamma_c(data):
alpha,
num_transit,
quad,
nufft_eps=nufft_eps,
is_fourier=True,
spline=spline,
vander=vander,
)

def fun(pitch_inv):
Expand All @@ -169,6 +174,7 @@ def fun(pitch_inv):
data,
["|grad(psi)|*kappa_g", "|B|_r|v,p", "K"],
points,
nufft_eps=nufft_eps,
is_fourier=True,
)
# This is γ_c π/2.
Expand All @@ -177,7 +183,10 @@ def fun(pitch_inv):
drift1,
drift2
* bounce.interp_to_argmin(
data["|grad(rho)|*|e_alpha|r,p|"], points, is_fourier=True
data["|grad(rho)|*|e_alpha|r,p|"],
points,
nufft_eps=nufft_eps,
is_fourier=True,
),
)
)
Expand All @@ -188,14 +197,14 @@ def fun(pitch_inv):
* data["pitch_inv weight"]
/ data["pitch_inv"] ** 2,
axis=-1,
) / (bounce.compute_fieldline_length(fl_quad) * 2**1.5 * jnp.pi)
) / (bounce.compute_fieldline_length(fl_quad, vander) * 2**1.5 * jnp.pi)

# It is assumed the grid is sufficiently dense to reconstruct |B|,
# so anything smoother than |B| may be captured accurately as a single
# Fourier series rather than transforming each component.
# Last term in K behaves as ∂log(|B|²/B^ϕ)/∂ρ |B| if one ignores the issue
# of a log argument with units. Smoothness determined by positive lower bound
# of log argument, and hence behaves as ∂log(|B|)/∂ρ |B| = ∂|B|/∂ρ.
# Fourier series rather than transforming each component. Last term in K
# behaves as ∂log(|B|²/(R₀B₀B^ϕ))/∂ρ |B| where R₀B₀ is a constant with
# units Tesla meters. Smoothness is determined by positive lower bound of
# log argument, and hence behaves as ∂log(|B|/B₀)/∂ρ |B| = ∂|B|/∂ρ.
fun_data = {
"|grad(psi)|*kappa_g": data["|grad(psi)|"] * data["kappa_g"],
"|grad(rho)|*|e_alpha|r,p|": data["|grad(rho)|"] * data["|e_alpha|r,p|"],
Expand Down Expand Up @@ -266,6 +275,7 @@ def _poloidal_drift(data, B, pitch):
"num_pitch",
"pitch_batch_size",
"surf_batch_size",
"nufft_eps",
"spline",
],
)
Expand Down Expand Up @@ -298,7 +308,9 @@ def _little_gamma_c_Nemov(params, transforms, profiles, data, **kwargs):
(automorphism_sin, grad_automorphism_sin),
)
)
nufft_eps = kwargs.get("nufft_eps", 1e-7)
spline = kwargs.get("spline", True)
vander = kwargs.get("_vander", None)

def gamma_c0(data):
bounce = Bounce2D(
Expand All @@ -309,8 +321,10 @@ def gamma_c0(data):
alpha,
num_transit,
quad,
nufft_eps=nufft_eps,
is_fourier=True,
spline=spline,
vander=vander,
)

def fun(pitch_inv):
Expand All @@ -321,6 +335,7 @@ def fun(pitch_inv):
data,
["|grad(psi)|*kappa_g", "|B|_r|v,p", "K"],
points,
nufft_eps=nufft_eps,
is_fourier=True,
)
return (2 / jnp.pi) * jnp.arctan(
Expand All @@ -330,6 +345,7 @@ def fun(pitch_inv):
* bounce.interp_to_argmin(
data["|grad(rho)|*|e_alpha|r,p|"],
points,
nufft_eps=nufft_eps,
is_fourier=True,
),
)
Expand Down Expand Up @@ -397,6 +413,7 @@ def fun(pitch_inv):
"num_pitch",
"pitch_batch_size",
"surf_batch_size",
"nufft_eps",
"spline",
],
)
Expand Down Expand Up @@ -427,7 +444,6 @@ def _Gamma_c_Velasco(params, transforms, profiles, data, **kwargs):
assert (
surf_batch_size == 1 or pitch_batch_size is None
), f"Expected pitch_batch_size to be None, got {pitch_batch_size}."
spline = kwargs.get("spline", True)
fl_quad = (
kwargs["fieldline_quad"] if "fieldline_quad" in kwargs else leggauss(Y_B // 2)
)
Expand All @@ -439,6 +455,9 @@ def _Gamma_c_Velasco(params, transforms, profiles, data, **kwargs):
(automorphism_sin, grad_automorphism_sin),
)
)
nufft_eps = kwargs.get("nufft_eps", 1e-7)
spline = kwargs.get("spline", True)
vander = kwargs.get("_vander", None)

def Gamma_c(data):
bounce = Bounce2D(
Expand All @@ -449,8 +468,10 @@ def Gamma_c(data):
alpha,
num_transit,
quad,
nufft_eps=nufft_eps,
is_fourier=True,
spline=spline,
vander=vander,
)

def fun(pitch_inv):
Expand All @@ -459,7 +480,8 @@ def fun(pitch_inv):
pitch_inv,
data,
["cvdrift0", "gbdrift (periodic)", "gbdrift (secular)/phi"],
bounce.points(pitch_inv, num_well),
num_well=num_well,
nufft_eps=nufft_eps,
is_fourier=True,
)
# This is γ_c π/2.
Expand All @@ -471,7 +493,7 @@ def fun(pitch_inv):
* data["pitch_inv weight"]
/ data["pitch_inv"] ** 2,
axis=-1,
) / (bounce.compute_fieldline_length(fl_quad) * 2**1.5 * jnp.pi)
) / (bounce.compute_fieldline_length(fl_quad, vander) * 2**1.5 * jnp.pi)

grid = transforms["grid"]

Expand Down
2 changes: 1 addition & 1 deletion desc/compute/_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -2626,7 +2626,7 @@ def _B_dot_grad_grad_rho(params, transforms, profiles, data, **kwargs):
@register_compute_fun(
name="finite-n instability drive",
label="(\\mathbf{J} \\times (\\nabla \\rho))/{(g^{\\rho \\rho})}^2"
+ " \\mathbf{B} \\cdot \\cdot \\mathbf{\\nabla} (\\mathbf{\\nabla} \\rho)",
+ " \\mathbf{B} \\cdot \\mathbf{\\nabla} (\\mathbf{\\nabla} \\rho)",
units="T A \\cdot m^{-1}",
units_long="Tesla Amperes / meter",
description="finite-n instability drive term",
Expand Down
Loading
Loading