Coordinax enables calculations with coordinates in JAX. Built on equinox and quax, with unit-support using unxt
pip install coordinax- Specialized quantities: scalar coordinate quantities with units, including
Angle(directional values on$S^1$ with explicit wrapping) andDistance(length-valued quantity), plus astronomy-facing forms likeParallaxandDistanceModulus. - Charts: a coordinate chart / component schema (names + physical dimensions). A chart does not store numerical values.
- Representation: geometric meaning of components, encoded as (geometry, basis, semantics), e.g.
point,coord_vel,phys_acc. - Point: data + chart + representation, with conversion and arithmetic behavior defined by chart transition maps and tangent pushforwards.
The most common import is the high-level user API:
import coordinax.main as cx>>> import coordinax.main as cx
>>> import unxt as u
>>> a = cx.Angle(30.0, "deg")
>>> d = cx.Distance(10.0, "kpc")
>>> u.uconvert("rad", a)
Angle(0.52359878, 'rad')>>> import unxt as u
>>> u.uconvert("rad", a)
Angle(0.52359878, 'rad')
Transform point coordinates between charts with pt_map:
>>> import coordinax.main as cx
>>> import unxt as u
>>> q = {"x": u.Q(1.0, "km"), "y": u.Q(2.0, "km"), "z": u.Q(3.0, "km")}
>>> q_sph = cx.pt_map(q, cx.cart3d, cx.sph3d)
>>> q_sph
{'r': Q(3.74165739, 'km'), 'theta': Q(0.64052231, 'rad'), 'phi': Q(1.10714872, 'rad')}Point carries chart + representation metadata, so conversions preserve semantics:
>>> import coordinax.main as cx
>>> vec = cx.Point.from_([1, 2, 3], "m")
>>> print(vec)
<Point: chart=Cart3D (x, y, z) [m]
[1 2 3]>
>>> sph_vec = vec.cconvert(cx.sph3d)
>>> print(sph_vec)
<Point: chart=Spherical3D (r[m], theta[rad], phi[rad])
[3.742 0.641 1.107]>Common representation constants are available from the high-level module:
import coordinax.main as cx
cx.point # point location data
cx.coord_vel # coordinate-basis velocity components
cx.phys_vel # physical-basis velocity componentsUse the built-in
>>> import jax.numpy as jnp
>>> import coordinax.charts as cxc
>>> import coordinax.manifolds as cxm
>>> import unxt as u
>>> # Unit two-sphere S^2 with its intrinsic round metric
>>> cxm.S2
HyperSphericalManifold(ndim=2)
>>> cxm.S2.metric
RoundMetric(ndim=2)
>>> # At the equator, measure the angle between northward and eastward tangents
>>> at = {"theta": u.Angle(jnp.pi / 2, "rad"), "phi": u.Angle(0.0, "rad")}
>>> u_north = {"theta": u.Angle(1.0, "rad"), "phi": u.Angle(0.0, "rad")}
>>> v_east = {"theta": u.Angle(0.0, "rad"), "phi": u.Angle(1.0, "rad")}
>>> cxm.S2.angle_between(cxc.sph2, u_north, v_east, at=at)
Angle(1.57079633, 'rad')Astronomy frames require the [astro] extra (pip install "coordinax[astro]") or to separately install the coordinax-astro package.
to_frame composes the full transformation chain automatically. The example below converts from ICRS to the Galactic bar frame, which co-rotates at pattern speed Rotate operator captures the rotation; TransformedReferenceFrame wraps the base frame with it; frame_transition fuses the resulting ICRS -> GCF -> bar chain on-the-fly:
>>> import jax.numpy as jnp
>>> import coordinax.main as cx
>>> import coordinax.astro as cxastro
>>> import coordinax.frames as cxf
>>> import coordinax.transforms as cxfm
>>> import unxt as u
>>> # ICRS -> Galactocentric (static: rotate, translate, rotate)
>>> sun = cx.Point.from_([0, 0, 0], "pc", cxastro.ICRS())
>>> print(sun.to_frame(cxastro.Galactocentric()))
<Point: chart=Cart3D (x, y, z) [pc]
[-8121.973 0. 20.8 ]>
>>> # Bar frame co-rotating at Omega_b — Rotate accepts a callable for t-dependence
>>> Omega_b = u.Q(0.0409, "rad/Myr") # approx 40 km/s/kpc
>>> def R_bar(t):
... theta = u.ustrip("rad", Omega_b * t)
... ct, st = jnp.cos(theta), jnp.sin(theta)
... return jnp.array([[ct, st, 0.0], [-st, ct, 0.0], [0.0, 0.0, 1.0]])
...
>>> bar_frame = cxf.TransformedReferenceFrame(cxastro.Galactocentric(), cxfm.Rotate(R_bar))
>>> # ICRS -> bar: frame_transition fuses all four operators
>>> cx.frame_transition(cxastro.ICRS(), bar_frame)
Composed((...))
>>> # Sun's ICRS-origin position expressed in the bar frame at t = 500 Myr
>>> print(sun.to_frame(bar_frame, t=u.Q(500.0, "Myr")))
<Point: chart=Cart3D (x, y, z) [pc]
[ 240.763 8118.404 20.8 ]>If you found this library to be useful in academic work, then please cite.
We welcome contributions!
For the local development workflow, see docs/dev.md. For pull request expectations, see docs/contributing.md.