Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
253 changes: 251 additions & 2 deletions desc/magnetic_fields/_core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Classes for magnetic fields."""

import functools
import warnings
from abc import ABC, abstractmethod
from collections.abc import MutableSequence
Expand All @@ -15,9 +16,9 @@
)
from interpax import approx_df, interp1d, interp2d, interp3d
from netCDF4 import Dataset, chartostring, stringtochar
from scipy.constants import mu_0
from scipy.constants import mu_0, physical_constants

from desc.backend import jit, jnp, sign
from desc.backend import jit, jnp, sign, vmap
from desc.basis import (
ChebyshevDoubleFourierBasis,
ChebyshevPolynomial,
Expand All @@ -36,6 +37,7 @@
from desc.transform import Transform
from desc.utils import (
copy_coeffs,
cross,
dot,
errorif,
flatten_list,
Expand Down Expand Up @@ -2654,6 +2656,253 @@
return r, z


def trace_particles(
x0,
lambda0,
ts,
field,
m=4,
q=2,
E=3.52e6,
gyrophase=0.0,
mode="gc-vac",
params=None,
source_grid=None,
basis="rpz",
rtol=1e-8,
atol=1e-8,
maxstep=1000,
min_step_size=1e-8,
solver=Tsit5(),
bounds_R=(0, np.inf),
bounds_Z=(-np.inf, np.inf),
**kwargs,
):
"""Trace particles in external magnetic field.

Parameters
----------
x0 : array-like, shape(num_particles,3)
Initial starting coordinates for r, phi, z (if basis="rpz")
or x, y, z (if basis="xyz")
lambda0 : array-like, shape(num_particles,)
Initial value for pitch angle parameter λ = v²_⊥ / v²
ts : array-like
Strictly increasing array of times where output is desired.
field : MagneticField
Source of magnetic field to integrate
m : float or array-like, shape(num_particles,)
Mass of particles, in units of proton masses
q : float or array-like, shape(num_particles,)
Charge of particles, in units of elementary charge.
E : float or array-like, shape(num_particles,)
Kinetic energy of particles, in eV.
gyrophase : float or array-like, shape(num_particles,)
Initial gyrophase of particles in [0, 2pi]. Only used if ``mode="full-orbit"``
mode : {"gc-vac", "full-orbit"}
Set of equations to solve.
params: dict, optional
Parameters passed to field
source_grid : Grid, optional
Grid to use to discretize field
basis : {"rpz", "xyz"}
Whether to use cylindrical or cartesian coordinates.
rtol, atol : float
relative and absolute tolerances for ode integration
maxstep : int
maximum number of steps between different output times
min_step_size: float
minimum step size (in t) that the integration can take. default is 1e-8
solver: diffrax.Solver
diffrax Solver object to use in integration,
defaults to Tsit5(), a RK45 explicit solver
bounds_R : tuple of (float,float), optional
R bounds for bounding box. Trajectories that leave this box will be stopped,
and NaN returned for points outside the box. Defaults to (0,np.inf)
bounds_Z : tuple of (float,float), optional
Z bounds for bounding box. Trajectories that leave this box will be stopped,
and NaN returned for points outside the box. Defaults to (-np.inf,np.inf)
kwargs: dict
keyword arguments to be passed into the ``diffrax.diffeqsolve``

Returns
-------
x : ndarray, shape(num_particles, num_timesteps, 3)
Position of each particle at each requested time, in
either r,phi,z or x,y,z depending on basis argument.
v : ndarray
Velocity of each particle at specified times. For ``mode="gc-vac"`` this is
the parallel velocity of shape shape(num_particles, num_timesteps, 1), for
``mode="full-orbit"`` this is the velocity vector in whichever basis was
specified, of shape(num_particles, num_timesteps, 3).

"""
errorif(

Check warning on line 2740 in desc/magnetic_fields/_core.py

View check run for this annotation

Codecov / codecov/patch

desc/magnetic_fields/_core.py#L2740

Added line #L2740 was not covered by tests
mode not in ["gc-vac", "full-orbit"],
ValueError,
f"mode should be one of 'gc-vac' or 'full-orbit', got {mode}",
)
x0, lambda0, m, q, E = map(jnp.asarray, (x0, lambda0, m, q, E))
n_particles = x0.shape[0]
lambda0, m, q, E, gyrophase = map(

Check warning on line 2747 in desc/magnetic_fields/_core.py

View check run for this annotation

Codecov / codecov/patch

desc/magnetic_fields/_core.py#L2745-L2747

Added lines #L2745 - L2747 were not covered by tests
lambda x: jnp.broadcast_to(x, n_particles), (lambda0, m, q, E, gyrophase)
)

@jit
def field_compute(x):
return field.compute_magnetic_field(

Check warning on line 2753 in desc/magnetic_fields/_core.py

View check run for this annotation

Codecov / codecov/patch

desc/magnetic_fields/_core.py#L2751-L2753

Added lines #L2751 - L2753 were not covered by tests
jnp.atleast_2d(x), params=params, basis=basis, grid=source_grid
).squeeze()

m *= physical_constants["proton mass"][0]
q *= physical_constants["elementary charge"][0]
E *= physical_constants["electron volt"][0]

Check warning on line 2759 in desc/magnetic_fields/_core.py

View check run for this annotation

Codecov / codecov/patch

desc/magnetic_fields/_core.py#L2757-L2759

Added lines #L2757 - L2759 were not covered by tests

modv0 = jnp.sqrt(2 * E / m) # speed |v|
vperp0 = jnp.sqrt(lambda0) * modv0
vpar0 = jnp.sqrt(modv0**2 - vperp0**2)
B = field_compute(x0)

Check warning on line 2764 in desc/magnetic_fields/_core.py

View check run for this annotation

Codecov / codecov/patch

desc/magnetic_fields/_core.py#L2761-L2764

Added lines #L2761 - L2764 were not covered by tests

if mode == "gc-vac":
y0 = jnp.hstack([x0, vpar0[:, None]])
modB = jnp.linalg.norm(B, axis=-1)
mu = vperp0**2 / (2 * modB)
args = (m / q, mu)
odefun = functools.partial(

Check warning on line 2771 in desc/magnetic_fields/_core.py

View check run for this annotation

Codecov / codecov/patch

desc/magnetic_fields/_core.py#L2766-L2771

Added lines #L2766 - L2771 were not covered by tests
_guiding_center_vacuum, field_compute=field_compute, basis=basis
)
elif mode == "full-orbit":
x0, v0 = _full_orbit_ic_from_gc(x0, vpar0, vperp0, B, gyrophase, m, q)
y0 = jnp.hstack([x0, v0])
args = q / m
odefun = functools.partial(_full_orbit, field_compute=field_compute)

Check warning on line 2778 in desc/magnetic_fields/_core.py

View check run for this annotation

Codecov / codecov/patch

desc/magnetic_fields/_core.py#L2774-L2778

Added lines #L2774 - L2778 were not covered by tests

# diffrax parameters

def default_terminating_event_fxn(state, **kwargs):
R_out = jnp.any(jnp.array([state.y[0] < bounds_R[0], state.y[0] > bounds_R[1]]))
Z_out = jnp.any(jnp.array([state.y[2] < bounds_Z[0], state.y[2] > bounds_Z[1]]))
return jnp.any(jnp.array([R_out, Z_out]))

Check warning on line 2785 in desc/magnetic_fields/_core.py

View check run for this annotation

Codecov / codecov/patch

desc/magnetic_fields/_core.py#L2782-L2785

Added lines #L2782 - L2785 were not covered by tests

kwargs.setdefault(

Check warning on line 2787 in desc/magnetic_fields/_core.py

View check run for this annotation

Codecov / codecov/patch

desc/magnetic_fields/_core.py#L2787

Added line #L2787 was not covered by tests
"stepsize_controller", PIDController(rtol=rtol, atol=atol, dtmin=min_step_size)
)
kwargs.setdefault(

Check warning on line 2790 in desc/magnetic_fields/_core.py

View check run for this annotation

Codecov / codecov/patch

desc/magnetic_fields/_core.py#L2790

Added line #L2790 was not covered by tests
"discrete_terminating_event",
DiscreteTerminatingEvent(default_terminating_event_fxn),
)

term = ODETerm(odefun)
saveat = SaveAt(ts=ts)

Check warning on line 2796 in desc/magnetic_fields/_core.py

View check run for this annotation

Codecov / codecov/patch

desc/magnetic_fields/_core.py#L2795-L2796

Added lines #L2795 - L2796 were not covered by tests

intfun = lambda x, args: diffeqsolve(

Check warning on line 2798 in desc/magnetic_fields/_core.py

View check run for this annotation

Codecov / codecov/patch

desc/magnetic_fields/_core.py#L2798

Added line #L2798 was not covered by tests
term,
solver,
y0=x,
t0=ts[0],
t1=ts[-1],
saveat=saveat,
max_steps=maxstep * len(ts),
dt0=min_step_size,
args=args,
**kwargs,
).ys

# suppress warnings till its fixed upstream:
# https://github.com/patrick-kidger/diffrax/issues/445
# also ignore deprecation warning for now until we actually need to deal with it
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="unhashable type")
warnings.filterwarnings("ignore", message="`diffrax.*discrete_terminating")
yt = jit(vmap(intfun))(y0, args)

Check warning on line 2817 in desc/magnetic_fields/_core.py

View check run for this annotation

Codecov / codecov/patch

desc/magnetic_fields/_core.py#L2814-L2817

Added lines #L2814 - L2817 were not covered by tests

yt = jnp.where(jnp.isinf(yt), jnp.nan, yt)

Check warning on line 2819 in desc/magnetic_fields/_core.py

View check run for this annotation

Codecov / codecov/patch

desc/magnetic_fields/_core.py#L2819

Added line #L2819 was not covered by tests

x = yt[:, :, :3]
v = yt[:, :, 3:]

Check warning on line 2822 in desc/magnetic_fields/_core.py

View check run for this annotation

Codecov / codecov/patch

desc/magnetic_fields/_core.py#L2821-L2822

Added lines #L2821 - L2822 were not covered by tests

return x, v

Check warning on line 2824 in desc/magnetic_fields/_core.py

View check run for this annotation

Codecov / codecov/patch

desc/magnetic_fields/_core.py#L2824

Added line #L2824 was not covered by tests


def _guiding_center_vacuum(t, y, args, field_compute, basis):
# this is the one implemented in simsopt for method="gc_vac"
# should be equivalent to full lagrangian from Cary & Brizard in vacuum
m_over_q, mu = args
vpar = y[-1]
x = y[:-1]
B = field_compute(x)
dB = jnp.vectorize(

Check warning on line 2834 in desc/magnetic_fields/_core.py

View check run for this annotation

Codecov / codecov/patch

desc/magnetic_fields/_core.py#L2830-L2834

Added lines #L2830 - L2834 were not covered by tests
Derivative(
field_compute,
mode="fwd",
),
signature="(n)->(n,n)",
)(x).squeeze()

modB = jnp.linalg.norm(B, axis=-1)
b = B / modB
grad_B = jnp.sum(b[:, None] * dB, axis=0)
if basis == "rpz":
g1, g2, g3 = grad_B

Check warning on line 2846 in desc/magnetic_fields/_core.py

View check run for this annotation

Codecov / codecov/patch

desc/magnetic_fields/_core.py#L2842-L2846

Added lines #L2842 - L2846 were not covered by tests
# factor of R from grad in cylindrical coordinates
g2 /= x[0]
grad_B = jnp.array([g1, g2, g3])

Check warning on line 2849 in desc/magnetic_fields/_core.py

View check run for this annotation

Codecov / codecov/patch

desc/magnetic_fields/_core.py#L2848-L2849

Added lines #L2848 - L2849 were not covered by tests

dRdt = vpar * b + (m_over_q / modB**2 * (mu * modB + vpar**2)) * cross(b, grad_B)
if basis == "rpz":
d1, d2, d3 = dRdt
d2 /= x[0]
dRdt = jnp.array([d1, d2, d3])

Check warning on line 2855 in desc/magnetic_fields/_core.py

View check run for this annotation

Codecov / codecov/patch

desc/magnetic_fields/_core.py#L2851-L2855

Added lines #L2851 - L2855 were not covered by tests

dvdt = -mu * dot(b, grad_B)
dxdt = jnp.append(dRdt, dvdt)
return dxdt.flatten()

Check warning on line 2859 in desc/magnetic_fields/_core.py

View check run for this annotation

Codecov / codecov/patch

desc/magnetic_fields/_core.py#L2857-L2859

Added lines #L2857 - L2859 were not covered by tests


def _full_orbit(t, y, args, field_compute, basis):
q_over_m = args[0]
x, v = y[:3], y[3:]
B = field_compute(x)
dx = v
dv = q_over_m * cross(v, B)
return jnp.concatenate([dx, dv]).flatten()

Check warning on line 2868 in desc/magnetic_fields/_core.py

View check run for this annotation

Codecov / codecov/patch

desc/magnetic_fields/_core.py#L2863-L2868

Added lines #L2863 - L2868 were not covered by tests


def _gc_radius(vperp, modB, m, q):
"""Radius of guiding center orbit."""
return m * vperp / (jnp.abs(q) * modB)

Check warning on line 2873 in desc/magnetic_fields/_core.py

View check run for this annotation

Codecov / codecov/patch

desc/magnetic_fields/_core.py#L2873

Added line #L2873 was not covered by tests


def _full_orbit_ic_from_gc(x0, vpar, vperp, B, eta, m, q):
modB = jnp.linalg.norm(B)
b = B / modB

Check warning on line 2878 in desc/magnetic_fields/_core.py

View check run for this annotation

Codecov / codecov/patch

desc/magnetic_fields/_core.py#L2877-L2878

Added lines #L2877 - L2878 were not covered by tests
# heavily borrowed from simsopt
# https://github.com/hiddenSymmetries/simsopt/blob/
# 3362805d306dff96de099da3c576850e1ec603f2/src/simsopt/field/tracing.py#L40

# construct 3 unit vectors, not necessarily orthogonal
# (but at least linearly independent)
p1 = b

Check warning on line 2885 in desc/magnetic_fields/_core.py

View check run for this annotation

Codecov / codecov/patch

desc/magnetic_fields/_core.py#L2885

Added line #L2885 was not covered by tests
# anything other than b
# note this is ok in rpz or xyz, since b shouldn't be purely in R or X directions
# (if it is, something else is probably wrong)
p2 = jnp.array([1, 0, 0])
p3 = -jnp.cross(p1, p2) # some third vector not parallel to p1 or p2
p3 /= jnp.linalg.norm(p3)

Check warning on line 2891 in desc/magnetic_fields/_core.py

View check run for this annotation

Codecov / codecov/patch

desc/magnetic_fields/_core.py#L2889-L2891

Added lines #L2889 - L2891 were not covered by tests
# now do Gram-Schmidt to find orthogonal basis around B
q1 = p1 # b
q2 = p2 - dot(q1, p2) * q1
q2 /= jnp.linalg.norm(q2)
q3 = p3 - dot(q1, p3) * q1 - dot(q2, p3) * q2
q3 /= jnp.linalg.norm(q3)
r = _gc_radius(vperp, modB, m, q)

Check warning on line 2898 in desc/magnetic_fields/_core.py

View check run for this annotation

Codecov / codecov/patch

desc/magnetic_fields/_core.py#L2893-L2898

Added lines #L2893 - L2898 were not covered by tests

# transform from guiding center frame to particle frame
x0 = x0 + r * jnp.sin(eta) * q2 + r * jnp.cos(eta) * q3
v0 = vpar * q1 + vperp * (-jnp.cos(eta) * q2 + jnp.sin(eta) * q3)
return x0, v0

Check warning on line 2903 in desc/magnetic_fields/_core.py

View check run for this annotation

Codecov / codecov/patch

desc/magnetic_fields/_core.py#L2901-L2903

Added lines #L2901 - L2903 were not covered by tests


class OmnigenousField(Optimizable, IOAble):
"""A magnetic field with perfect omnigenity (but is not necessarily analytic).

Expand Down
Loading