Skip to content

Labeled Samples object for less verbose plotting #1667

@janfb

Description

@janfb

Context

  • pairplot refactoring in Refactor Plotting functions #1631 introduces typed options and upper/lower/diag API; offdiag is deprecated.
  • Pain point remains: users still pass many repeated args (labels, limits, ticks) in tutorials (see docs/advanced_tutorials/17_plotting_functionality.ipynb).
  • Proposal below reduces boilerplate and error-prone dimension bookkeeping by encapsulating per-dataset metadata.

Proposal

LabeledSamples dataclass (or NamedTuple) in sbi.analysis

  • Fields:
    • data: np.ndarray | torch.Tensor, shape (N, D)
    • dim_labels: list[str] length D (optional; defaults to ["θ1", "θ2", ...])
    • Optional: ticks: Optional[List[Tuple[float, float]]], limits: Optional[List[Tuple[float, float]]]
  • Contract:
    • pairplot and marginal_plot accept either raw arrays or LabeledSamples. When provided:
      • Use dim_labels if labels arg not supplied.
      • Use limits inline if limits arg not supplied.
      • Use ticks inline if ticks arg not supplied.
    • Support list of LabeledSamples to overlay multiple sources; FigOptions.samples_labels defaults to the .name or generated labels if present.
  • Interop:
    • A from_xarray(dataarray) constructor could leverage xarray’s named dimensions and coordinates (optional; follow-up).

Migration

  • Backward compatible: continue supporting raw arrays; adopt container to reduce passing repeated labels/ticks in every call.
  • Precedence: explicit function args override container metadata.

Benefits

  • Encapsulates dimension labeling and axis metadata in one place.
  • Simplifies notebook and pipeline code; fewer chances for label/order mismatches.
  • Autocomplete-friendly, aligns with stronger typing direction from this branch.

Minimal sketch

from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import numpy as np
import torch

ArrayLike = Union[np.ndarray, torch.Tensor]

@dataclass(frozen=True)
class LabeledSamples:
    name: Optional[str] = None
    data: ArrayLike = None  # shape (N, D)
    dim_labels: Optional[List[str]] = None
    limits: Optional[List[Tuple[float, float]]] = None
    ticks: Optional[List[Tuple[float, float]]] = None

# prepare_for_plot() detects LabeledSamples, unpacks metadata,
# and prefers explicit function args over container fields.

More context:

  • open questions

    • Defaults for samples-level labels/colors when mixing raw arrays and LabeledSamples.
    • Whether limits are per-sample-set or global; if mixed, choose intersection/union or prefer first.
    • Validation surface: check D consistency across overlaid datasets and metadata lengths.
    • Immutability (frozen=True) and torch/numpy conversion policy at boundaries.
    • Export location and public API: sbi.analysis export for discoverability.
  • Impact on docs/tests

    • Update 17_plotting_functionality.ipynb to show both raw arrays and LabeledSamples flows.
    • Add smoke tests for container + precedence rules; type tests for dim_labels length and limits shape.
    • Re-export LabeledSamples from sbi.analysis in init.py.
  • Suggested steps (one or several PRs depending on size)

    • Introduce LabeledSamples, wire into prepare_for_plot, add precedence + validation.
    • Update tutorials and add unit tests; add from_xarray helper (optional).
    • Consider deprecation notice in docs encouraging container for repeated plotting scenarios.

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions