Skip to content

Better pairplots #505

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 4, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 31 additions & 4 deletions bayesflow/diagnostics/plots/pairs_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def pairs_posterior(
height: int = 3,
post_color: str | tuple = "#132a70",
prior_color: str | tuple = "gray",
target_color: str | tuple = "red",
alpha: float = 0.9,
label_fontsize: int = 14,
tick_fontsize: int = 12,
Expand Down Expand Up @@ -109,14 +110,40 @@ def pairs_posterior(
# Create DataFrame with variable names as columns
g.data = pd.DataFrame(targets, columns=targets.variable_names)
g.data["_source"] = "True Parameter"
g.map_diag(plot_true_params)
g.map_diag(plot_true_params_as_lines, color=target_color)
g.map_offdiag(plot_true_params_as_points, color=target_color)

target_handle = plt.Line2D(
[0], [0],
color=target_color,
linestyle="--",
marker="x",
label="Targets"
)

diag_ax = g.axes[0, 0]
# Collect histogram legend handles
hist_handles = getattr(diag_ax, '_legend_handles', [])

# Collect labels and handles from regular plots (if any)
handles, labels = g.axes[0, 0].get_legend_handles_labels()

handles = hist_handles + [target_handle]
labels = [h.get_label() for h in handles] # safer to refresh labels
g.fig.legend(handles=handles, labels=labels, loc="center right", frameon=False, fontsize=legend_fontsize)

return g


def plot_true_params(x, hue=None, **kwargs):
"""Custom function to plot true parameters on the diagonal."""
def plot_true_params_as_lines(x, hue=None, color=None, **kwargs):
"""Custom function to plot true parameters on the diagonal as dashed lines."""
# hue needs to be added to handle the case of plotting both posterior and prior
param = x.iloc[0] # Get the single true value for the diagonal
# only plot on the diagonal a vertical line for the true parameter
plt.axvline(param, color="black", linestyle="--")
plt.axvline(param, color=color, linestyle="--", label="Target (Line)")


def plot_true_params_as_points(x, y, color=None, marker='x', **kwargs):
"""Custom function to plot true parameters on the off-diagonal as a single point."""
if len(x) > 0 and len(y) > 0:
plt.scatter(x.iloc[0], y.iloc[0], color=color, marker=marker, **kwargs)
33 changes: 23 additions & 10 deletions bayesflow/diagnostics/plots/pairs_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import seaborn as sns

import matplotlib.pyplot as plt
from matplotlib.patches import Patch

from bayesflow.utils import logging
from bayesflow.utils.dict_utils import dicts_to_arrays
Expand Down Expand Up @@ -146,11 +147,6 @@ def _pairs_samples(
logging.exception("KDE failed due to the following exception:\n" + repr(e) + "\nSubstituting scatter plot.")
g.map_lower(sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color, lw=0)

# need to add legend here such that colors are recognized
if plot_data["priors"] is not None:
g.add_legend(fontsize=legend_fontsize, loc="center right")
g._legend.set_title(None)

# Generate grids
dim = g.axes.shape[0]
for i in range(dim):
Expand All @@ -170,6 +166,11 @@ def _pairs_samples(
g.axes[i, 0].set_ylabel(variable_names[i], fontsize=label_fontsize)
g.axes[dim - 1, i].set_xlabel(variable_names[i], fontsize=label_fontsize)

# need to add legend here such that colors are recognized
if plot_data["priors"] is not None:
g.add_legend(fontsize=legend_fontsize, loc="center right")
g._legend.set_title(None)

# Return figure
g.tight_layout()

Expand All @@ -181,16 +182,28 @@ def _pairs_samples(
# in independent of the y scaling of the off-diagonal plots
def histplot_twinx(x, **kwargs):
# Create a twin axis
ax2 = plt.gca().twinx()
# ax2 = plt.gca().twinx()

label = kwargs.pop("labels", None)
color = kwargs.get("colors", None)

ax = plt.gca()

# create a histogram on the twin axis
sns.histplot(x, **kwargs, ax=ax2)
sns.histplot(x, **kwargs)

if label is not None:
legend_artist = Patch(color=color, label=label)
# Store the artist for later
if not hasattr(ax, '_legend_handles'):
ax._legend_handles = []
ax._legend_handles.append(legend_artist)

# make the twin axis invisible
plt.gca().spines["right"].set_visible(False)
plt.gca().spines["top"].set_visible(False)
ax2.set_ylabel("")
ax2.set_yticks([])
ax2.set_yticklabels([])
# ax2.set_ylabel("")
# ax2.set_yticks([])
# ax2.set_yticklabels([])

return None