Skip to content
Draft
200 changes: 149 additions & 51 deletions desc/geometry/surface.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
vmap,
)
from desc.basis import DoubleFourierSeries, ZernikePolynomial
from desc.compute import get_transforms
from desc.grid import Grid, LinearGrid
from desc.io import InputReader
from desc.optimizable import optimizable_parameter
Expand Down Expand Up @@ -666,7 +667,7 @@ def from_shape_parameters(
def constant_offset_surface(
self, offset, grid=None, M=None, N=None, full_output=False
):
"""Create a FourierRZSurface with constant offset from the base surface (self).
"""Create a new FourierRZToroidalSurface with constant offset from self.

Implementation of algorithm described in Appendix B of
"An improved current potential method for fast computation of
Expand All @@ -676,6 +677,10 @@ def constant_offset_surface(
NOTE: Must have the toroidal angle as the cylindrical toroidal angle
in order for this algorithm to work properly

NOTE: if one wants to use this inside of an optimization, one should
use the private method _constant_offset_surface directly, and refer to
the documentation in PR #2016 for more details.

Parameters
----------
base_surface : FourierRZToroidalSurface
Expand All @@ -684,7 +689,7 @@ def constant_offset_surface(
constant offset (in m) of the desired surface from the input surface
offset will be in the normal direction to the surface.
grid : Grid, optional
Grid object of the points on the given surface to evaluate the
Grid object of the points on the offset surface to evaluate the
offset points at, from which the offset surface will be created by fitting
offset points with the basis defined by the given M and N.
If None, defaults to a LinearGrid with M and N and NFP equal to the
Expand Down Expand Up @@ -717,6 +722,8 @@ def constant_offset_surface(
coordinates on the offset surface, corresponding to the
``x`` points on the base_surface (i.e. the points to which the
offset surface was fit)
as well as the transforms bases used to fit R and Z.
Only returned if ``full_output`` is True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docs of the grid can also be updated. The default grid resolution is twice the M,N

info : tuple
2 element tuple containing residuals and number of iterations
for each point. Only returned if ``full_output`` is True
Expand All @@ -725,7 +732,7 @@ def constant_offset_surface(
M = check_nonnegint(M, "M")
N = check_nonnegint(N, "N")

base_surface = self
base_surface = self.copy()
if grid is None:
grid = LinearGrid(
M=base_surface.M * 2,
Expand All @@ -738,57 +745,23 @@ def constant_offset_surface(
), "base_surface must be a FourierRZToroidalSurface!"
M = base_surface.M if M is None else int(M)
N = base_surface.N if N is None else int(N)
base_surface.change_resolution(M=M, N=N)

def n_and_r_jax(nodes):
data = base_surface.compute(
["X", "Y", "Z", "n_rho"],
grid=Grid(nodes, jitable=True, sort=False),
method="jitable",
)

phi = nodes[:, 2]
re = jnp.vstack([data["X"], data["Y"], data["Z"]]).T
n = data["n_rho"]
n = rpz2xyz_vec(n, phi=phi)
r_offset = re + offset * n
return n, re, r_offset

def fun_jax(zeta_hat, theta, zeta):
nodes = jnp.vstack((jnp.ones_like(theta), theta, zeta_hat)).T
n, r, r_offset = n_and_r_jax(nodes)
return jnp.arctan(r_offset[0, 1] / r_offset[0, 0]) - zeta

vecroot = jit(
vmap(
lambda x0, *p: root_scalar(
fun_jax, x0, jac=None, args=p, full_output=full_output
)
)
R_lmn, Z_lmn, data, (res, niter) = _constant_offset_surface(
base_surface,
offset,
grid=grid,
)
if full_output:
zetas, (res, niter) = vecroot(
grid.nodes[:, 2], grid.nodes[:, 1], grid.nodes[:, 2]
)
else:
zetas = vecroot(grid.nodes[:, 2], grid.nodes[:, 1], grid.nodes[:, 2])

zetas = np.asarray(zetas)
nodes = np.vstack((np.ones_like(grid.nodes[:, 1]), grid.nodes[:, 1], zetas)).T
n, x, x_offsets = n_and_r_jax(nodes)

data = {}
data["n"] = xyz2rpz_vec(n, phi=nodes[:, 1])
data["x"] = xyz2rpz(x)
data["x_offset_surface"] = xyz2rpz(x_offsets)

offset_surface = FourierRZToroidalSurface.from_values(
data["x_offset_surface"],
theta=nodes[:, 1],
M=M,
N=N,
NFP=base_surface.NFP,
sym=base_surface.sym,

offset_surface = FourierRZToroidalSurface(
R_lmn,
Z_lmn,
data["transforms"]["R"].basis.modes[:, 1:],
data["transforms"]["Z"].basis.modes[:, 1:],
base_surface.NFP,
base_surface.sym,
)

if full_output:
return offset_surface, data, (res, niter)
else:
Expand Down Expand Up @@ -1205,3 +1178,128 @@ def _get_ess_scale(self, alpha=1.2, order=np.inf, min_value=1e-7):
scales.update(get_ess_scale(modes, alpha, order, min_value))

return scales


def _constant_offset_surface(
base_surface,
offset,
grid,
transforms=None,
params=None,
):
"""Create a FourierRZToroidalSurface with constant offset from the base surface.

Implementation of algorithm described in Appendix B of
"An improved current potential method for fast computation of
stellarator coil shapes", Landreman (2017)
https://iopscience.iop.org/article/10.1088/1741-4326/aa57d4

NOTE: Must have the toroidal angle as the cylindrical toroidal angle
in order for this algorithm to work properly

NOTE: this function lacks the checks of the constant_offset_surface
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only is differentiable though if params is passed, otherwise the end will not be differentiatedcorrectlt

so that it is jittable/differentiable

Parameters
----------
base_surface : FourierRZToroidalSurface
Surface from which the constant offset surface will be found.
offset : float
constant offset (in m) of the desired surface from the input surface
offset will be in the normal direction to the surface.
grid : Grid, optional
Grid object of the points on the offset surface to evaluate the
offset points at, from which the offset surface will be created by fitting
offset points with the basis defined by the given M and N.
If None, defaults to a LinearGrid with M and N and NFP equal to the
base_surface.M and base_surface.N and base_surface.NFP
transforms: dict, optional
Transforms to use to fit the offset surface's R and Z, respectively. If None,
new transforms will be created using the given surface's M and N.
If given, should contain the keys ["R"] and ["Z"], with the pinv matrices
already built, and the corresponding grid should match the input grid.
params : dict, optional
dictionary of parameters to use when computing data from the base_surface.
If None, uses base_surface.params_dict, however the resulting computation
will not be differentiable with respect to the base_surface parameters
(since the JAX AD inside of an objective traces the params dictionaries
that are passedto their compute methods)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
that are passedto their compute methods)
that are passed to their compute methods)


Returns
-------
R_lmn, Z_lmn : array-like
coefficients describing the offset surface geometry
data : dict
dictionary containing the following data, in the cylindrical basis:
``n`` : (``grid.num_nodes`` x 3) array of the unit surface normal on
the base_surface evaluated at the input ``grid``
``x`` : (``grid.num_nodes`` x 3) array of coordinates on
the base_surface evaluated at the input ``grid``
``x_offset_surface`` : (``grid.num_nodes`` x 3) array of the
coordinates on the offset surface, corresponding to the
``x`` points on the base_surface (i.e. the points to which the
offset surface was fit)
as well as the transforms used to fit R and Z.
info : tuple
2 element tuple containing residuals and number of iterations
for each point.

"""
if params is None:
params = base_surface.params_dict

def n_and_r_jax(nodes):
data = base_surface.compute(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we sure this is differentiable?

["X", "Y", "Z", "n_rho"],
grid=Grid(nodes, jitable=True, sort=False),
method="jitable",
params=params,
)

phi = nodes[:, 2]
re = jnp.vstack([data["X"], data["Y"], data["Z"]]).T
n = data["n_rho"]
n = rpz2xyz_vec(n, phi=phi)
r_offset = re + offset * n
return n, re, r_offset

def fun_jax(zeta_hat, theta, zeta):
nodes = jnp.vstack((jnp.ones_like(theta), theta, zeta_hat)).T
n, r, r_offset = n_and_r_jax(nodes)
return jnp.arctan(r_offset[0, 1] / r_offset[0, 0]) - zeta

vecroot = jit(
vmap(
lambda x0, *p: root_scalar(
fun_jax, x0, jac=None, args=p, full_output=True, tol=1e-12
)
)
)
zetas, (res, niter) = vecroot(grid.nodes[:, 2], grid.nodes[:, 1], grid.nodes[:, 2])

zetas = jnp.asarray(zetas)
nodes = jnp.vstack((jnp.ones_like(grid.nodes[:, 1]), grid.nodes[:, 1], zetas)).T
n, x, x_offsets = n_and_r_jax(nodes)

data = {}
data["n"] = xyz2rpz_vec(n, phi=nodes[:, 1])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nodes[:, 1] is theta, isn't it?

data["x"] = xyz2rpz(x)
data["x_offset_surface"] = xyz2rpz(x_offsets)

if transforms is None:
# NOTE: we are assuming here that the rootfind was successful for every point,
# so that the zeta=arctan(y/x) of the offset surface point are the same as
# the grid nodes' zeta values. If this is not the case, the fitting
# will be incorrect.
transforms = get_transforms(
obj=base_surface, keys=["R", "Z"], grid=grid, jitable=True
)
transforms["R"].build_pinv()
transforms["Z"].build_pinv()

R_lmn = transforms["R"].fit(data["x_offset_surface"][:, 0])
Z_lmn = transforms["Z"].fit(data["x_offset_surface"][:, 2])

data["transforms"] = transforms

return R_lmn, Z_lmn, data, (res, niter)
Loading