Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
45693b9
Merge changes from 1360
unalmis Jul 29, 2025
479c497
Add NFP warning to eq.compute
unalmis Jul 30, 2025
7181f79
Merge branch 'ku/NFP' into ku/partialsum
unalmis Jul 30, 2025
a85895a
first pass at partial sum
unalmis Jul 30, 2025
c390d07
Merge branch 'master' into ku/partialsum
unalmis Jul 30, 2025
c2a0e09
working commit
unalmis Jul 31, 2025
c12c884
Merge branch 'master' into ku/partialsum
unalmis Jul 31, 2025
9e87fd1
Remove old static attributes
unalmis Jul 31, 2025
3151bc7
partial sum pass two
unalmis Jul 31, 2025
c1cab79
Reduce resolution
unalmis Jul 31, 2025
7dc7916
Updated notebook
unalmis Jul 31, 2025
1bef805
Dummy wrapper to avoid circular import
unalmis Jul 31, 2025
f22e7d4
Update _fast_ion.py
unalmis Jul 31, 2025
39d1912
Cast to array first
unalmis Jul 31, 2025
8558e89
Remove deprecated code
unalmis Jul 31, 2025
b3a872c
Merge branch 'master' into ku/partialsum
unalmis Aug 2, 2025
b961570
Merge branch 'master' into ku/partialsum
unalmis Aug 4, 2025
c985023
Pull changes from ku/nufft
unalmis Aug 6, 2025
bdd174a
Merge branch 'master' into ku/partialsum
unalmis Aug 7, 2025
74f2b12
Merge branch 'master' into ku/partialsum
unalmis Aug 11, 2025
03174ba
Merge branch 'master' into ku/partialsum
unalmis Aug 13, 2025
7be33eb
dario comment suggestion
unalmis Aug 13, 2025
775bdf1
Pulling changes down from #1834 which are necessary to address @f0uri…
unalmis Aug 13, 2025
d5ff809
Add comment to address Rory comment
unalmis Aug 13, 2025
2b20f8d
Changing variable name for Rahul
unalmis Aug 14, 2025
fd17cc6
Merge branch 'master' into ku/partialsum
rahulgaur104 Aug 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
291 changes: 139 additions & 152 deletions desc/equilibrium/coords.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
"""Functions for mapping between flux, sfl, and real space coordinates."""

import functools
from functools import partial

import numpy as np

from desc.backend import jit, jnp, root, root_scalar, vmap
from desc.backend import jit, jnp, rfft, root, root_scalar, vmap
from desc.batching import batch_map
from desc.compute import compute as compute_fun
from desc.compute import data_index, get_data_deps, get_profiles, get_transforms
from desc.grid import ConcentricGrid, Grid, LinearGrid, QuadratureGrid
from desc.transform import Transform
from desc.utils import check_posint, errorif, safenorm, setdefault, warnif
from desc.utils import (
ResolutionWarning,
check_posint,
errorif,
safenorm,
setdefault,
warnif,
)


def _periodic(x, period):
Expand All @@ -30,7 +38,7 @@
outbasis=("rho", "theta", "zeta"),
guess=None,
params=None,
period=None,
period=(np.inf, np.inf, np.inf),
tol=1e-6,
maxiter=30,
full_output=False,
Expand Down Expand Up @@ -65,6 +73,7 @@
period : tuple of float
Assumed periodicity for each quantity in ``inbasis``.
Use ``np.inf`` to denote no periodicity.
Default assumes no periodicity.
tol : float
Stopping tolerance.
maxiter : int
Expand Down Expand Up @@ -115,33 +124,27 @@
# TODO (#1382): make this work for permutations of in/out basis
if outbasis == ("rho", "theta", "zeta"):
if inbasis == ("rho", "alpha", "zeta"):
errorif(
np.isfinite(period[1]),
msg=f"Period must be ∞ for inbasis={inbasis}, but got {period[1]}.",
)
if "iota" in kwargs:
iota = kwargs.pop("iota")
elif "profiles" in kwargs:
iota = eq._compute_iota_under_jit(coords, params, **kwargs)

Check warning on line 134 in desc/equilibrium/coords.py

View check run for this annotation

Codecov / codecov/patch

desc/equilibrium/coords.py#L134

Added line #L134 was not covered by tests
else:
if profiles["iota"] is None:
profiles["iota"] = eq.get_profile(
["iota", "iota_r"], params=params, **kwargs
)
iota = profiles["iota"].compute(Grid(coords, sort=False, jitable=True))
return _map_clebsch_coordinates(
coords=coords,
iota=iota,
L_lmn=params["L_lmn"],
L_basis=eq.L_basis,
guess=guess[:, 1] if guess is not None else None,
period=period[1] if period is not None else np.inf,
tol=tol,
maxiter=maxiter,
full_output=full_output,
**kwargs,
)
iota = eq._compute_iota_under_jit(coords, params, profiles, **kwargs)
rho, alpha, zeta = coords.T
omega = 0 # TODO(#568)
coords = jnp.column_stack([rho, alpha + iota * (zeta + omega), zeta])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If zeta is the generalized toroidal angle, don't we assume zeta = phi + Omega where phi is the cylindrical toroidal angle?
So shouldn't theta_PEST = alpha + iota * (zeta - omega)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Phi = zeta + omega, and theta_PEST = theta + lambda. The left hand side of these relations are defined quantities. Phi is the cylindrical toroidal angle and theta_PEST is the angle where the field lines are straight in the (theta_PEST, Phi) plane. When we mention generalizing angles, we refer to changing the meaning of the angles "zeta" and "theta". These relations must still hold and hence the stream functions must negate the change in zeta and theta.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alpha = theta_PEST - iota phi

inbasis = ("rho", "theta_PEST", "zeta")
if inbasis == ("rho", "theta_PEST", "zeta"):
return _map_PEST_coordinates(
coords=coords,
L_lmn=params["L_lmn"],
L_basis=eq.L_basis,
guess=guess[:, 1] if guess is not None else None,
period=period[1] if period is not None else np.inf,
period=period[1],
tol=tol,
maxiter=maxiter,
full_output=full_output,
Expand All @@ -154,7 +157,7 @@
params["i_l"] = profiles["iota"].params

rhomin = kwargs.pop("rhomin", tol / 10)
period = np.asarray(setdefault(period, (np.inf, np.inf, np.inf)))
period = np.asarray(period)
coords = _periodic(coords, period)

p = "desc.equilibrium.equilibrium.Equilibrium"
Expand Down Expand Up @@ -340,7 +343,10 @@
Only returned if ``full_output`` is True.

"""
# noqa: D202
errorif(
np.isfinite(period) and period != (2 * jnp.pi),
msg=f"Period must be ∞ or 2π, but got {period}.",
)

# Root finding for θₖ such that r(θₖ) = ϑₖ(ρ, θₖ, ζ) − ϑ = 0.
def rootfun(theta, theta_PEST, rho, zeta):
Expand Down Expand Up @@ -376,131 +382,154 @@
)
)
rho, theta_PEST, zeta = coords.T
theta = vecroot(
# Assume λ=0 for default initial guess.
setdefault(guess, theta_PEST),
theta_PEST,
rho,
zeta,
)
if full_output:
theta, (res, niter) = vecroot(
# Assume λ=0 for default initial guess.
setdefault(guess, theta_PEST),
theta_PEST,
rho,
zeta,
)
else:
theta = vecroot(
# Assume λ=0 for default initial guess.
setdefault(guess, theta_PEST),
theta_PEST,
rho,
zeta,
)
theta, (res, niter) = theta

Check warning on line 393 in desc/equilibrium/coords.py

View check run for this annotation

Codecov / codecov/patch

desc/equilibrium/coords.py#L393

Added line #L393 was not covered by tests
out = jnp.column_stack([rho, jnp.atleast_1d(theta.squeeze()), zeta])
if full_output:
return out, (res, niter)
return out


# TODO(#568): decide later whether to assume given phi instead of zeta.
def _partial_sum(lmbda, L_lmn, omega, W_lmn, iota):
"""Convert FourierZernikeBasis to set of Fourier series.

TODO(#1243) Do proper partial summation once the DESC
basis are improved to store the padded tensor product modes.
https://github.com/PlasmaControl/DESC/issues/1243#issuecomment-3131182128.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@PlasmaControl/desc-dev in #1508 header it says "figure out how to do FourierZernike". Basically until the basis is padded #1243 (comment) there is no efficient implementation without loops.

The partial summation implemented here has a totally unnecessary FourierZernike
spectral to real transform and unnecessary N^2 FFT's of size N. Still the
performance improvement is significant. To avoid the transform and FFTs,
I suggest padding the FourierZernike basis modes to make the partial summation
trivial. Then this computation will likely take microseconds.

Parameters
----------
lmbda : Transform
FourierZernikeBasis
L_lmn : jnp.ndarray
FourierZernikeBasis basis coefficients for λ.
omega : Transform
FourierZernikeBasis
W_lmn : jnp.ndarray
FourierZernikeBasis basis coefficients for ω.
iota : jnp.ndarray
Shape (lmbda.grid.num_rho, )

Returns
-------
lmbda_minus_iota_omega, modes
Spectral coefficients and modes.
Shape (num rho, num zeta, num modes).

"""
grid = lmbda.grid
errorif(not grid.fft_poloidal, NotImplementedError, msg="See note in docstring.")
# TODO(#1243): assert grid.sym==eq.sym once basis is padded for partial sum
# TODO: (#568)
warnif(
grid.M > lmbda.basis.M,
ResolutionWarning,
msg="Poloidal grid resolution is higher than necessary for coordinate mapping.",
)
warnif(
grid.M < lmbda.basis.M,
ResolutionWarning,
msg="High frequency lambda modes will be truncated in coordinate mapping.",
)
lmbda_minus_iota_omega = lmbda.transform(L_lmn)
lmbda_minus_iota_omega = (
rfft(grid.meshgrid_reshape(lmbda_minus_iota_omega, "rzt"), norm="forward")
.at[..., (0, -1) if ((grid.num_theta % 2) == 0) else 0]
.divide(2)
* 2
)
return lmbda_minus_iota_omega, jnp.fft.rfftfreq(grid.num_theta, 1 / grid.num_theta)


@partial(jit, static_argnames=["tol", "maxiter"])
def _map_clebsch_coordinates(
coords,
iota,
alpha,
zeta,
L_lmn,
L_basis,
lmbda,
guess=None,
period=np.inf,
*,
tol=1e-6,
maxiter=30,
full_output=False,
**kwargs,
):
"""Find θ for given Clebsch field line poloidal label α.

Parameters
----------
coords : ndarray
Shape (k, 3).
Clebsch field line coordinates [ρ, α, ζ]. Assumes ζ = ϕ.
Each row is a different point in space.
iota : ndarray
Shape (k, ).
Rotational transform on each node.
Shape (num iota, ).
Rotational transform.
alpha : ndarray
Shape (num alpha, ).
Field line labels.
zeta : ndarray
Shape (num zeta, ).
DESC toroidal angle.
L_lmn : jnp.ndarray
Spectral coefficients for lambda.
L_basis : Basis
Spectral basis for lambda.
Spectral coefficients for λ.
lmbda : Transform
Transform for λ built on DESC coordinates [ρ, θ, ζ].
guess : jnp.ndarray
Shape (k, ).
Optional initial guess for the computational coordinates.
period : float
Assumed periodicity for α.
Use ``np.inf`` to denote no periodicity.
Shape (num iota, num alpha, num zeta).
Optional initial guess for the DESC computational coordinate θ solution.
tol : float
Stopping tolerance.
maxiter : int
Maximum number of Newton iterations.
full_output : bool, optional
If True, also return a tuple where the first element is the residual from
the root finding and the second is the number of iterations.
kwargs : dict, optional
Additional keyword arguments to pass to ``root_scalar`` such as ``maxiter_ls``,
``alpha``.

Returns
-------
out : ndarray
Shape (k, 3).
DESC computational coordinates [ρ, θ, ζ].
info : tuple
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you remove info? This is not a good coding practice. This is basically passive encryption.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if root finding fails for some reason and I want to debug it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the residual is calculated elsewhere, ignore this comment.

Copy link
Collaborator Author

@unalmis unalmis Aug 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a new function, see #1826 (comment), so I disagree with the statement about coding practice. Whatever code users have that uses root finding, nothing has changed. They can still get their info tuple.

In this new function, I have not added functionality to return auxiliary information about the root finding because that is impossible --- functions that are decorated with jnp.vectorize, such as this one, must return arrays. info is not an array.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not convinced that this root finding can fail by the way.

2 element tuple containing residuals and number of iterations for each point.
Only returned if ``full_output`` is True.
theta : ndarray
Shape (num iota, num alpha, num zeta).
DESC computational coordinates θ at given input meshgrid.

"""
# noqa: D202

# Root finding for θₖ such that r(θₖ) = αₖ(ρ, θₖ, ζ) − α = 0.
def rootfun(theta, alpha, rho, zeta, iota):
nodes = jnp.array([rho.squeeze(), theta.squeeze(), zeta.squeeze()], ndmin=2)
A = L_basis.evaluate(nodes)
lmbda = A @ L_lmn
alpha_k = theta + lmbda - iota * zeta
return _fixup_residual(alpha_k - alpha, period).squeeze()

def jacfun(theta, alpha, rho, zeta, iota):
# Valid everywhere except θ such that θ+λ = k period where k ∈ ℤ.
nodes = jnp.array([rho.squeeze(), theta.squeeze(), zeta.squeeze()], ndmin=2)
A1 = L_basis.evaluate(nodes, (0, 1, 0))
lmbda_t = jnp.dot(A1, L_lmn)
return 1 + lmbda_t.squeeze()

def fixup(x, *args):
return _periodic(x, period)

vecroot = jit(
vmap(
lambda x0, *p: root_scalar(
rootfun,
x0,
jac=jacfun,
args=p,
fixup=fixup,
tol=tol,
maxiter=maxiter,
full_output=full_output,
**kwargs,
)
def rootfun(theta, target, c_m):
c = (jnp.exp(1j * modes * theta) * c_m).real.sum()
target_k = theta + c
return target_k - target

def jacfun(theta, target, c_m):
dc_dt = ((1j * jnp.exp(1j * modes * theta) * c_m).real * modes).sum()
return 1 + dc_dt

@partial(jnp.vectorize, signature="(),(),(m)->()")
def vecroot(guess, target, c_m):
return root_scalar(
rootfun,
guess,
jac=jacfun,
args=(target, c_m),
tol=tol,
maxiter=maxiter,
full_output=False,
**kwargs,
)
)
rho, alpha, zeta = coords.T
if guess is None:
# Assume λ=0 for default initial guess.
guess = alpha + iota * zeta
if full_output:
theta, (res, niter) = vecroot(guess, alpha, rho, zeta, iota)
else:
theta = vecroot(guess, alpha, rho, zeta, iota)

out = jnp.column_stack([rho, jnp.atleast_1d(theta.squeeze()), zeta])
if full_output:
return out, (res, niter)
return out
c_m, modes = _partial_sum(lmbda, L_lmn, None, None, iota)
c_m = c_m[:, jnp.newaxis]
target = alpha[:, jnp.newaxis] + iota[:, jnp.newaxis, jnp.newaxis] * zeta
# Assume λ − ι ω = 0 for default initial guess.
return vecroot(setdefault(guess, target), target, c_m)


def is_nested(eq, grid=None, R_lmn=None, Z_lmn=None, L_lmn=None, msg=None):
Expand Down Expand Up @@ -770,45 +799,3 @@
**idx,
)
return desc_grid


# TODO(#1383): deprecated, remove eventually
def compute_theta_coords(
eq, flux_coords, L_lmn=None, tol=1e-6, maxiter=20, full_output=False, **kwargs
):
"""Find θ (theta_DESC) for given straight field line ϑ (theta_PEST).

Parameters
----------
eq : Equilibrium
Equilibrium to use.
flux_coords : ndarray
Shape (k, 3).
Straight field line PEST coordinates [ρ, ϑ, ϕ]. Assumes ζ = ϕ.
Each row is a different point in space.
L_lmn : ndarray
Spectral coefficients for lambda. Defaults to ``eq.L_lmn``.
tol : float
Stopping tolerance.
maxiter : int
Maximum number of Newton iterations.
full_output : bool, optional
If True, also return a tuple where the first element is the residual from
the root finding and the second is the number of iterations.
kwargs : dict, optional
Additional keyword arguments to pass to ``root_scalar`` such as
``maxiter_ls``, ``alpha``.

Returns
-------
coords : ndarray
Shape (k, 3).
DESC computational coordinates [ρ, θ, ζ].
info : tuple
2 element tuple containing residuals and number of iterations for each
point. Only returned if ``full_output`` is True.

"""
return eq.compute_theta_coords(
flux_coords, L_lmn, tol, maxiter, full_output, **kwargs
)
Loading