-
Notifications
You must be signed in to change notification settings - Fork 8
WIP: Add SCF Potential #751
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
e93c337
476a29a
0e5c7a7
052d0e9
872003d
93b784f
72053c8
c00a5ed
fab17fa
058052e
f9dc960
92bc71b
3f12e64
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -53,6 +53,9 @@ | |
| "wadler_lindig>=0.1.6", | ||
| "xmmutablemap>=0.1", | ||
| "zeroth>=1.0", | ||
| "spexial @ git+https://github.com/JAXtronomy/spexial.git@main", | ||
| "hypothesis>=6.135.14", | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a test-time dependency. It can be added by |
||
| "gala>=1.9.1", | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Gala is already a test-time dependency. :) |
||
| ] | ||
|
|
||
| [project.optional-dependencies] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,7 +26,7 @@ def __call__( | |
|
|
||
| Parameters | ||
| ---------- | ||
| t : `~galax.typing.BBtQuSz0` | ||
| t : `~galax._custom_types.BBtQuSz0` | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. True. We should fix this. Probably my just not having the types in the docstring. The actual annotations are authoritative 🤷 . |
||
| Time(s) at which to compute the parameter value. | ||
| ustrip : Unit | None | ||
| Unit to strip from the parameter value. | ||
|
|
@@ -62,7 +62,7 @@ def __call__( | |
|
|
||
| Parameters | ||
| ---------- | ||
| t : `~galax.typing.BBtQuSz0` | ||
| t : `~galax._custom_types.BBtQuSz0` | ||
| The time(s) at which to compute the parameter value. | ||
| ustrip : Unit | None | ||
| The unit to strip from the parameter value. If None, the | ||
|
|
||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For static analyses' sake we prefer the long-form exports. Annoying, I know, but it makes mypy happier. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| from . import bfe, bfe_helper, coeffs, coeffs_helper | ||
| from .bfe import * | ||
| from .bfe_helper import * | ||
| from .coeffs import * | ||
| from .coeffs_helper import * | ||
|
|
||
| __all__: list[str] = [] | ||
| __all__ += bfe.__all__ | ||
| __all__ += bfe_helper.__all__ | ||
| __all__ += coeffs.__all__ | ||
| __all__ += coeffs_helper.__all__ |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,189 @@ | ||
| """Self-Consistent Field Potential.""" | ||
|
|
||
| __all__ = ["SCFPotential", "STnlmSnapshotParameter"] | ||
|
|
||
| from collections.abc import Callable | ||
| from functools import partial | ||
| from typing import Any | ||
|
|
||
| import equinox as eqx | ||
| import jax | ||
| import jax.numpy as jnp | ||
| from jaxtyping import Array, Float | ||
|
|
||
| import galax._custom_types as gt | ||
| from .bfe_helper import phi_nl_vec, rho_nl as calculate_rho_nl | ||
| from .coeffs import compute_coeffs_discrete | ||
| from .utils import cartesian_to_spherical, real_Ylm | ||
| from galax.potential import AbstractPotential | ||
| from galax.potential._src.params.base import AbstractParameter | ||
| from galax.potential._src.params.field import ParameterField | ||
|
|
||
| ############################################################################## | ||
|
|
||
|
|
||
| class SCFPotential(AbstractPotential): | ||
| r"""Self-Consistent Field (SCF) potential. | ||
|
|
||
| A gravitational potential represented as a basis function expansion. This | ||
| uses the self-consistent field (SCF) method of Hernquist & Ostriker (1992) | ||
| and Lowing et al. (2011), and represents all coefficients as real | ||
| quantities. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| m : numeric | ||
| Scale mass. | ||
| r_s : numeric | ||
| Scale length. | ||
| Snlm : Array[float, (nmax+1, lmax+1, lmax+1)] | Callable | ||
| Array of coefficients for the cos() terms of the expansion. This should | ||
| be a 3D array with shape `(nmax+1, lmax+1, lmax+1)`, where `nmax` is the | ||
| number of radial expansion terms and `lmax` is the number of spherical | ||
| harmonic `l` terms. If a callable is provided, it should accept a | ||
| single argument `t` and return the array of coefficients for that time. | ||
| Tnlm : Array[float, (nmax+1, lmax+1, lmax+1)] | Callable | ||
| Array of coefficients for the sin() terms of the expansion. This should | ||
| be a 3D array with shape `(nmax+1, lmax+1, lmax+1)`, where `nmax` is the | ||
| number of radial expansion terms and `lmax` is the number of spherical | ||
| harmonic `l` terms. If a callable is provided, it should accept a | ||
| single argument `t` and return the array of coefficients for that time. | ||
| units : iterable | ||
| Unique list of non-reducable units that specify (at minimum) the length, | ||
| mass, time, and angle units. | ||
| """ | ||
|
|
||
| m: AbstractParameter = ParameterField(dimensions="mass") | ||
| r_s: AbstractParameter = ParameterField(dimensions="length") | ||
| Snlm: AbstractParameter = ParameterField(dimensions="dimensionless") | ||
| Tnlm: AbstractParameter = ParameterField(dimensions="dimensionless") | ||
|
|
||
| nmax: int = eqx.field(init=False, static=True, repr=False) | ||
| lmax: int = eqx.field(init=False, static=True, repr=False) | ||
|
|
||
| def __post_init__(self) -> None: | ||
| super().__post_init__() | ||
|
|
||
| # shape parameters | ||
| shape = self.Snlm(0).shape | ||
| object.__setattr__(self, "nmax", shape[0] - 1) | ||
| object.__setattr__(self, "lmax", shape[1] - 1) | ||
|
|
||
| # ========================================================================== | ||
|
|
||
| @partial(jax.jit, inline=True) | ||
| def _potential( | ||
| self, xyz: gt.BtQuSz3, t: gt.BtQuSz0, / | ||
| ) -> gt.SzN | gt.FloatSz0: | ||
| r_s = self.r_s(t) | ||
| r, theta, phi = cartesian_to_spherical(xyz).T | ||
|
|
||
| s = jnp.atleast_1d(r / r_s) # ([n],[l],[m],[N]) | ||
| theta = jnp.atleast_1d(theta)[None, None, None] # ([n],[l],[m],[N]) | ||
| phi = jnp.atleast_1d(phi)[None, None, None] # ([n],[l],[m],[N]) | ||
|
|
||
| ns = jnp.arange(self.nmax + 1)[:, None, None] # (n, [l], [m]) | ||
| ls = jnp.arange(self.lmax + 1)[None, :, None] # ([n], l, [m]) | ||
| phi_nl = phi_nl_vec(s, ns, ls) # (n, l, [m], N) | ||
|
|
||
| li, mi = jnp.tril_indices(self.lmax + 1) # (l*(l+1)//2,) | ||
| shape = (1, self.lmax + 1, self.lmax + 1, 1) # ([n], l, m, [N]) | ||
| midx = jnp.zeros(shape, dtype=int).at[:, li, mi, 0].set(mi) # ([n], l, m, [N]) | ||
|
|
||
| Ylm = jnp.zeros(shape[:-1] + (len(s),)) | ||
| Ylm = Ylm.at[0, li, mi, :].set( | ||
| real_Ylm(theta[:, 0, 0, :], li[..., None], mi[..., None]) | ||
| ) | ||
|
|
||
| Snlm = self.Snlm(t, r_s=r_s)[..., None] | ||
| Tnlm = self.Tnlm(t, r_s=r_s)[..., None] | ||
|
|
||
| out = (self._G * self.m(t) / r_s) * jnp.sum( | ||
| Ylm * phi_nl * (Snlm * jnp.cos(midx * phi) + Tnlm * jnp.sin(midx * phi)), | ||
| axis=(0, 1, 2), | ||
| ) | ||
| return out[0] if len(xyz.shape) == 1 else out | ||
|
|
||
| @partial(jax.jit, inline=True) | ||
| @eqx.filter_vmap(in_axes=(None, 1, None)) # type: ignore[misc] # on `q` axis 1 | ||
| def _density(self, q: gt.QuSz3, /, t: gt.QuSz0) -> Float[Array, "N"]: # type: ignore[name-defined] | ||
| """Compute the density at the given position(s).""" | ||
| r, theta, phi = cartesian_to_spherical(q) | ||
| r_s = self.r_s(t) | ||
| s = jnp.atleast_1d(r / r_s)[:, None, None, None] | ||
| theta = jnp.atleast_1d(theta)[:, None, None, None] | ||
| phi = jnp.atleast_1d(phi)[:, None, None, None] | ||
|
|
||
| ns = jnp.arange(self.nmax + 1)[:, None, None] # (n, [l], [m]) | ||
| ls = jnp.arange(self.lmax + 1)[None, :, None] # ([n], l, [m]) | ||
|
|
||
| phi_nl = calculate_rho_nl(s, ns[None], ls[None]) | ||
|
|
||
| li, mi = jnp.tril_indices(self.lmax + 1) # (l*(l+1)//2,) | ||
| shape = (1, 1, self.lmax + 1, self.lmax + 1) | ||
| midx = jnp.zeros(shape, dtype=int).at[:, :, li, mi].set(mi) | ||
| Ylm = jnp.zeros((len(theta), 1, self.lmax + 1, self.lmax + 1)) | ||
| Ylm = Ylm.at[:, li, mi, :].set(real_Ylm(li[None], mi[None], theta[:, :, 0, 0])) | ||
|
|
||
| Snlm = self.Snlm(t, r_s=r_s)[None] | ||
| Tnlm = self.Tnlm(t, r_s=r_s)[None] | ||
|
|
||
| out = (self._G * self.m(t) / r_s) * jnp.sum( | ||
| Ylm * phi_nl * (Snlm * jnp.cos(midx * phi) + Tnlm * jnp.sin(midx * phi)), | ||
| axis=(1, 2, 3), | ||
| ) | ||
| return out[0] if len(q.shape) == 1 else out | ||
|
|
||
|
|
||
| # ============================================================================= | ||
|
|
||
|
|
||
| class STnlmSnapshotParameter(AbstractParameter): # type: ignore[misc] | ||
| """Parameter for the STnlm coefficients.""" | ||
|
|
||
| snapshot: Callable[ # type: ignore[name-defined] | ||
| [Float[Array, "N"]], | ||
| tuple[Float[Array, "3 N"], Float[Array, "N"]], | ||
| ] | ||
| """Cartesian coordinates of the snapshot. | ||
|
|
||
| This should be a callable that accepts a single argument `t` and returns | ||
| the cartesian coordinates and the masses of the snapshot at that time. | ||
| """ | ||
|
|
||
| nmax: int = eqx.field(static=True, converter=int) | ||
| """Radial expansion term.""" | ||
|
|
||
| lmax: int = eqx.field(static=True, converter=int) | ||
| """Spherical harmonic term.""" | ||
|
|
||
| def __call__( | ||
| self, t: gt.QuSz0, *, r_s: gt.QuSz0, **_: Any | ||
| ) -> tuple[ | ||
| Float[Array, "{self.nmax}+1 {self.lmax}+1 {self.lmax}+1"], | ||
| Float[Array, "{self.nmax}+1 {self.lmax}+1 {self.lmax}+1"], | ||
| ]: | ||
| """Return the coefficients at the given time(s). | ||
|
|
||
| TODO: are the types correct here? Should they be quantity specific? | ||
| Parameters | ||
| ---------- | ||
| t : float | Array[float, ()] | ||
| Time at which to evaluate the coefficients. | ||
| r_s : float | Array[float, ()] | ||
| Scale length of the potential at the given time(s. | ||
| **kwargs : Any | ||
| Additional keyword arguments are ignored. | ||
|
|
||
| Returns | ||
| ------- | ||
| Snlm : Array[float, (nmax+1, lmax+1, lmax+1)] | ||
| The value of the cosine expansion coefficient. | ||
| Tnlm : Array[float, (nmax+1, lmax+1, lmax+1)] | ||
| The value of the sine expansion coefficient. | ||
| """ | ||
| xyz, m = self.snapshot(t) | ||
| coeffs: tuple[Array, Array] = compute_coeffs_discrete( | ||
| xyz, m, nmax=self.nmax, lmax=self.lmax, r_s=r_s | ||
| ) | ||
| return coeffs |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,81 @@ | ||
| """Self-Consistent Field Potential.""" | ||
|
|
||
| __all__: list[str] = [] | ||
|
|
||
| from functools import partial | ||
|
|
||
| import jax | ||
| import jax.numpy as jnp | ||
| from jaxtyping import Array, Float | ||
|
|
||
| from .coeffs_helper import normalization_Knl | ||
| from .utils import psi_of_r | ||
| from spexial import eval_gegenbauers | ||
| import galax._custom_types as gt | ||
|
|
||
|
|
||
| def rho_nl(n: gt.IntSz0, l: gt.IntSz0, s: gt.FloatSz0, | ||
| ) -> gt.FloatSz0: | ||
| r"""Radial density expansion terms. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| s : Array[float, (n,)] | ||
| Scaled radius :math:`r/r_s`. | ||
| n : int | ||
| Radial expansion term. | ||
| l : int | ||
| Spherical harmonic term. | ||
|
|
||
| Returns | ||
| ------- | ||
| Array[float, (n,)] | ||
| """ | ||
| return ( | ||
| jnp.sqrt(4 * jnp.pi) | ||
| * (normalization_Knl(n=n, l=l) / (2 * jnp.pi)) | ||
| * (s**l / (s * (1 + s) ** (2 * l + 3))) | ||
| * eval_gegenbauers(n, 2 * l + 1.5, psi_of_r(s)) | ||
| ) | ||
|
|
||
| rho_nl_jit_vec = jax.jit( | ||
| jax.vmap( jax.vmap(rho_nl, in_axes=(None, 0, None),), in_axes=(None, None, 0)), static_argnames="n" | ||
| ) | ||
|
|
||
| # ====================================================================== | ||
|
|
||
|
|
||
| def phi_nl(n: gt.IntSz0, l: gt.IntSz0, s: gt.FloatSz0, | ||
| ) -> gt.FloatSz0: | ||
| r"""Angular density expansion terms. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| n : int | ||
| Max Radial expansion term. | ||
| l : int | ||
| Spherical harmonic term. | ||
| s : Float | ||
| Scaled radius :math:`r/r_s`. | ||
|
|
||
| Returns | ||
| ------- | ||
| Array[float, (n + 1,)] | ||
|
|
||
| Examples | ||
| -------- | ||
| >>> import jax.numpy as jnp | ||
| >>> phi_nl(0.5, 1, 1) | ||
| Array(0.5, dtype=float32) | ||
| >>> phi_nl(jnp.array([0.5, 0.5]), 1, 1) | ||
| Array([0.5, 0.5], dtype=float32) | ||
| """ | ||
| return ( | ||
| -jnp.sqrt(4 * jnp.pi) | ||
| * (s**l / (1.0 + s) ** (2 * l + 1)) | ||
| * eval_gegenbauers(n, 2 * l + 1.5, psi_of_r(s)) | ||
| ) | ||
|
|
||
| phi_nl_jit_vec = jax.jit( | ||
| jax.vmap( jax.vmap(phi_nl, in_axes=(None, 0, None),), in_axes=(None, None, 0)), static_argnames="n" | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unfortunately, spexial is unlikely to be released. Can you move the function into
potential/scf/?