Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions tests/test_palettes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
"""Tests for ggsci.palette functions (discrete and continuous)."""

from __future__ import annotations

import inspect
from collections.abc import Callable

import pytest

from ggsci import palettes as pl
from ggsci.data import PALETTES


def _is_continuous_palette(func: Callable[..., object]) -> bool:
sig = inspect.signature(func)
return "n" in sig.parameters


def _alpha_hex(alpha: float) -> str:
return f"{int(alpha * 255):02x}"


def _all_palette_functions() -> dict[str, Callable[..., object]]:
return {name[4:]: getattr(pl, name) for name in dir(pl) if name.startswith("pal_")}


def _discrete_and_continuous() -> tuple[list[str], list[str]]:
funcs = _all_palette_functions()
discrete = [name for name, fn in funcs.items() if not _is_continuous_palette(fn)]
continuous = [name for name, fn in funcs.items() if _is_continuous_palette(fn)]
return sorted(discrete), sorted(continuous)


DISCRETE_NAMES, CONTINUOUS_NAMES = _discrete_and_continuous()


@pytest.mark.parametrize("name", DISCRETE_NAMES)
def test_discrete_palette_happy_path_and_alpha(name: str):
func: Callable[..., Callable[[int], list[str]]] = getattr(pl, f"pal_{name}")

# Validate each available sub-palette for this family
for palette_key, colors in PALETTES[name].items():
# Happy path
pal_fn = func(palette=palette_key, alpha=1.0)
n = min(3, len(colors))
out = pal_fn(n)
assert isinstance(out, list) and len(out) == n
assert all(c.startswith("#") and len(c) == 7 for c in out)

# Too many requested colors -> error
with pytest.raises(ValueError):
pal_fn(len(colors) + 1)

# Alpha applied in discrete palette function
pal_fn_a = func(palette=palette_key, alpha=0.6)
out_a = pal_fn_a(1)
assert len(out_a) == 1 and out_a[0].startswith("#") and len(out_a[0]) == 9
assert out_a[0][-2:] == _alpha_hex(0.6)


@pytest.mark.parametrize("name", DISCRETE_NAMES)
def test_discrete_palette_errors(name: str):
func: Callable[..., Callable[[int], list[str]]] = getattr(pl, f"pal_{name}")

with pytest.raises(ValueError):
func(palette="__unknown__", alpha=1.0)

for bad_alpha in (0.0, -0.1, 1.0 + 1e-9):
with pytest.raises(ValueError):
func(alpha=bad_alpha)


@pytest.mark.parametrize("name", CONTINUOUS_NAMES)
def test_continuous_palette_happy_path_reverse_alpha(name: str):
func: Callable[..., list[str]] = getattr(pl, f"pal_{name}")

# Exercise all palettes for the family (kept small n for speed)
for palette_key in PALETTES[name].keys():
# Forward
out = func(palette=palette_key, n=7, alpha=1.0, reverse=False)
assert isinstance(out, list) and len(out) == 7
assert all(c.startswith("#") and len(c) == 7 for c in out)

# Reverse
out_r = func(palette=palette_key, n=7, alpha=1.0, reverse=True)
assert out_r == out[::-1]

# Alpha applied post-interpolation
out_a = func(palette=palette_key, n=5, alpha=0.6, reverse=False)
assert all(c.startswith("#") and len(c) == 9 for c in out_a)
assert out_a[0][-2:] == _alpha_hex(0.6)


@pytest.mark.parametrize("name", CONTINUOUS_NAMES)
def test_continuous_palette_errors(name: str):
func: Callable[..., list[str]] = getattr(pl, f"pal_{name}")

with pytest.raises(ValueError):
func(palette="__unknown__")

for bad_alpha in (0.0, -0.1, 1.0 + 1e-9):
with pytest.raises(ValueError):
func(alpha=bad_alpha)
6 changes: 0 additions & 6 deletions tests/test_placeholder.py

This file was deleted.

108 changes: 108 additions & 0 deletions tests/test_scales.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""Tests for ggsci.scales (discrete classes, continuous functions, aliases)."""

from __future__ import annotations

import inspect
from collections.abc import Callable

import pytest
from plotnine.scales import scale_color_gradientn, scale_fill_gradientn

import ggsci.scales as sc


def _discrete_scale_classes() -> list[tuple[str, type]]:
items: list[tuple[str, type]] = []
for name, obj in sc.__dict__.items():
if not isinstance(obj, type):
continue
if not issubclass(obj, sc.scale_discrete):
continue
# Skip the base class itself
if name == "scale_discrete":
continue
items.append((name, obj))
return items


def _continuous_scale_factories() -> list[tuple[str, Callable[..., object]]]:
return [
(name, obj)
for name, obj in sc.__dict__.items()
if name.startswith("scale_color_")
and callable(obj)
and not isinstance(obj, type)
and "gradientn" in inspect.getsource(sc.__dict__[name]).lower()
] + [
(name, obj)
for name, obj in sc.__dict__.items()
if name.startswith("scale_fill_")
and callable(obj)
and not isinstance(obj, type)
and "gradientn" in inspect.getsource(sc.__dict__[name]).lower()
]


@pytest.mark.parametrize("name,cls", _discrete_scale_classes())
def test_discrete_scales_aesthetics_and_palette(name: str, cls: type):
s = cls()
if name.startswith("scale_color_") or name.startswith("scale_colour_"):
assert s._aesthetics == ["color"]
elif name.startswith("scale_fill_"):
assert s._aesthetics == ["fill"]
else:
pytest.fail(f"Unexpected discrete scale name: {name}")

# Palette is a callable taking n and returning list[str]
colors = s.palette(3)
assert isinstance(colors, list) and len(colors) == 3
assert all(c.startswith("#") and len(c) == 7 for c in colors)

# Alpha is applied via InitVar
s_alpha = cls(alpha=0.6)
out = s_alpha.palette(1)
assert len(out) == 1 and out[0].startswith("#") and len(out[0]) == 9
assert out[0][-2:] == f"{int(0.6 * 255):02x}"


@pytest.mark.parametrize("name,fn", _continuous_scale_factories())
def test_continuous_scale_return_types(name: str, fn: Callable[..., object]):
obj = fn()
if name.startswith("scale_color_"):
assert isinstance(obj, scale_color_gradientn)
else:
assert isinstance(obj, scale_fill_gradientn)


def test_british_aliases_identity():
# All scale_colour_* should be the same object as scale_color_*
uk_names = [n for n in sc.__dict__ if n.startswith("scale_colour_")]
for uk in uk_names:
us = uk.replace("colour", "color")
assert hasattr(sc, us)
assert getattr(sc, uk) is getattr(sc, us)


def test_init_exports_alignment():
# Importing from package root should expose the same objects
import ggsci as pkg

names = [
# Sample a few across types to ensure import surface
"scale_color_npg",
"scale_fill_npg",
"scale_colour_npg",
"scale_color_gsea",
"scale_fill_bs5",
"pal_npg",
"pal_gsea",
]
for name in names:
assert hasattr(pkg, name)
# Identity with module definitions
mod = (
sc
if name.startswith("scale_")
else __import__("ggsci.palettes", fromlist=[name])
)
assert getattr(pkg, name) is getattr(mod, name)
53 changes: 53 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""Tests for ggsci.utils helpers."""

from __future__ import annotations

import pytest

from ggsci.utils import apply_alpha, hex_to_rgb, interpolate_colors, rgb_to_hex


def test_hex_to_rgb_and_back():
# Basic round-trip for full intensity channels
assert hex_to_rgb("#ffffff") == (255, 255, 255)
assert hex_to_rgb("#000000") == (0, 0, 0)

# to_hex returns lowercase, verify known mapping
assert rgb_to_hex((1.0, 0.0, 0.0)) == "#ff0000"
assert rgb_to_hex((0.0, 1.0, 0.0)) == "#00ff00"
assert rgb_to_hex((0.0, 0.0, 1.0)) == "#0000ff"


@pytest.mark.parametrize(
"colors,alpha,expected_suffix",
[
(
[
"#ffffff",
],
0.5,
"7f",
),
(["#000000"], 1.0, "ff"),
(["#abcdef"], 0.0 + 1e-9, "00"),
],
)
def test_apply_alpha(colors: list[str], alpha: float, expected_suffix: str):
out = apply_alpha(colors, alpha)
assert len(out) == len(colors)
assert all(v.startswith("#") and len(v) == 9 for v in out)
# Last two digits encode alpha
assert out[0][-2:] == expected_suffix


def test_interpolate_colors_endpoints_and_sampling():
# Interpolation across two endpoints keeps ends intact
colors = interpolate_colors(["#ff0000", "#0000ff"], 5)
assert len(colors) == 5
assert colors[0] == "#ff0000"
assert colors[-1] == "#0000ff"

# When n <= len(colors), sample endpoints evenly
base = ["#000000", "#111111", "#222222"]
sampled = interpolate_colors(base, 2)
assert sampled == [base[0], base[-1]]
Loading