Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion desc/integrals/_bounce_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,7 +831,7 @@
Fourier coefficients.

"""
return (v * c[..., None, None, None, :, :]).real.sum((-2, -1))
return jnp.einsum("...pwqzt, ...zt -> ...pwq", v, c).real

Check warning on line 834 in desc/integrals/_bounce_utils.py

View check run for this annotation

Codecov / codecov/patch

desc/integrals/_bounce_utils.py#L834

Added line #L834 was not covered by tests


def broadcast_for_bounce(pitch_inv):
Expand Down
15 changes: 4 additions & 11 deletions desc/integrals/_interp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ def irfft_mmt(x, a, n, domain=(0, 2 * jnp.pi), axis=-1, *, _modes=None):
_modes = jnp.fft.rfftfreq(n, (domain[1] - domain[0]) / (2 * jnp.pi * n))
i = (0, -1) if (n % 2 == 0) else 0
a = jnp.moveaxis(a, axis, -1).at[..., i].divide(2) * 2
vander = jnp.exp(1j * _modes * (x - domain[0])[..., jnp.newaxis])
return (vander * a).real.sum(-1)
vander = jnp.exp(-1j * _modes * (x - domain[0])[..., jnp.newaxis])
return jnp.linalg.vecdot(vander, a).real


def ifft_mmt(x, a, domain=(0, 2 * jnp.pi), axis=-1, *, vander=None, modes=None):
Expand Down Expand Up @@ -269,7 +269,7 @@ def _irfft2_mmt(
f, r = np.argsort(axes)
modes_f, modes_r = rfft2_modes(n[f], n[r], d[f], d[r])
vander = rfft2_vander(x[f], x[r], modes_f, modes_r, d[f][0], d[r][0])
return (vander * a).real.sum((-2, -1))
return jnp.einsum("...mn, ...mn", vander, a).real


def rfft2_vander(
Expand All @@ -286,14 +286,7 @@ def rfft2_vander(

Warnings
--------
It is vital to not perform any operations on Vandermonde array and immediately
reduce it. For example, to transform from spectral to real space do
``a=jnp.fft.rfft2(f).at[...,i].divide(2)*2``

``(vander*a).real.sum((-2,-1))``

Performing the scaling on the Vandermonde array would triple the memory consumption.
Perhaps this is required for the compiler to fuse operations.
Reduce with einsum to save memory.

Notes
-----
Expand Down
5 changes: 3 additions & 2 deletions desc/integrals/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,9 +636,10 @@ def intersect2d(self, k=0.0, *, eps=_eps):
# ∂f/∂y = ∑ₙ₌₀ᴺ⁻¹ aₙ(x) n Uₙ₋₁(y)
# sign ∂f/∂y = sign ∑ₙ₌₀ᴺ⁻¹ aₙ(x) n sin(n arcos y)
df_dy = jnp.sign(
jnp.linalg.vecdot(
jnp.einsum(
"...yn, ...n",
n * jnp.sin(n * jnp.arccos(y)[..., None]),
self.cheb[..., None, :],
self.cheb,
)
)
y = bijection_from_disc(y, self.domain[0], self.domain[-1])
Expand Down
1 change: 0 additions & 1 deletion desc/objectives/_fast_ion.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ class GammaC(_Objective):
Notes
-----
Performance will improve significantly by resolving these GitHub issues.
* https://github.com/jax-ml/jax/issues/30627
* ``1303`` Patch for differentiable code with dynamic shapes
* ``1206`` Upsample data above midplane to full grid assuming stellarator symmetry
* ``1034`` Optimizers/objectives with auxiliary output
Expand Down
1 change: 0 additions & 1 deletion desc/objectives/_neoclassical.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ class EffectiveRipple(_Objective):
Notes
-----
Performance will improve significantly by resolving these GitHub issues.
* https://github.com/jax-ml/jax/issues/30627
* ``1303`` Patch for differentiable code with dynamic shapes
* ``1206`` Upsample data above midplane to full grid assuming stellarator symmetry
* ``1034`` Optimizers/objectives with auxiliary output
Expand Down
9 changes: 1 addition & 8 deletions tests/test_neoclassical.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,7 @@
@pytest.mark.mpl_image_compare(remove_text=True, tolerance=tol_1d)
@pytest.mark.parametrize("nufft_eps", [0, 1e-6])
def test_effective_ripple_2D(nufft_eps):
"""Test effective ripple with W7-X against NEO.

If this test has a peak memory consumption of more than 2.7 GB on JAX version 0.5.0
or more than 5.7 GB on JAX versions 0.5.3+, then there is another memory regression.
These values are for the test where nufft_eps is zero.
https://github.com/jax-ml/jax/issues/30627.
For nufft_eps nonzero with surf_batch_size = 2, memory is 1 GB.
"""
"""Test effective ripple with W7-X against NEO."""
eq = get("W7-X")
rho = np.linspace(0, 1, 10)
grid = LinearGrid(rho=rho, M=eq.M_grid, N=eq.N_grid, NFP=eq.NFP, sym=False)
Expand Down
Loading