Skip to content
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@
"wadler_lindig>=0.1.6",
"xmmutablemap>=0.1",
"zeroth>=1.0",
"spexial @ git+https://github.com/JAXtronomy/spexial.git@main",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unfortunately, spexial is unlikely to be released. Can you move the function into potential/scf/?

"hypothesis>=6.135.14",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a test-time dependency. It can be added by uv add --group test hypothesis.

"gala>=1.9.1",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gala is already a test-time dependency. :)

]

[project.optional-dependencies]
Expand Down
4 changes: 2 additions & 2 deletions src/galax/potential/_src/params/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __call__(

Parameters
----------
t : `~galax.typing.BBtQuSz0`
t : `~galax._custom_types.BBtQuSz0`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True. We should fix this. Probably my just not having the types in the docstring. The actual annotations are authoritative 🤷 .

Time(s) at which to compute the parameter value.
ustrip : Unit | None
Unit to strip from the parameter value.
Expand Down Expand Up @@ -62,7 +62,7 @@ def __call__(

Parameters
----------
t : `~galax.typing.BBtQuSz0`
t : `~galax._custom_types.BBtQuSz0`
The time(s) at which to compute the parameter value.
ustrip : Unit | None
The unit to strip from the parameter value. If None, the
Expand Down
2 changes: 1 addition & 1 deletion src/galax/potential/_src/params/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def __call__(

Parameters
----------
t : `~galax.typing.BBtQuSz0`, optional
t : `~galax._custom_types.BBtQuSz0`, optional
This is ignored and is thus optional. Note that for most
:class:`~galax.potential.AbstractParameter` the time is required.
ustrip : Unit | None
Expand Down
11 changes: 11 additions & 0 deletions src/galax/potential/_src/scf/__init__.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For static analyses' sake we prefer the long-form exports. Annoying, I know, but it makes mypy happier.

Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from . import bfe, bfe_helper, coeffs, coeffs_helper
from .bfe import *
from .bfe_helper import *
from .coeffs import *
from .coeffs_helper import *

__all__: list[str] = []
__all__ += bfe.__all__
__all__ += bfe_helper.__all__
__all__ += coeffs.__all__
__all__ += coeffs_helper.__all__
189 changes: 189 additions & 0 deletions src/galax/potential/_src/scf/bfe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
"""Self-Consistent Field Potential."""

__all__ = ["SCFPotential", "STnlmSnapshotParameter"]

from collections.abc import Callable
from functools import partial
from typing import Any

import equinox as eqx
import jax
import jax.numpy as jnp
from jaxtyping import Array, Float

import galax._custom_types as gt
from .bfe_helper import phi_nl_vec, rho_nl as calculate_rho_nl
from .coeffs import compute_coeffs_discrete
from .utils import cartesian_to_spherical, real_Ylm
from galax.potential import AbstractPotential
from galax.potential._src.params.base import AbstractParameter
from galax.potential._src.params.field import ParameterField

##############################################################################


class SCFPotential(AbstractPotential):
r"""Self-Consistent Field (SCF) potential.

A gravitational potential represented as a basis function expansion. This
uses the self-consistent field (SCF) method of Hernquist & Ostriker (1992)
and Lowing et al. (2011), and represents all coefficients as real
quantities.

Parameters
----------
m : numeric
Scale mass.
r_s : numeric
Scale length.
Snlm : Array[float, (nmax+1, lmax+1, lmax+1)] | Callable
Array of coefficients for the cos() terms of the expansion. This should
be a 3D array with shape `(nmax+1, lmax+1, lmax+1)`, where `nmax` is the
number of radial expansion terms and `lmax` is the number of spherical
harmonic `l` terms. If a callable is provided, it should accept a
single argument `t` and return the array of coefficients for that time.
Tnlm : Array[float, (nmax+1, lmax+1, lmax+1)] | Callable
Array of coefficients for the sin() terms of the expansion. This should
be a 3D array with shape `(nmax+1, lmax+1, lmax+1)`, where `nmax` is the
number of radial expansion terms and `lmax` is the number of spherical
harmonic `l` terms. If a callable is provided, it should accept a
single argument `t` and return the array of coefficients for that time.
units : iterable
Unique list of non-reducable units that specify (at minimum) the length,
mass, time, and angle units.
"""

m: AbstractParameter = ParameterField(dimensions="mass")
r_s: AbstractParameter = ParameterField(dimensions="length")
Snlm: AbstractParameter = ParameterField(dimensions="dimensionless")
Tnlm: AbstractParameter = ParameterField(dimensions="dimensionless")

nmax: int = eqx.field(init=False, static=True, repr=False)
lmax: int = eqx.field(init=False, static=True, repr=False)

def __post_init__(self) -> None:
super().__post_init__()

# shape parameters
shape = self.Snlm(0).shape
object.__setattr__(self, "nmax", shape[0] - 1)
object.__setattr__(self, "lmax", shape[1] - 1)

# ==========================================================================

@partial(jax.jit, inline=True)
def _potential(
self, xyz: gt.BtQuSz3, t: gt.BtQuSz0, /
) -> gt.SzN | gt.FloatSz0:
r_s = self.r_s(t)
r, theta, phi = cartesian_to_spherical(xyz).T

s = jnp.atleast_1d(r / r_s) # ([n],[l],[m],[N])
theta = jnp.atleast_1d(theta)[None, None, None] # ([n],[l],[m],[N])
phi = jnp.atleast_1d(phi)[None, None, None] # ([n],[l],[m],[N])

ns = jnp.arange(self.nmax + 1)[:, None, None] # (n, [l], [m])
ls = jnp.arange(self.lmax + 1)[None, :, None] # ([n], l, [m])
phi_nl = phi_nl_vec(s, ns, ls) # (n, l, [m], N)

li, mi = jnp.tril_indices(self.lmax + 1) # (l*(l+1)//2,)
shape = (1, self.lmax + 1, self.lmax + 1, 1) # ([n], l, m, [N])
midx = jnp.zeros(shape, dtype=int).at[:, li, mi, 0].set(mi) # ([n], l, m, [N])

Ylm = jnp.zeros(shape[:-1] + (len(s),))
Ylm = Ylm.at[0, li, mi, :].set(
real_Ylm(theta[:, 0, 0, :], li[..., None], mi[..., None])
)

Snlm = self.Snlm(t, r_s=r_s)[..., None]
Tnlm = self.Tnlm(t, r_s=r_s)[..., None]

out = (self._G * self.m(t) / r_s) * jnp.sum(
Ylm * phi_nl * (Snlm * jnp.cos(midx * phi) + Tnlm * jnp.sin(midx * phi)),
axis=(0, 1, 2),
)
return out[0] if len(xyz.shape) == 1 else out

@partial(jax.jit, inline=True)
@eqx.filter_vmap(in_axes=(None, 1, None)) # type: ignore[misc] # on `q` axis 1
def _density(self, q: gt.QuSz3, /, t: gt.QuSz0) -> Float[Array, "N"]: # type: ignore[name-defined]
"""Compute the density at the given position(s)."""
r, theta, phi = cartesian_to_spherical(q)
r_s = self.r_s(t)
s = jnp.atleast_1d(r / r_s)[:, None, None, None]
theta = jnp.atleast_1d(theta)[:, None, None, None]
phi = jnp.atleast_1d(phi)[:, None, None, None]

ns = jnp.arange(self.nmax + 1)[:, None, None] # (n, [l], [m])
ls = jnp.arange(self.lmax + 1)[None, :, None] # ([n], l, [m])

phi_nl = calculate_rho_nl(s, ns[None], ls[None])

li, mi = jnp.tril_indices(self.lmax + 1) # (l*(l+1)//2,)
shape = (1, 1, self.lmax + 1, self.lmax + 1)
midx = jnp.zeros(shape, dtype=int).at[:, :, li, mi].set(mi)
Ylm = jnp.zeros((len(theta), 1, self.lmax + 1, self.lmax + 1))
Ylm = Ylm.at[:, li, mi, :].set(real_Ylm(li[None], mi[None], theta[:, :, 0, 0]))

Snlm = self.Snlm(t, r_s=r_s)[None]
Tnlm = self.Tnlm(t, r_s=r_s)[None]

out = (self._G * self.m(t) / r_s) * jnp.sum(
Ylm * phi_nl * (Snlm * jnp.cos(midx * phi) + Tnlm * jnp.sin(midx * phi)),
axis=(1, 2, 3),
)
return out[0] if len(q.shape) == 1 else out


# =============================================================================


class STnlmSnapshotParameter(AbstractParameter): # type: ignore[misc]
"""Parameter for the STnlm coefficients."""

snapshot: Callable[ # type: ignore[name-defined]
[Float[Array, "N"]],
tuple[Float[Array, "3 N"], Float[Array, "N"]],
]
"""Cartesian coordinates of the snapshot.

This should be a callable that accepts a single argument `t` and returns
the cartesian coordinates and the masses of the snapshot at that time.
"""

nmax: int = eqx.field(static=True, converter=int)
"""Radial expansion term."""

lmax: int = eqx.field(static=True, converter=int)
"""Spherical harmonic term."""

def __call__(
self, t: gt.QuSz0, *, r_s: gt.QuSz0, **_: Any
) -> tuple[
Float[Array, "{self.nmax}+1 {self.lmax}+1 {self.lmax}+1"],
Float[Array, "{self.nmax}+1 {self.lmax}+1 {self.lmax}+1"],
]:
"""Return the coefficients at the given time(s).

TODO: are the types correct here? Should they be quantity specific?
Parameters
----------
t : float | Array[float, ()]
Time at which to evaluate the coefficients.
r_s : float | Array[float, ()]
Scale length of the potential at the given time(s.
**kwargs : Any
Additional keyword arguments are ignored.

Returns
-------
Snlm : Array[float, (nmax+1, lmax+1, lmax+1)]
The value of the cosine expansion coefficient.
Tnlm : Array[float, (nmax+1, lmax+1, lmax+1)]
The value of the sine expansion coefficient.
"""
xyz, m = self.snapshot(t)
coeffs: tuple[Array, Array] = compute_coeffs_discrete(
xyz, m, nmax=self.nmax, lmax=self.lmax, r_s=r_s
)
return coeffs
81 changes: 81 additions & 0 deletions src/galax/potential/_src/scf/bfe_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""Self-Consistent Field Potential."""

__all__: list[str] = []

from functools import partial

import jax
import jax.numpy as jnp
from jaxtyping import Array, Float

from .coeffs_helper import normalization_Knl
from .utils import psi_of_r
from spexial import eval_gegenbauers
import galax._custom_types as gt


def rho_nl(n: gt.IntSz0, l: gt.IntSz0, s: gt.FloatSz0,
) -> gt.FloatSz0:
r"""Radial density expansion terms.

Parameters
----------
s : Array[float, (n,)]
Scaled radius :math:`r/r_s`.
n : int
Radial expansion term.
l : int
Spherical harmonic term.

Returns
-------
Array[float, (n,)]
"""
return (
jnp.sqrt(4 * jnp.pi)
* (normalization_Knl(n=n, l=l) / (2 * jnp.pi))
* (s**l / (s * (1 + s) ** (2 * l + 3)))
* eval_gegenbauers(n, 2 * l + 1.5, psi_of_r(s))
)

rho_nl_jit_vec = jax.jit(
jax.vmap( jax.vmap(rho_nl, in_axes=(None, 0, None),), in_axes=(None, None, 0)), static_argnames="n"
)

# ======================================================================


def phi_nl(n: gt.IntSz0, l: gt.IntSz0, s: gt.FloatSz0,
) -> gt.FloatSz0:
r"""Angular density expansion terms.

Parameters
----------
n : int
Max Radial expansion term.
l : int
Spherical harmonic term.
s : Float
Scaled radius :math:`r/r_s`.

Returns
-------
Array[float, (n + 1,)]

Examples
--------
>>> import jax.numpy as jnp
>>> phi_nl(0.5, 1, 1)
Array(0.5, dtype=float32)
>>> phi_nl(jnp.array([0.5, 0.5]), 1, 1)
Array([0.5, 0.5], dtype=float32)
"""
return (
-jnp.sqrt(4 * jnp.pi)
* (s**l / (1.0 + s) ** (2 * l + 1))
* eval_gegenbauers(n, 2 * l + 1.5, psi_of_r(s))
)

phi_nl_jit_vec = jax.jit(
jax.vmap( jax.vmap(phi_nl, in_axes=(None, 0, None),), in_axes=(None, None, 0)), static_argnames="n"
)
Loading
Loading