Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
7958cf5
Add new helper functions for curve resampling and Fourier fitting in …
rogeriojorge Sep 6, 2025
a8ceb71
Refactor curve resampling functions in coils.py for improved clarity …
rogeriojorge Sep 6, 2025
7c5ae51
Update coil parameters and environment settings in coils_from_nearaxi…
rogeriojorge Sep 6, 2025
84905e7
Update coil parameters and fitting logic in coils_from_nearaxis.py fo…
rogeriojorge Sep 6, 2025
7585cd7
Started working on phiboozer instead of phi
rogeriojorge Sep 8, 2025
3f98fff
Same
rogeriojorge Sep 8, 2025
ad1d5b1
Update near_axis class to include additional parameters and enhance b…
rogeriojorge Sep 9, 2025
c518db6
Computing phi off axis from varphi off axis and phi on axis
rogeriojorge Sep 10, 2025
caefce7
Add new script for Boozer transformation and visualization; include o…
rogeriojorge Sep 10, 2025
22481dc
Add parallelization support and environment variable setup for XLA in…
rogeriojorge Sep 11, 2025
9794642
Refactor coils_from_BOOZXFORM.py to enhance surface plotting and fiel…
rogeriojorge Sep 11, 2025
857d69a
Remove obsolete low-resolution output file for Boozer transformation.
rogeriojorge Sep 11, 2025
3c699db
Refactor dynamics.py to use NoProgressMeter for progress tracking and…
rogeriojorge Sep 12, 2025
0510d4b
Refactor dynamics.py to switch progress meter to TqdmProgressMeter; u…
rogeriojorge Sep 13, 2025
7c85737
Add axis labels to 2D coil plot for improved clarity
rogeriojorge Sep 13, 2025
8d1c013
Refactor coils_from_BOOZ_XFORM.py and coils_from_nearaxis.py to enhan…
rogeriojorge Sep 21, 2025
ac08848
Fixed sharding with random keys and coils calling curvature
rogeriojorge Sep 21, 2025
4df878f
Updated examples
eduardolneto Sep 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 101 additions & 3 deletions essos/coils.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,8 +518,8 @@ def RotatedCurve(curve, phi, flip):
if flip:
rotmat = rotmat @ jnp.array(
[[1, 0, 0],
[0, -1, 0],
[0, 0, -1]])
[0, -1, 0],
[0, 0, -1]])
return curve @ rotmat

@partial(jit, static_argnames=['nfp', 'stellsym'])
Expand Down Expand Up @@ -559,4 +559,102 @@ def apply_symmetries_to_currents(base_currents, nfp, stellsym):
for i in range(len(base_currents)):
current = -base_currents[i] if flip else base_currents[i]
currents.append(current)
return jnp.array(currents)
return jnp.array(currents)

def _resample_closed_curve_uniform_one(g: jnp.ndarray, n_segments: int) -> jnp.ndarray:
"""
One-curve arclength resample to n_segments points on t∈[0,1), piecewise linear.
g: (M,3) closed curve (first≈last not required; we close internally).
Returns: (n_segments,3)
"""
# Close the loop
g0 = g[0:1, :]
g_ext = jnp.concatenate([g, g0], axis=0) # (M+1,3)
seg = g_ext[1:] - g_ext[:-1] # (M,3)
seg_len = jnp.linalg.norm(seg, axis=1) # (M,)
cum = jnp.concatenate([jnp.array([0.0]), jnp.cumsum(seg_len)], axis=0) # (M+1,)
total = cum[-1]
# Uniform targets in arclength (exclude total to avoid duplicate)
s_targets = jnp.linspace(0.0, total, n_segments, endpoint=False) # (n_segments,)
# For each s_t, find i with cum[i] <= s_t < cum[i+1]
idx = jnp.searchsorted(cum, s_targets, side='right') - 1 # (n_segments,)
idx = jnp.clip(idx, 0, seg.shape[0]-1)
s0 = cum[idx]
s1 = cum[idx+1]
w = (s_targets - s0) / jnp.maximum(s1 - s0, 1e-20) # (n_segments,)
p0 = g_ext[idx]
p1 = g_ext[idx+1]
return p0 + w[:, None] * (p1 - p0) # (n_segments,3)

def _resample_closed_curve_uniform_batch(gammas: jnp.ndarray, n_segments: int) -> jnp.ndarray:
"""
Batch arclength resample.
gammas: (Ncoils, M, 3) (all curves same M; if not, pre-interp in index space).
Returns: (Ncoils, n_segments, 3)
"""
return vmap(_resample_closed_curve_uniform_one, in_axes=(0, None))(gammas, n_segments)

@partial(jit, static_argnames=('order',))
def _fit_real_fourier_batch(gamma_uni: jnp.ndarray, order: int) -> jnp.ndarray:
"""
gamma_uni: (Ncoils, Nseg, 3), samples at t_j = j/Nseg, j=0..Nseg-1
Returns dofs: (Ncoils, 3, 2*order+1) with [a0, sin1, cos1, ..., sinK, cosK].
"""
Ncoils, Nseg, _ = gamma_uni.shape # Nseg is static if n_segments was static upstream
Kmax = min(order, Nseg // 2) # <-- Python int (static)

g = jnp.transpose(gamma_uni, (0, 2, 1)) # (Ncoils, 3, Nseg)
F = jnp.fft.rfft(g, axis=-1) / Nseg # (Ncoils, 3, Nseg//2 + 1)

a0 = F[..., 0].real # (Ncoils, 3)

# Static slice (OK under jit)
Fk = F[..., 1:1 + Kmax] # (Ncoils, 3, Kmax)

cos_k = 2.0 * Fk.real # (Ncoils, 3, Kmax)
sin_k = -2.0 * Fk.imag # (Ncoils, 3, Kmax)

# Pad to 'order' if needed (pad width is also static here)
if Kmax < order:
pad = order - Kmax
zshape = (cos_k.shape[0], cos_k.shape[1], pad)
z = jnp.zeros(zshape, dtype=gamma_uni.dtype)
cos_k = jnp.concatenate([cos_k, z], axis=-1) # (Ncoils, 3, order)
sin_k = jnp.concatenate([sin_k, z], axis=-1) # (Ncoils, 3, order)

inter = jnp.empty((Ncoils, 3, 2*order), dtype=gamma_uni.dtype)
inter = inter.at[..., 0::2].set(sin_k) # sin₁, sin₂, ...
inter = inter.at[..., 1::2].set(cos_k) # cos₁, cos₂, ...

dofs = jnp.concatenate([a0[..., None], inter], axis=-1) # (Ncoils, 3, 2*order+1)
return dofs

@partial(jit, static_argnames=('order','n_segments','assume_uniform'))
def fit_dofs_from_coils(
coils_gamma: jnp.ndarray,
order: int,
n_segments: int,
assume_uniform: bool = False,
) -> tuple[jnp.ndarray, jnp.ndarray]:
"""
Fast path (batched + JIT + rFFT).
coils_gamma: (Ncoils, M, 3) JAX array. If M != n_segments and assume_uniform=True,
curves are uniformly subsampled in index space. If assume_uniform=False,
do arclength resampling (slower but accurate).
Returns:
dofs: (Ncoils, 3, 2*order+1)
gamma_resampled: (Ncoils, n_segments, 3)
"""
Ncoils, M, _ = coils_gamma.shape
if assume_uniform:
if M == n_segments:
gamma_uni = coils_gamma
else:
# uniform subsampling in index space (fast)
idx = jnp.floor(jnp.linspace(0, M, n_segments, endpoint=False)).astype(int) % M
gamma_uni = coils_gamma[:, idx, :]
else:
gamma_uni = _resample_closed_curve_uniform_batch(coils_gamma, n_segments) # arclength (vmapped)

dofs = _fit_real_fourier_batch(gamma_uni, order) # rFFT-based fit
return dofs, gamma_uni
27 changes: 14 additions & 13 deletions essos/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from jax import jit, vmap, tree_util, random, lax, device_put
from functools import partial
from diffrax import diffeqsolve, ODETerm, SaveAt, Tsit5, PIDController, Event, TqdmProgressMeter
from diffrax import diffeqsolve, ODETerm, SaveAt, Tsit5, PIDController, Event, TqdmProgressMeter, NoProgressMeter
from diffrax import ControlTerm,UnsafeBrownianPath,MultiTerm,ItoMilstein,ClipStepSizeController #For collisions we need this to solve stochastic differential equation
import diffrax
from essos.coils import Coils
Expand Down Expand Up @@ -501,6 +501,7 @@ def __init__(self, trajectories_input=None, initial_conditions=None, times_to_tr
self.particles = particles
self.species=species
self.tag_gc=tag_gc
self.progress_meter = TqdmProgressMeter() # NoProgressMeter() # TqdmProgressMeter()
if condition is None:
self.condition = lambda t, y, args, **kwargs: False
if isinstance(field, Vmec):
Expand Down Expand Up @@ -694,7 +695,7 @@ def update_state(state, _):
#stepsize_controller = PIDController(pcoeff=0.4, icoeff=0.3, dcoeff=0, rtol=self.tol_step_size, atol=self.tol_step_size),
max_steps=10000000000,
event = Event(self.condition),
progress_meter=TqdmProgressMeter(),
progress_meter=self.progress_meter,
).ys
elif self.model == 'GuidingCenterCollisionsMuAdaptative':
import warnings
Expand All @@ -720,7 +721,7 @@ def update_state(state, _):
stepsize_controller=ClipStepSizeController(controller=PIDController(pcoeff=0.1, icoeff=0.3, dcoeff=0.0, rtol=self.rtol, atol=self.atol,dtmin=dt0,dtmax=1.e-4,force_dtmin=True),step_ts=self.times,store_rejected_steps=self.rejected_steps),
max_steps=10000000000,
event = Event(self.condition),
progress_meter=TqdmProgressMeter(),
progress_meter=self.progress_meter,
).ys
elif self.model == 'GuidingCenterCollisionsMuFixed':
import warnings
Expand All @@ -744,7 +745,7 @@ def update_state(state, _):
# adjoint=DirectAdjoint(),
max_steps=10000000000,
event = Event(self.condition),
progress_meter=TqdmProgressMeter(),
progress_meter=self.progress_meter,
).ys
elif self.model == 'GuidingCenterCollisionsMuIto':
import warnings
Expand All @@ -768,7 +769,7 @@ def update_state(state, _):
# adjoint=DirectAdjoint(),
max_steps=10000000000,
event = Event(self.condition),
progress_meter=TqdmProgressMeter(),
progress_meter=self.progress_meter,
).ys
elif self.model == 'FullOrbitCollisions':
import warnings
Expand All @@ -794,7 +795,7 @@ def update_state(state, _):
stepsize_controller = PIDController(pcoeff=0.4, icoeff=0.3, dcoeff=0, rtol=self.tol_step_size, atol=self.tol_step_size,dtmin=dt0),
max_steps=10000000000,
event = Event(self.condition),
progress_meter=TqdmProgressMeter()
progress_meter=self.progress_meter,
).ys
elif self.model == 'GuidingCenterAdaptative' :
import warnings
Expand All @@ -805,12 +806,12 @@ def update_state(state, _):
t1=self.maxtime,
dt0=self.timestep,#self.maxtime / self.timesteps,
y0=initial_condition,
solver=diffrax.Tsit5(),
solver=diffrax.Dopri8(),
args=self.args,
saveat=SaveAt(ts=self.times),
throw=False,
# adjoint=DirectAdjoint(),
progress_meter=TqdmProgressMeter(),
progress_meter=self.progress_meter,
stepsize_controller = PIDController(pcoeff=0.4, icoeff=0.3, dcoeff=0, rtol=self.rtol, atol=self.atol),
max_steps=10000000000,
event = Event(self.condition)
Expand All @@ -824,12 +825,12 @@ def update_state(state, _):
t1=self.maxtime,
dt0=self.timestep,#self.maxtime / self.timesteps,
y0=initial_condition,
solver=diffrax.Tsit5(),
solver=diffrax.Dopri8(),
args=self.args,
saveat=SaveAt(ts=self.times),
throw=False,
# adjoint=DirectAdjoint(),
progress_meter=TqdmProgressMeter(),
progress_meter=self.progress_meter,
stepsize_controller = PIDController(pcoeff=0.4, icoeff=0.3, dcoeff=0, rtol=self.rtol, atol=self.atol),
max_steps=10000000000,
event = Event(self.condition)
Expand All @@ -844,19 +845,19 @@ def update_state(state, _):
t1=self.maxtime,
dt0=self.timestep,#self.maxtime / self.timesteps,
y0=initial_condition,
solver=diffrax.Tsit5(),
solver=diffrax.Dopri8(),
args=self.args,
saveat=SaveAt(ts=self.times),
throw=False,
# adjoint=DirectAdjoint(),
progress_meter=TqdmProgressMeter(),
progress_meter=self.progress_meter,
max_steps=10000000000,
event = Event(self.condition)
).ys
return trajectory

return jit(vmap(compute_trajectory,in_axes=(0,0)), in_shardings=(sharding,sharding_index), out_shardings=sharding)(
device_put(self.initial_conditions, sharding), device_put(self.particles.random_keys if self.particles else None, sharding_index))
device_put(self.initial_conditions, sharding), device_put(self.particles.random_keys if self.particles else None, sharding_index))
#x=jax.device_put(self.initial_conditions, sharding)
#y=jax.device_put(self.particles.random_keys, sharding_index)
#sharded_fun = jax.jit(jax.shard_map(jax.vmap(compute_trajectory,in_axes=(0,0)), mesh=mesh, in_specs=(spec,spec_index), out_specs=spec))
Expand Down
Loading