Skip to content

[Don't merge] MacOS GHA healpy failure experiments #310

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 13 additions & 16 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,8 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
os: [ubuntu-latest]
include:
- os: macos-latest
python-version: "3.8"
python-version: ["3.13"]
os: [macos-latest]
fail-fast: false
env:
CMAKE_POLICY_VERSION_MINIMUM: 3.5
Expand All @@ -48,23 +45,23 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
cache: pip

- name: Set up Miniforge
uses: conda-incubator/setup-miniconda@v3
with:
miniforge-version: latest

- name: Setup tmate session
uses: mxschmitt/action-tmate@v3
with:
limit-access-to-actor: true

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[tests]

- name: Run tests (skipping slow)
if: github.event_name == 'pull_request'
run: |
pytest -v --cov-report=xml --cov=s2fft --cov-config=.coveragerc -m "not slow"

- name: Run tests
if: github.event_name != 'pull_request'
run: |
pytest -v --cov-report=xml --cov=s2fft --cov-config=.coveragerc

- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v5
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
pytest -vv tests/test_fourier_wigner.py::test_forward_fourier_wigner_transform_high_N[8966433580120847635-True-mwss-64]
5 changes: 4 additions & 1 deletion benchmarks/benchmarking.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,10 @@ def _format_results_entry(results_entry: dict) -> str:

def _dict_product(dicts: dict[str, Iterable[Any]]) -> Iterable[dict[str, Any]]:
"""Generator corresponding to Cartesian product of dictionaries."""
return (dict(zip(dicts.keys(), values)) for values in product(*dicts.values()))
return (
dict(zip(dicts.keys(), values, strict=False))
for values in product(*dicts.values())
)


def _parse_value(value: str) -> Any:
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,10 @@ def plot_results_against_bandlimit(
squeeze=False,
)
axes = axes.T if functions_along_columns else axes
for axes_row, function in zip(axes, functions):
for axes_row, function in zip(axes, functions, strict=False):
results = benchmark_results["results"][function]
l_values = np.array([r["parameters"]["L"] for r in results])
for ax, measurement in zip(axes_row, measurements):
for ax, measurement in zip(axes_row, measurements, strict=False):
plot_function, label = _measurement_plot_functions_and_labels[measurement]
try:
plot_function(ax, "L", l_values, results)
Expand Down
10 changes: 4 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
requires = [
"setuptools",
"setuptools-scm",
"scikit-build-core >=0.4.3",
"nanobind >=1.3.2"
"scikit-build-core>=0.4.3",
"nanobind>=1.3.2"
]
build-backend = "scikit_build_core.build"

Expand All @@ -16,11 +16,9 @@ authors = [
classifiers = [
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Operating System :: OS Independent",
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
Expand All @@ -38,7 +36,7 @@ keywords = [
]
name = "s2fft"
readme = "README.md"
requires-python = ">=3.8"
requires-python = ">=3.11"
license.file = "LICENCE.txt"
urls.homepage = "https://github.com/astro-informatics/s2fft"

Expand Down
5 changes: 2 additions & 3 deletions s2fft/precompute_transforms/construct.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from typing import Tuple
from warnings import warn

import jax
Expand Down Expand Up @@ -612,7 +611,7 @@ def wigner_kernel_jax(
wigner_kernel_torch = torch_wrapper.wrap_as_torch_function(wigner_kernel_jax)


def fourier_wigner_kernel(L: int) -> Tuple[np.ndarray, np.ndarray]:
def fourier_wigner_kernel(L: int) -> tuple[np.ndarray, np.ndarray]:
"""
Computes Fourier coefficients of the reduced Wigner d-functions and quadrature
weights upsampled for the forward Fourier-Wigner transform.
Expand Down Expand Up @@ -640,7 +639,7 @@ def fourier_wigner_kernel(L: int) -> Tuple[np.ndarray, np.ndarray]:
return deltas, w


def fourier_wigner_kernel_jax(L: int) -> Tuple[jnp.ndarray, jnp.ndarray]:
def fourier_wigner_kernel_jax(L: int) -> tuple[jnp.ndarray, jnp.ndarray]:
"""
Computes Fourier coefficients of the reduced Wigner d-functions and quadrature
weights upsampled for the forward Fourier-Wigner transform (JAX implementation).
Expand Down
13 changes: 6 additions & 7 deletions s2fft/precompute_transforms/custom_ops.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from functools import partial
from typing import Tuple

import jax.numpy as jnp
import numpy as np
Expand All @@ -9,7 +8,7 @@
def wigner_subset_to_s2(
flmn: np.ndarray,
spins: np.ndarray,
DW: Tuple[np.ndarray, np.ndarray],
DW: tuple[np.ndarray, np.ndarray],
L: int,
sampling: str = "mw",
) -> np.ndarray:
Expand Down Expand Up @@ -91,7 +90,7 @@ def wigner_subset_to_s2(
def wigner_subset_to_s2_jax(
flmn: jnp.ndarray,
spins: jnp.ndarray,
DW: Tuple[jnp.ndarray, jnp.ndarray],
DW: tuple[jnp.ndarray, jnp.ndarray],
L: int,
sampling: str = "mw",
) -> jnp.ndarray:
Expand Down Expand Up @@ -173,7 +172,7 @@ def wigner_subset_to_s2_jax(
def so3_to_wigner_subset(
f: np.ndarray,
spins: np.ndarray,
DW: Tuple[np.ndarray, np.ndarray],
DW: tuple[np.ndarray, np.ndarray],
L: int,
N: int,
sampling: str = "mw",
Expand Down Expand Up @@ -214,7 +213,7 @@ def so3_to_wigner_subset(
def so3_to_wigner_subset_jax(
f: jnp.ndarray,
spins: jnp.ndarray,
DW: Tuple[jnp.ndarray, jnp.ndarray],
DW: tuple[jnp.ndarray, jnp.ndarray],
L: int,
N: int,
sampling: str = "mw",
Expand Down Expand Up @@ -257,7 +256,7 @@ def so3_to_wigner_subset_jax(
def s2_to_wigner_subset(
fs: np.ndarray,
spins: np.ndarray,
DW: Tuple[np.ndarray, np.ndarray],
DW: tuple[np.ndarray, np.ndarray],
L: int,
sampling: str = "mw",
) -> np.ndarray:
Expand Down Expand Up @@ -343,7 +342,7 @@ def s2_to_wigner_subset(
def s2_to_wigner_subset_jax(
fs: jnp.ndarray,
spins: jnp.ndarray,
DW: Tuple[jnp.ndarray, jnp.ndarray],
DW: tuple[jnp.ndarray, jnp.ndarray],
L: int,
sampling: str = "mw",
) -> jnp.ndarray:
Expand Down
9 changes: 4 additions & 5 deletions s2fft/precompute_transforms/spherical.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from functools import partial
from typing import Optional
from warnings import warn

import jax.numpy as jnp
Expand All @@ -21,11 +20,11 @@ def inverse(
flm: np.ndarray,
L: int,
spin: int = 0,
kernel: Optional[np.ndarray] = None,
kernel: np.ndarray | None = None,
sampling: str = "mw",
reality: bool = False,
method: str = "jax",
nside: Optional[int] = None,
nside: int | None = None,
) -> np.ndarray:
r"""
Compute the inverse spherical harmonic transform via precompute.
Expand Down Expand Up @@ -228,11 +227,11 @@ def forward(
f: np.ndarray,
L: int,
spin: int = 0,
kernel: Optional[np.ndarray] = None,
kernel: np.ndarray | None = None,
sampling: str = "mw",
reality: bool = False,
method: str = "jax",
nside: Optional[int] = None,
nside: int | None = None,
iter: int = 0,
) -> np.ndarray:
r"""
Expand Down
9 changes: 4 additions & 5 deletions s2fft/recursions/price_mcewen.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import warnings
from functools import partial
from typing import List

import jax.lax as lax
import jax.numpy as jnp
Expand All @@ -19,7 +18,7 @@ def generate_precomputes(
nside: int = None,
forward: bool = False,
L_lower: int = 0,
) -> List[np.ndarray]:
) -> list[np.ndarray]:
r"""
Compute recursion coefficients with :math:`\mathcal{O}(L^3)` memory overhead.

Expand Down Expand Up @@ -125,7 +124,7 @@ def generate_precomputes_jax(
forward: bool = False,
L_lower: int = 0,
betas: jnp.ndarray = None,
) -> List[jnp.ndarray]:
) -> list[jnp.ndarray]:
r"""
Compute recursion coefficients with :math:`\mathcal{O}(L^2)` memory overhead.
In practice one could compute these on-the-fly but the memory overhead is
Expand Down Expand Up @@ -264,7 +263,7 @@ def generate_precomputes_wigner(
forward: bool = False,
reality: bool = False,
L_lower: int = 0,
) -> List[List[np.ndarray]]:
) -> list[list[np.ndarray]]:
r"""
Compute recursion coefficients with :math:`\mathcal{O}(L^2)` memory overhead.
In practice one could compute these on-the-fly but the memory overhead is
Expand Down Expand Up @@ -316,7 +315,7 @@ def generate_precomputes_wigner_jax(
forward: bool = False,
reality: bool = False,
L_lower: int = 0,
) -> List[List[jnp.ndarray]]:
) -> list[list[jnp.ndarray]]:
r"""
Compute recursion coefficients with :math:`\mathcal{O}(L^2)` memory overhead.
In practice one could compute these on-the-fly but the memory overhead is
Expand Down
8 changes: 3 additions & 5 deletions s2fft/sampling/s2_samples.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Tuple

import numpy as np


Expand Down Expand Up @@ -125,7 +123,7 @@ def nphi_equiang(L: int, sampling: str = "mw") -> int:
return 1


def ftm_shape(L: int, sampling: str = "mw", nside: int = None) -> Tuple[int, int]:
def ftm_shape(L: int, sampling: str = "mw", nside: int = None) -> tuple[int, int]:
r"""
Shape of intermediate array, before/after latitudinal step.

Expand Down Expand Up @@ -445,7 +443,7 @@ def ring_phase_shift_hp(
return np.exp(sign * 1j * np.arange(m_start_ind, L) * phi_offset)


def f_shape(L: int = None, sampling: str = "mw", nside: int = None) -> Tuple[int]:
def f_shape(L: int = None, sampling: str = "mw", nside: int = None) -> tuple[int]:
r"""
Shape of spherical signal.

Expand Down Expand Up @@ -480,7 +478,7 @@ def f_shape(L: int = None, sampling: str = "mw", nside: int = None) -> Tuple[int
return ntheta(L, sampling), nphi_equiang(L, sampling)


def flm_shape(L: int) -> Tuple[int, int]:
def flm_shape(L: int) -> tuple[int, int]:
r"""
Standard shape of harmonic coefficients.

Expand Down
8 changes: 3 additions & 5 deletions s2fft/sampling/so3_samples.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from typing import Tuple

import numpy as np

from s2fft.sampling import s2_samples as samples


def f_shape(
L: int, N: int, sampling: str = "mw", nside: int = None
) -> Tuple[int, int, int]:
) -> tuple[int, int, int]:
r"""
Computes the pixel-space sampling shape for signal on the rotation group
:math:`SO(3)`.
Expand Down Expand Up @@ -49,7 +47,7 @@ def f_shape(
raise ValueError(f"Sampling scheme sampling={sampling} not supported")


def flmn_shape(L: int, N: int) -> Tuple[int, int, int]:
def flmn_shape(L: int, N: int) -> tuple[int, int, int]:
r"""
Computes the shape of Wigner coefficients for signal on the rotation group
:math:`SO(3)`.
Expand All @@ -69,7 +67,7 @@ def flmn_shape(L: int, N: int) -> Tuple[int, int, int]:

def fnab_shape(
L: int, N: int, sampling: str = "mw", nside: int = None
) -> Tuple[int, int, int]:
) -> tuple[int, int, int]:
r"""
Computes the shape of Wigner coefficients for signal on the rotation group
:math:`SO(3)`.
Expand Down
9 changes: 4 additions & 5 deletions s2fft/transforms/otf_recursions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from functools import partial
from typing import List

import jax.lax as lax
import jax.numpy as jnp
Expand All @@ -21,7 +20,7 @@ def inverse_latitudinal_step(
nside: int,
sampling: str = "mw",
reality: bool = False,
precomps: List = None,
precomps: list = None,
L_lower: int = 0,
) -> np.ndarray:
r"""
Expand Down Expand Up @@ -181,7 +180,7 @@ def inverse_latitudinal_step_jax(
nside: int,
sampling: str = "mw",
reality: bool = False,
precomps: List = None,
precomps: list = None,
spmd: bool = False,
L_lower: int = 0,
) -> jnp.ndarray:
Expand Down Expand Up @@ -438,7 +437,7 @@ def forward_latitudinal_step(
nside: int,
sampling: str = "mw",
reality: bool = False,
precomps: List = None,
precomps: list = None,
L_lower: int = 0,
) -> np.ndarray:
r"""
Expand Down Expand Up @@ -598,7 +597,7 @@ def forward_latitudinal_step_jax(
nside: int,
sampling: str = "mw",
reality: bool = False,
precomps: List = None,
precomps: list = None,
spmd: bool = False,
L_lower: int = 0,
) -> jnp.ndarray:
Expand Down
Loading
Loading