Skip to content

Better pairplots #505

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 29 additions & 9 deletions bayesflow/diagnostics/plots/pairs_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import seaborn as sns

from bayesflow.utils.dict_utils import dicts_to_arrays
from bayesflow.utils.plot_utils import create_legends

from .pairs_samples import _pairs_samples

Expand All @@ -21,6 +22,7 @@ def pairs_posterior(
height: int = 3,
post_color: str | tuple = "#132a70",
prior_color: str | tuple = "gray",
target_color: str | tuple = "red",
alpha: float = 0.9,
label_fontsize: int = 14,
tick_fontsize: int = 12,
Expand All @@ -37,25 +39,27 @@ def pairs_posterior(
Optional true parameter values that have generated the observed dataset.
priors : np.ndarray of shape (n_prior_draws, n_params) or None, optional (default: None)
Optional prior samples obtained from the prior.
dataset_id: Optional ID of the dataset for whose posterior the pairs plot shall be generated.
Should only be specified if estimates contains posterior draws from multiple datasets.
dataset_id: Optional ID of the dataset for whose posterior the pair plots shall be generated.
Should only be specified if estimates contain posterior draws from multiple datasets.
variable_keys : list or None, optional, default: None
Select keys from the dictionary provided in samples.
By default, select all keys.
variable_names : list or None, optional, default: None
The parameter names for nice plot titles. Inferred if None
height : float, optional, default: 3
The height of the pairplot
The height of the pair plots
label_fontsize : int, optional, default: 14
The font size of the x and y-label texts (parameter names)
tick_fontsize : int, optional, default: 12
The font size of the axis ticklabels
The font size of the axis tick labels
legend_fontsize : int, optional, default: 16
The font size of the legend text
post_color : str, optional, default: '#132a70'
The color for the posterior histograms and KDEs
prior_color : str, optional, default: gray
The color for the optional prior histograms and KDEs
target_color : str, optional, default: red
The color for the optional true parameter lines and points
alpha : float in [0, 1], optional, default: 0.9
The opacity of the posterior plots

Expand All @@ -81,7 +85,7 @@ def pairs_posterior(
variable_names=variable_names,
)

# dicts_to_arrays will keep dataset axis even if it is of length 1
# dicts_to_arrays will keep the dataset axis even if it is of length 1
# however, pairs plotting requires the dataset axis to be removed
estimates_shape = plot_data["estimates"].shape
if len(estimates_shape) == 3 and estimates_shape[0] == 1:
Expand Down Expand Up @@ -109,14 +113,30 @@ def pairs_posterior(
# Create DataFrame with variable names as columns
g.data = pd.DataFrame(targets, columns=targets.variable_names)
g.data["_source"] = "True Parameter"
g.map_diag(plot_true_params)
g.map_diag(plot_true_params_as_lines, color=target_color)
g.map_offdiag(plot_true_params_as_points, color=target_color)

create_legends(
g,
plot_data,
color=post_color,
color2=prior_color,
legend_fontsize=legend_fontsize,
show_single_legend=False,
)

return g


def plot_true_params(x, hue=None, **kwargs):
"""Custom function to plot true parameters on the diagonal."""
def plot_true_params_as_lines(x, hue=None, color=None, **kwargs):
"""Custom function to plot true parameters on the diagonal as dashed lines."""
# hue needs to be added to handle the case of plotting both posterior and prior
param = x.iloc[0] # Get the single true value for the diagonal
# only plot on the diagonal a vertical line for the true parameter
plt.axvline(param, color="black", linestyle="--")
plt.axvline(param, color=color, linestyle="--")


def plot_true_params_as_points(x, y, color=None, marker="x", **kwargs):
"""Custom function to plot true parameters on the off-diagonal as a single point."""
if len(x) > 0 and len(y) > 0:
plt.scatter(x.iloc[0], y.iloc[0], color=color, marker=marker, **kwargs)
78 changes: 55 additions & 23 deletions bayesflow/diagnostics/plots/pairs_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from bayesflow.utils import logging
from bayesflow.utils.dict_utils import dicts_to_arrays
from bayesflow.utils.plot_utils import create_legends


def pairs_samples(
Expand All @@ -17,8 +18,10 @@ def pairs_samples(
height: float = 2.5,
color: str | tuple = "#132a70",
alpha: float = 0.9,
label: str = "Posterior",
label_fontsize: int = 14,
tick_fontsize: int = 12,
show_single_legend: bool = False,
**kwargs,
) -> sns.PairGrid:
"""
Expand All @@ -37,13 +40,18 @@ def pairs_samples(
height : float, optional, default: 2.5
The height of the pair plot
color : str, optional, default : '#8f2727'
The color of the plot
The primary color of the plot
alpha : float in [0, 1], optional, default: 0.9
The opacity of the plot
label : str, optional, default: "Posterior"
Label for the dataset to plot
label_fontsize : int, optional, default: 14
The font size of the x and y-label texts (parameter names)
tick_fontsize : int, optional, default: 12
The font size of the axis ticklabels
The font size of the axis tick labels
show_single_legend : bool, optional, default: False
Optional toggle for the user to choose whether a single dataset
should also display legend
**kwargs : dict, optional
Additional keyword arguments passed to the sns.PairGrid constructor
"""
Expand All @@ -59,8 +67,11 @@ def pairs_samples(
height=height,
color=color,
alpha=alpha,
label=label,
label_fontsize=label_fontsize,
tick_fontsize=tick_fontsize,
show_single_legend=show_single_legend,
**kwargs,
)

return g
Expand All @@ -72,17 +83,27 @@ def _pairs_samples(
color: str | tuple = "#132a70",
color2: str | tuple = "gray",
alpha: float = 0.9,
label: str = "Posterior",
label_fontsize: int = 14,
tick_fontsize: int = 12,
legend_fontsize: int = 14,
show_single_legend: bool = False,
**kwargs,
) -> sns.PairGrid:
# internal version of pairs_samples creating the seaborn plot
"""
Internal version of pairs_samples creating the seaborn PairPlot
for both a single dataset and multiple datasets.

# Parameters
# ----------
# plot_data : output of bayesflow.utils.dict_utils.dicts_to_arrays
# other arguments are documented in pairs_samples
Parameters
----------
plot_data : output of bayesflow.utils.dict_utils.dicts_to_arrays
Formatted data to plot from the sample dataset
color2 : str, optional, default: 'gray'
Secondary color for the pair plots.
This is the color used for the prior draws.

Other arguments are documented in pairs_samples
"""

estimates_shape = plot_data["estimates"].shape
if len(estimates_shape) != 2:
Expand Down Expand Up @@ -136,7 +157,7 @@ def _pairs_samples(
common_norm=False,
)

# add scatterplots to the upper diagonal
# add scatter plots to the upper diagonal
g.map_upper(sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color, lw=0)

# add KDEs to the lower diagonal
Expand All @@ -146,11 +167,6 @@ def _pairs_samples(
logging.exception("KDE failed due to the following exception:\n" + repr(e) + "\nSubstituting scatter plot.")
g.map_lower(sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color, lw=0)

# need to add legend here such that colors are recognized
if plot_data["priors"] is not None:
g.add_legend(fontsize=legend_fontsize, loc="center right")
g._legend.set_title(None)

# Generate grids
dim = g.axes.shape[0]
for i in range(dim):
Expand All @@ -165,32 +181,48 @@ def _pairs_samples(
g.axes[i, j].tick_params(axis="both", which="major", labelsize=tick_fontsize)
g.axes[i, j].tick_params(axis="both", which="minor", labelsize=tick_fontsize)

# adjust font size of labels
# adjust the font size of labels
# the labels themselves remain the same as before, i.e., variable_names
g.axes[i, 0].set_ylabel(variable_names[i], fontsize=label_fontsize)
g.axes[dim - 1, i].set_xlabel(variable_names[i], fontsize=label_fontsize)

# need to add legend here such that colors are recognized
# if plot_data["priors"] is not None:
# g.add_legend(fontsize=legend_fontsize, loc="center right")
# g._legend.set_title(None)

create_legends(
g,
plot_data,
color=color,
color2=color2,
legend_fontsize=legend_fontsize,
label=label,
show_single_legend=show_single_legend,
)

# Return figure
g.tight_layout()

return g


# create a histogram plot on a twin y axis
# this ensures that the y scaling of the diagonal plots
# in independent of the y scaling of the off-diagonal plots
def histplot_twinx(x, **kwargs):
# Create a twin axis
ax2 = plt.gca().twinx()
"""
# create a histogram plot on a twin y-axis
# this ensures that the y scaling of the diagonal plots
# in independent of the y scaling of the off-diagonal plots

Parameters
----------
x : np.ndarray
Data to be plotted.
"""
# create a histogram on the twin axis
sns.histplot(x, **kwargs, ax=ax2)
sns.histplot(x, legend=False, **kwargs)

# make the twin axis invisible
plt.gca().spines["right"].set_visible(False)
plt.gca().spines["top"].set_visible(False)
ax2.set_ylabel("")
ax2.set_yticks([])
ax2.set_yticklabels([])

return None
72 changes: 67 additions & 5 deletions bayesflow/utils/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from matplotlib.collections import LineCollection
from matplotlib.colors import Normalize
from matplotlib.patches import Rectangle
from matplotlib.patches import Rectangle, Patch
from matplotlib.legend_handler import HandlerPatch

from .validators import check_estimates_prior_shapes
Expand Down Expand Up @@ -67,7 +67,7 @@ def prepare_plot_data(
)
check_estimates_prior_shapes(plot_data["estimates"], plot_data["targets"])

# store variable information at top level for easy access
# store variable information at the top level for easy access
variable_names = plot_data["estimates"].variable_names
num_variables = len(variable_names)
plot_data["variable_names"] = variable_names
Expand Down Expand Up @@ -249,7 +249,7 @@ def prettify_subplots(axes: np.ndarray, num_subplots: int, tick: bool = True, ti

def make_quadratic(ax: plt.Axes, x_data: np.ndarray, y_data: np.ndarray):
"""
Utility to make a subplots quadratic in order to avoid visual illusions
Utility to make subplots quadratic to avoid visual illusions
in, e.g., recovery plots.
"""

Expand All @@ -269,7 +269,7 @@ def make_quadratic(ax: plt.Axes, x_data: np.ndarray, y_data: np.ndarray):

def gradient_line(x, y, c=None, cmap: str = "viridis", lw: float = 2.0, alpha: float = 1, ax=None):
"""
Plot a 1D line with color gradient determined by `c` (same shape as x and y).
Plot a 1D line with a color gradient determined by `c` (same shape as x and y).
"""
if ax is None:
ax = plt.gca()
Expand Down Expand Up @@ -304,7 +304,7 @@ def gradient_legend(ax, label, cmap, norm, loc="upper right"):
- loc: legend location (default 'upper right')
"""

# Custom dummy handle to represent the gradient
# Custom placeholder handle to represent the gradient
class _GradientSwatch(Rectangle):
pass

Expand Down Expand Up @@ -358,3 +358,65 @@ def add_gradient_plot(
label=label,
alpha=0.01,
)


def create_legends(
g,
plot_data: dict,
color: str | tuple = "#132a70",
color2: str | tuple = "gray",
label: str = "Posterior",
show_single_legend: bool = False,
legend_fontsize: int = 14,
):
"""
Helper function to create legends for pairplots.
Parameters
----------
g : sns.PairGrid
Seaborn object for the pair plots
plot_data : output of bayesflow.utils.dict_utils.dicts_to_arrays
Formatted data to plot from the sample dataset
color : str, optional, default : '#8f2727'
The primary color of the plot
color2 : str, optional, default: 'gray'
The secondary color for the plot
label : str, optional, default: "Posterior"
Label for the dataset to plot
show_single_legend : bool, optional, default: False
Optional toggle for the user to choose whether a single dataset
should also display legend
legend_fontsize : int, optional, default: 14
fontsize for the legend
"""
handles = []
labels = []

if plot_data.get("priors") is not None:
prior_handle = Patch(color=color2, label="Prior")
prior_label = "Prior"
handles.append(prior_handle)
labels.append(prior_label)

posterior_handle = Patch(color=color, label="Posterior")
posterior_label = label
handles.append(posterior_handle)
labels.append(posterior_label)

if plot_data.get("targets") is not None:
target_handle = plt.Line2D([0], [0], color="r", linestyle="--", marker="x", label="Targets")
target_label = "Targets"
handles.append(target_handle)
labels.append(target_label)

# If there are more than one dataset to plot,
if len(handles) > 1 or show_single_legend:
g.figure.legend(
handles=handles,
labels=labels,
loc="center left",
bbox_to_anchor=(1, 0.5),
frameon=False,
fontsize=legend_fontsize,
)