-
Notifications
You must be signed in to change notification settings - Fork 207
Open
Labels
enhancementNew feature or requestNew feature or request
Description
Context
- pairplot refactoring in Refactor Plotting functions #1631 introduces typed options and upper/lower/diag API; offdiag is deprecated.
- Pain point remains: users still pass many repeated args (labels, limits, ticks) in tutorials (see docs/advanced_tutorials/17_plotting_functionality.ipynb).
- Proposal below reduces boilerplate and error-prone dimension bookkeeping by encapsulating per-dataset metadata.
Proposal
LabeledSamples dataclass (or NamedTuple) in sbi.analysis
- Fields:
- data: np.ndarray | torch.Tensor, shape (N, D)
- dim_labels: list[str] length D (optional; defaults to ["θ1", "θ2", ...])
- Optional: ticks: Optional[List[Tuple[float, float]]], limits: Optional[List[Tuple[float, float]]]
- Contract:
- pairplot and marginal_plot accept either raw arrays or LabeledSamples. When provided:
- Use dim_labels if labels arg not supplied.
- Use limits inline if limits arg not supplied.
- Use ticks inline if ticks arg not supplied.
- Support list of LabeledSamples to overlay multiple sources; FigOptions.samples_labels defaults to the .name or generated labels if present.
- pairplot and marginal_plot accept either raw arrays or LabeledSamples. When provided:
- Interop:
- A from_xarray(dataarray) constructor could leverage xarray’s named dimensions and coordinates (optional; follow-up).
Migration
- Backward compatible: continue supporting raw arrays; adopt container to reduce passing repeated labels/ticks in every call.
- Precedence: explicit function args override container metadata.
Benefits
- Encapsulates dimension labeling and axis metadata in one place.
- Simplifies notebook and pipeline code; fewer chances for label/order mismatches.
- Autocomplete-friendly, aligns with stronger typing direction from this branch.
Minimal sketch
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
ArrayLike = Union[np.ndarray, torch.Tensor]
@dataclass(frozen=True)
class LabeledSamples:
name: Optional[str] = None
data: ArrayLike = None # shape (N, D)
dim_labels: Optional[List[str]] = None
limits: Optional[List[Tuple[float, float]]] = None
ticks: Optional[List[Tuple[float, float]]] = None
# prepare_for_plot() detects LabeledSamples, unpacks metadata,
# and prefers explicit function args over container fields.More context:
-
open questions
- Defaults for samples-level labels/colors when mixing raw arrays and LabeledSamples.
- Whether limits are per-sample-set or global; if mixed, choose intersection/union or prefer first.
- Validation surface: check D consistency across overlaid datasets and metadata lengths.
- Immutability (frozen=True) and torch/numpy conversion policy at boundaries.
- Export location and public API: sbi.analysis export for discoverability.
-
Impact on docs/tests
- Update 17_plotting_functionality.ipynb to show both raw arrays and LabeledSamples flows.
- Add smoke tests for container + precedence rules; type tests for dim_labels length and limits shape.
- Re-export LabeledSamples from sbi.analysis in init.py.
-
Suggested steps (one or several PRs depending on size)
- Introduce LabeledSamples, wire into prepare_for_plot, add precedence + validation.
- Update tutorials and add unit tests; add from_xarray helper (optional).
- Consider deprecation notice in docs encouraging container for repeated plotting scenarios.
Metadata
Metadata
Labels
enhancementNew feature or requestNew feature or request