Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ New Features
- `field_line_integrate` function doesn't accept additional keyword-arguments related to `diffrax`, if it is necessary, they must be given through `options` dictionary.
- ``poincare_plot`` and ``plot_field_lines`` functions can now plot partial results if the integration failed. Previously, user had to pass ``throw=False`` or change the integration parameters. Users can ignore the warnings that are caused by hitting the bounds (i.e. `Terminating differential equation solve because an event occurred.`).
- `chunk_size` argument is now used for chunking the number of field lines. For the chunking of Biot-Savart integration for the magnetic field, users can use `bs_chunk_size` instead.
- Adds Fourier and Chebyshev differentiation matrices that can be used by calling the functions from ``desc.diffmat_utils``.



Bug Fixes

Expand Down
56 changes: 42 additions & 14 deletions desc/compute/_stability.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def _g_balloon(params, transforms, profiles, data, **kwargs):
description="Normalized squared ideal ballooning growth rate",
dim=4,
params=[],
transforms={"grid": []},
transforms={"grid": [], "diffmat": []},
profiles=[],
coordinates="rtz",
data=["c ballooning", "f ballooning", "g ballooning"],
Expand All @@ -406,35 +406,62 @@ def _ideal_ballooning_lambda(params, transforms, profiles, data, **kwargs):
Returns
-------
Ideal-ballooning lambda eigenvalues
Shape (num rho, num alpha, num zeta0, num eigvals).
Shape (num_rho, num alpha, num zeta0, num eigvals).

"""
Neigvals = kwargs.get("Neigvals", 1)
grid = transforms["grid"].source_grid
# toroidal step size between points along field lines is assumed uniform
dz = grid.nodes[grid.unique_zeta_idx[:2], 2]
dz = dz[1] - dz[0]

num_zeta = grid.num_zeta
num_zeta0 = data["c ballooning"].shape[0]

def reshape(f):
assert f.shape == (num_zeta0, grid.num_nodes)
f = grid.meshgrid_reshape(f.T, "raz").swapaxes(-1, -2)
f = jnp.swapaxes(grid.meshgrid_reshape(f.T, "raz"), -1, -2)
assert f.shape == (grid.num_rho, grid.num_alpha, num_zeta0, grid.num_zeta)
return f

c = reshape(data["c ballooning"])
f = reshape(data["f ballooning"])
g = reshape(data["g ballooning"])

# Approximate derivative along field line with second order finite differencing.
# Use g on the half grid for numerical stability.
g_half = (g[..., 1:] + g[..., :-1]) / (2 * dz**2)
b_inv = jnp.reciprocal(f[..., 1:-1])
diag_inner = (c[..., 1:-1] - g_half[..., 1:] - g_half[..., :-1]) * b_inv
diag_outer = g_half[..., 1:-1] * jnp.sqrt(b_inv[..., :-1] * b_inv[..., 1:])
if transforms["diffmat"].D_zeta is not None:

# Check that the gradients of D_zeta are not calculated
D_zeta = transforms["diffmat"].D_zeta
W_zeta = transforms["diffmat"].W_zeta

# W_zeta is purely diagonal for all the quadratures used
# This will give wrong answers for a non-diagonal W_zeta
w = jnp.diag(W_zeta)

wg = -1 * w * g
A = D_zeta.T @ (wg[..., :, None] * D_zeta)

idx = jnp.arange(num_zeta)
A = A.at[..., idx, idx].add(w * c)

# TODO: Issue #1750
w, v = eigh_tridiagonal(diag_inner, diag_outer)
b_inv = jnp.sqrt(jnp.reciprocal(w * f))

A = (b_inv[..., :, None] * A) * b_inv[..., None, :]

# apply dirichlet BC to X
w, v = jnp.linalg.eigh(A[..., 1:-1, 1:-1])

else:
# toroidal step size between points along field lines is assumed uniform
dz = grid.nodes[grid.unique_zeta_idx[:2], 2]
dz = dz[1] - dz[0]

# Approximate derivative along field line with second order finite differencing.
# Use g on the half grid for numerical stability.
g_half = (g[..., 1:] + g[..., :-1]) / (2 * dz**2)
b_inv = jnp.reciprocal(f[..., 1:-1])
diag_inner = (c[..., 1:-1] - g_half[..., 1:] - g_half[..., :-1]) * b_inv
diag_outer = g_half[..., 1:-1] * jnp.sqrt(b_inv[..., :-1] * b_inv[..., 1:])

# TODO: Issue #1750
w, v = eigh_tridiagonal(diag_inner, diag_outer)

w, top_idx = jax.lax.top_k(w, k=Neigvals)
assert w.shape == (grid.num_rho, grid.num_alpha, num_zeta0, Neigvals)
Expand All @@ -444,6 +471,7 @@ def reshape(f):
# stop_gradient prevents that.
v = jax.lax.stop_gradient(v)
v = jnp.take_along_axis(v, top_idx[..., jnp.newaxis, :], axis=-1)

assert v.shape == (
grid.num_rho,
grid.num_alpha,
Expand Down
38 changes: 37 additions & 1 deletion desc/compute/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from desc.backend import execute_on_cpu, jnp
from desc.grid import Grid

from ..diffmat_utils import DiffMat
from ..utils import errorif, rpz2xyz, rpz2xyz_vec
from .data_index import allowed_kwargs, data_index, deprecated_names

Expand Down Expand Up @@ -66,12 +67,14 @@

"""
basis = kwargs.pop("basis", "rpz").lower()

errorif(basis not in {"rpz", "xyz"}, NotImplementedError)
p = _parse_parameterization(parameterization)
if isinstance(names, str):
names = [names]
if basis == "xyz" and "phi" not in names:
names = names + ["phi"]

# this allows the DeprecationWarning to be thrown in this file
with warnings.catch_warnings():
warnings.simplefilter("always", DeprecationWarning)
Expand All @@ -87,6 +90,24 @@
"instead.",
DeprecationWarning,
)

# RG: normalize transforms so diffmat is passed via transforms, not as a kwarg ---
# We only absorb 'diffmat'. We intentionally DO NOT move 'grid' from kwargs into
# transforms here, because Equilibrium.compute's existing plumbing correctly
# constructs/wraps the source_grid when grid is passed as a kwarg.
if transforms is None:
transforms = {}

Check warning on line 99 in desc/compute/utils.py

View check run for this annotation

Codecov / codecov/patch

desc/compute/utils.py#L99

Added line #L99 was not covered by tests
else:
transforms = dict(transforms)

# Always remove `diffmat` from kwargs so it never reaches the bad-kwarg check,
# regardless of whether get_transforms already set transforms["diffmat"].
dm_kw = kwargs.pop("diffmat", None)

# If get_transforms didn't already provide transforms["diffmat"], wire it now:
if "diffmat" not in transforms and dm_kw is not None:
transforms["diffmat"] = dm_kw

Check warning on line 109 in desc/compute/utils.py

View check run for this annotation

Codecov / codecov/patch

desc/compute/utils.py#L109

Added line #L109 was not covered by tests

bad_kwargs = kwargs.keys() - allowed_kwargs
if len(bad_kwargs) > 0:
raise ValueError(f"Unrecognized argument(s): {bad_kwargs}")
Expand Down Expand Up @@ -512,7 +533,7 @@


@execute_on_cpu
def get_transforms(
def get_transforms( # noqa: C901
keys, obj, grid, jitable=False, has_axis=False, basis="rpz", **kwargs
):
"""Get transforms needed to compute a given quantity on a given grid.
Expand Down Expand Up @@ -548,6 +569,13 @@
has_axis = has_axis or (grid is not None and grid.axis.size)
derivs = get_derivs(keys, obj, has_axis=has_axis, basis=basis)
transforms = {"grid": grid}

# We do not build a Transform, just ensure the dict is present.
# If not in transforms, Look in kwargs here.
if "diffmat" in kwargs and kwargs["diffmat"] is not None:
dm = kwargs["diffmat"]
transforms["diffmat"] = dm if isinstance(dm, DiffMat) else DiffMat(**dm)

for c in derivs.keys():
if hasattr(obj, c + "_basis"): # regular stuff like R, Z, lambda etc.
basis = getattr(obj, c + "_basis")
Expand Down Expand Up @@ -637,6 +665,14 @@
build_pinv=False,
method=method,
)
elif c == "diffmat":
errorif(
"diffmat" not in transforms,
ValueError,
"Compute requested 'diffmat' but none was provided. "
"Call eq.compute(..., diffmat=DiffMat(...)) or set eq.diffmat first.",
)

elif c not in transforms: # possible other stuff lumped in with transforms
transforms[c] = getattr(obj, c)

Expand Down
Loading
Loading