-
Notifications
You must be signed in to change notification settings - Fork 41
Partial summation in coordinate mapping #1826
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
45693b9
479c497
7181f79
a85895a
c390d07
c2a0e09
c12c884
9e87fd1
3151bc7
c1cab79
7dc7916
1bef805
f22e7d4
39d1912
8558e89
b3a872c
b961570
c985023
bdd174a
74f2b12
03174ba
7be33eb
775bdf1
d5ff809
2b20f8d
fd17cc6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,16 +1,24 @@ | ||
| """Functions for mapping between flux, sfl, and real space coordinates.""" | ||
|
|
||
| import functools | ||
| from functools import partial | ||
|
|
||
| import numpy as np | ||
|
|
||
| from desc.backend import jit, jnp, root, root_scalar, vmap | ||
| from desc.backend import jit, jnp, rfft, root, root_scalar, vmap | ||
| from desc.batching import batch_map | ||
| from desc.compute import compute as compute_fun | ||
| from desc.compute import data_index, get_data_deps, get_profiles, get_transforms | ||
| from desc.grid import ConcentricGrid, Grid, LinearGrid, QuadratureGrid | ||
| from desc.transform import Transform | ||
| from desc.utils import check_posint, errorif, safenorm, setdefault, warnif | ||
| from desc.utils import ( | ||
| ResolutionWarning, | ||
| check_posint, | ||
| errorif, | ||
| safenorm, | ||
| setdefault, | ||
| warnif, | ||
| ) | ||
|
|
||
|
|
||
| def _periodic(x, period): | ||
|
|
@@ -30,7 +38,7 @@ | |
| outbasis=("rho", "theta", "zeta"), | ||
| guess=None, | ||
| params=None, | ||
| period=None, | ||
| period=(np.inf, np.inf, np.inf), | ||
| tol=1e-6, | ||
| maxiter=30, | ||
| full_output=False, | ||
|
|
@@ -65,6 +73,7 @@ | |
| period : tuple of float | ||
| Assumed periodicity for each quantity in ``inbasis``. | ||
| Use ``np.inf`` to denote no periodicity. | ||
| Default assumes no periodicity. | ||
| tol : float | ||
| Stopping tolerance. | ||
| maxiter : int | ||
|
|
@@ -115,33 +124,27 @@ | |
| # TODO (#1382): make this work for permutations of in/out basis | ||
| if outbasis == ("rho", "theta", "zeta"): | ||
| if inbasis == ("rho", "alpha", "zeta"): | ||
| errorif( | ||
| np.isfinite(period[1]), | ||
| msg=f"Period must be ∞ for inbasis={inbasis}, but got {period[1]}.", | ||
| ) | ||
| if "iota" in kwargs: | ||
| iota = kwargs.pop("iota") | ||
| elif "profiles" in kwargs: | ||
| iota = eq._compute_iota_under_jit(coords, params, **kwargs) | ||
| else: | ||
| if profiles["iota"] is None: | ||
| profiles["iota"] = eq.get_profile( | ||
| ["iota", "iota_r"], params=params, **kwargs | ||
| ) | ||
| iota = profiles["iota"].compute(Grid(coords, sort=False, jitable=True)) | ||
| return _map_clebsch_coordinates( | ||
| coords=coords, | ||
| iota=iota, | ||
| L_lmn=params["L_lmn"], | ||
| L_basis=eq.L_basis, | ||
| guess=guess[:, 1] if guess is not None else None, | ||
| period=period[1] if period is not None else np.inf, | ||
| tol=tol, | ||
| maxiter=maxiter, | ||
| full_output=full_output, | ||
| **kwargs, | ||
| ) | ||
| iota = eq._compute_iota_under_jit(coords, params, profiles, **kwargs) | ||
| rho, alpha, zeta = coords.T | ||
| omega = 0 # TODO(#568) | ||
| coords = jnp.column_stack([rho, alpha + iota * (zeta + omega), zeta]) | ||
| inbasis = ("rho", "theta_PEST", "zeta") | ||
| if inbasis == ("rho", "theta_PEST", "zeta"): | ||
| return _map_PEST_coordinates( | ||
| coords=coords, | ||
| L_lmn=params["L_lmn"], | ||
| L_basis=eq.L_basis, | ||
| guess=guess[:, 1] if guess is not None else None, | ||
| period=period[1] if period is not None else np.inf, | ||
| period=period[1], | ||
| tol=tol, | ||
| maxiter=maxiter, | ||
| full_output=full_output, | ||
|
|
@@ -154,7 +157,7 @@ | |
| params["i_l"] = profiles["iota"].params | ||
|
|
||
| rhomin = kwargs.pop("rhomin", tol / 10) | ||
| period = np.asarray(setdefault(period, (np.inf, np.inf, np.inf))) | ||
| period = np.asarray(period) | ||
| coords = _periodic(coords, period) | ||
|
|
||
| p = "desc.equilibrium.equilibrium.Equilibrium" | ||
|
|
@@ -340,7 +343,10 @@ | |
| Only returned if ``full_output`` is True. | ||
|
|
||
| """ | ||
| # noqa: D202 | ||
| errorif( | ||
| np.isfinite(period) and period != (2 * jnp.pi), | ||
| msg=f"Period must be ∞ or 2π, but got {period}.", | ||
| ) | ||
|
|
||
| # Root finding for θₖ such that r(θₖ) = ϑₖ(ρ, θₖ, ζ) − ϑ = 0. | ||
| def rootfun(theta, theta_PEST, rho, zeta): | ||
|
|
@@ -376,131 +382,154 @@ | |
| ) | ||
| ) | ||
| rho, theta_PEST, zeta = coords.T | ||
| theta = vecroot( | ||
| # Assume λ=0 for default initial guess. | ||
| setdefault(guess, theta_PEST), | ||
| theta_PEST, | ||
| rho, | ||
| zeta, | ||
| ) | ||
| if full_output: | ||
| theta, (res, niter) = vecroot( | ||
| # Assume λ=0 for default initial guess. | ||
| setdefault(guess, theta_PEST), | ||
| theta_PEST, | ||
| rho, | ||
| zeta, | ||
| ) | ||
| else: | ||
| theta = vecroot( | ||
| # Assume λ=0 for default initial guess. | ||
| setdefault(guess, theta_PEST), | ||
| theta_PEST, | ||
| rho, | ||
| zeta, | ||
| ) | ||
| theta, (res, niter) = theta | ||
| out = jnp.column_stack([rho, jnp.atleast_1d(theta.squeeze()), zeta]) | ||
| if full_output: | ||
| return out, (res, niter) | ||
| return out | ||
|
|
||
|
|
||
| # TODO(#568): decide later whether to assume given phi instead of zeta. | ||
| def _partial_sum(lmbda, L_lmn, omega, W_lmn, iota): | ||
| """Convert FourierZernikeBasis to set of Fourier series. | ||
|
|
||
| TODO(#1243) Do proper partial summation once the DESC | ||
| basis are improved to store the padded tensor product modes. | ||
| https://github.com/PlasmaControl/DESC/issues/1243#issuecomment-3131182128. | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @PlasmaControl/desc-dev in #1508 header it says "figure out how to do FourierZernike". Basically until the basis is padded #1243 (comment) there is no efficient implementation without loops. |
||
| The partial summation implemented here has a totally unnecessary FourierZernike | ||
| spectral to real transform and unnecessary N^2 FFT's of size N. Still the | ||
| performance improvement is significant. To avoid the transform and FFTs, | ||
| I suggest padding the FourierZernike basis modes to make the partial summation | ||
| trivial. Then this computation will likely take microseconds. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| lmbda : Transform | ||
| FourierZernikeBasis | ||
| L_lmn : jnp.ndarray | ||
| FourierZernikeBasis basis coefficients for λ. | ||
| omega : Transform | ||
| FourierZernikeBasis | ||
| W_lmn : jnp.ndarray | ||
| FourierZernikeBasis basis coefficients for ω. | ||
| iota : jnp.ndarray | ||
| Shape (lmbda.grid.num_rho, ) | ||
|
|
||
| Returns | ||
| ------- | ||
| lmbda_minus_iota_omega, modes | ||
| Spectral coefficients and modes. | ||
| Shape (num rho, num zeta, num modes). | ||
|
|
||
| """ | ||
| grid = lmbda.grid | ||
| errorif(not grid.fft_poloidal, NotImplementedError, msg="See note in docstring.") | ||
| # TODO(#1243): assert grid.sym==eq.sym once basis is padded for partial sum | ||
| # TODO: (#568) | ||
| warnif( | ||
| grid.M > lmbda.basis.M, | ||
| ResolutionWarning, | ||
| msg="Poloidal grid resolution is higher than necessary for coordinate mapping.", | ||
| ) | ||
| warnif( | ||
| grid.M < lmbda.basis.M, | ||
| ResolutionWarning, | ||
| msg="High frequency lambda modes will be truncated in coordinate mapping.", | ||
| ) | ||
| lmbda_minus_iota_omega = lmbda.transform(L_lmn) | ||
| lmbda_minus_iota_omega = ( | ||
unalmis marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| rfft(grid.meshgrid_reshape(lmbda_minus_iota_omega, "rzt"), norm="forward") | ||
| .at[..., (0, -1) if ((grid.num_theta % 2) == 0) else 0] | ||
| .divide(2) | ||
| * 2 | ||
| ) | ||
| return lmbda_minus_iota_omega, jnp.fft.rfftfreq(grid.num_theta, 1 / grid.num_theta) | ||
|
|
||
|
|
||
| @partial(jit, static_argnames=["tol", "maxiter"]) | ||
| def _map_clebsch_coordinates( | ||
| coords, | ||
| iota, | ||
| alpha, | ||
| zeta, | ||
| L_lmn, | ||
| L_basis, | ||
| lmbda, | ||
| guess=None, | ||
| period=np.inf, | ||
| *, | ||
| tol=1e-6, | ||
| maxiter=30, | ||
| full_output=False, | ||
| **kwargs, | ||
| ): | ||
| """Find θ for given Clebsch field line poloidal label α. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| coords : ndarray | ||
| Shape (k, 3). | ||
| Clebsch field line coordinates [ρ, α, ζ]. Assumes ζ = ϕ. | ||
| Each row is a different point in space. | ||
| iota : ndarray | ||
| Shape (k, ). | ||
| Rotational transform on each node. | ||
| Shape (num iota, ). | ||
| Rotational transform. | ||
| alpha : ndarray | ||
| Shape (num alpha, ). | ||
| Field line labels. | ||
| zeta : ndarray | ||
| Shape (num zeta, ). | ||
| DESC toroidal angle. | ||
| L_lmn : jnp.ndarray | ||
| Spectral coefficients for lambda. | ||
| L_basis : Basis | ||
| Spectral basis for lambda. | ||
| Spectral coefficients for λ. | ||
| lmbda : Transform | ||
| Transform for λ built on DESC coordinates [ρ, θ, ζ]. | ||
| guess : jnp.ndarray | ||
| Shape (k, ). | ||
| Optional initial guess for the computational coordinates. | ||
| period : float | ||
| Assumed periodicity for α. | ||
| Use ``np.inf`` to denote no periodicity. | ||
| Shape (num iota, num alpha, num zeta). | ||
| Optional initial guess for the DESC computational coordinate θ solution. | ||
| tol : float | ||
| Stopping tolerance. | ||
| maxiter : int | ||
| Maximum number of Newton iterations. | ||
| full_output : bool, optional | ||
| If True, also return a tuple where the first element is the residual from | ||
| the root finding and the second is the number of iterations. | ||
| kwargs : dict, optional | ||
| Additional keyword arguments to pass to ``root_scalar`` such as ``maxiter_ls``, | ||
| ``alpha``. | ||
|
|
||
| Returns | ||
| ------- | ||
| out : ndarray | ||
| Shape (k, 3). | ||
| DESC computational coordinates [ρ, θ, ζ]. | ||
| info : tuple | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why did you remove info? This is not a good coding practice. This is basically passive encryption.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What happens if root finding fails for some reason and I want to debug it.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the residual is calculated elsewhere, ignore this comment.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a new function, see #1826 (comment), so I disagree with the statement about coding practice. Whatever code users have that uses root finding, nothing has changed. They can still get their In this new function, I have not added functionality to return auxiliary information about the root finding because that is impossible --- functions that are decorated with
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not convinced that this root finding can fail by the way. |
||
| 2 element tuple containing residuals and number of iterations for each point. | ||
| Only returned if ``full_output`` is True. | ||
| theta : ndarray | ||
| Shape (num iota, num alpha, num zeta). | ||
| DESC computational coordinates θ at given input meshgrid. | ||
|
|
||
| """ | ||
| # noqa: D202 | ||
|
|
||
| # Root finding for θₖ such that r(θₖ) = αₖ(ρ, θₖ, ζ) − α = 0. | ||
| def rootfun(theta, alpha, rho, zeta, iota): | ||
| nodes = jnp.array([rho.squeeze(), theta.squeeze(), zeta.squeeze()], ndmin=2) | ||
| A = L_basis.evaluate(nodes) | ||
| lmbda = A @ L_lmn | ||
| alpha_k = theta + lmbda - iota * zeta | ||
| return _fixup_residual(alpha_k - alpha, period).squeeze() | ||
|
|
||
| def jacfun(theta, alpha, rho, zeta, iota): | ||
| # Valid everywhere except θ such that θ+λ = k period where k ∈ ℤ. | ||
| nodes = jnp.array([rho.squeeze(), theta.squeeze(), zeta.squeeze()], ndmin=2) | ||
| A1 = L_basis.evaluate(nodes, (0, 1, 0)) | ||
| lmbda_t = jnp.dot(A1, L_lmn) | ||
| return 1 + lmbda_t.squeeze() | ||
|
|
||
| def fixup(x, *args): | ||
| return _periodic(x, period) | ||
|
|
||
| vecroot = jit( | ||
| vmap( | ||
| lambda x0, *p: root_scalar( | ||
| rootfun, | ||
| x0, | ||
| jac=jacfun, | ||
| args=p, | ||
| fixup=fixup, | ||
| tol=tol, | ||
| maxiter=maxiter, | ||
| full_output=full_output, | ||
| **kwargs, | ||
| ) | ||
| def rootfun(theta, target, c_m): | ||
| c = (jnp.exp(1j * modes * theta) * c_m).real.sum() | ||
| target_k = theta + c | ||
| return target_k - target | ||
|
|
||
| def jacfun(theta, target, c_m): | ||
| dc_dt = ((1j * jnp.exp(1j * modes * theta) * c_m).real * modes).sum() | ||
| return 1 + dc_dt | ||
|
|
||
| @partial(jnp.vectorize, signature="(),(),(m)->()") | ||
| def vecroot(guess, target, c_m): | ||
| return root_scalar( | ||
| rootfun, | ||
| guess, | ||
| jac=jacfun, | ||
| args=(target, c_m), | ||
| tol=tol, | ||
| maxiter=maxiter, | ||
| full_output=False, | ||
| **kwargs, | ||
| ) | ||
| ) | ||
| rho, alpha, zeta = coords.T | ||
| if guess is None: | ||
| # Assume λ=0 for default initial guess. | ||
| guess = alpha + iota * zeta | ||
| if full_output: | ||
| theta, (res, niter) = vecroot(guess, alpha, rho, zeta, iota) | ||
| else: | ||
| theta = vecroot(guess, alpha, rho, zeta, iota) | ||
|
|
||
| out = jnp.column_stack([rho, jnp.atleast_1d(theta.squeeze()), zeta]) | ||
| if full_output: | ||
| return out, (res, niter) | ||
| return out | ||
| c_m, modes = _partial_sum(lmbda, L_lmn, None, None, iota) | ||
| c_m = c_m[:, jnp.newaxis] | ||
| target = alpha[:, jnp.newaxis] + iota[:, jnp.newaxis, jnp.newaxis] * zeta | ||
| # Assume λ − ι ω = 0 for default initial guess. | ||
| return vecroot(setdefault(guess, target), target, c_m) | ||
|
|
||
|
|
||
| def is_nested(eq, grid=None, R_lmn=None, Z_lmn=None, L_lmn=None, msg=None): | ||
|
|
@@ -770,45 +799,3 @@ | |
| **idx, | ||
| ) | ||
| return desc_grid | ||
|
|
||
|
|
||
| # TODO(#1383): deprecated, remove eventually | ||
| def compute_theta_coords( | ||
| eq, flux_coords, L_lmn=None, tol=1e-6, maxiter=20, full_output=False, **kwargs | ||
| ): | ||
| """Find θ (theta_DESC) for given straight field line ϑ (theta_PEST). | ||
|
|
||
| Parameters | ||
| ---------- | ||
| eq : Equilibrium | ||
| Equilibrium to use. | ||
| flux_coords : ndarray | ||
| Shape (k, 3). | ||
| Straight field line PEST coordinates [ρ, ϑ, ϕ]. Assumes ζ = ϕ. | ||
| Each row is a different point in space. | ||
| L_lmn : ndarray | ||
| Spectral coefficients for lambda. Defaults to ``eq.L_lmn``. | ||
| tol : float | ||
| Stopping tolerance. | ||
| maxiter : int | ||
| Maximum number of Newton iterations. | ||
| full_output : bool, optional | ||
| If True, also return a tuple where the first element is the residual from | ||
| the root finding and the second is the number of iterations. | ||
| kwargs : dict, optional | ||
| Additional keyword arguments to pass to ``root_scalar`` such as | ||
| ``maxiter_ls``, ``alpha``. | ||
|
|
||
| Returns | ||
| ------- | ||
| coords : ndarray | ||
| Shape (k, 3). | ||
| DESC computational coordinates [ρ, θ, ζ]. | ||
| info : tuple | ||
| 2 element tuple containing residuals and number of iterations for each | ||
| point. Only returned if ``full_output`` is True. | ||
|
|
||
| """ | ||
| return eq.compute_theta_coords( | ||
| flux_coords, L_lmn, tol, maxiter, full_output, **kwargs | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If zeta is the generalized toroidal angle, don't we assume zeta = phi + Omega where phi is the cylindrical toroidal angle?
So shouldn't theta_PEST = alpha + iota * (zeta - omega)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Phi = zeta + omega, and theta_PEST = theta + lambda. The left hand side of these relations are defined quantities. Phi is the cylindrical toroidal angle and theta_PEST is the angle where the field lines are straight in the (theta_PEST, Phi) plane. When we mention generalizing angles, we refer to changing the meaning of the angles "zeta" and "theta". These relations must still hold and hence the stream functions must negate the change in zeta and theta.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
alpha = theta_PEST - iota phi