Description
The diagnostics.pairs_posterior()
function currently plots targets
behind the prior and posterior distributions, making them hard to see. Additionally, there is limited customization for the appearance and legend handling of the targets, and no way to adjust their rendering order due to separate axes being used in diagonal subplots. This does not affact non-diagonal subplots.
Minimal example to reproduce
import numpy as np
import bayesflow as bf
dummy_prior = np.random.multivariate_normal([0, 0], np.eye(2), size=100)
dummy_posterior = np.random.multivariate_normal([0, 0], 0.1 * np.eye(2), size=100)
dummy_targets = np.array([0, 0])
fig = bf.diagnostics.plots.pairs_posterior(
estimates=dummy_posterior,
priors=dummy_prior,
targets=dummy_targets,
)
fig.savefig("test_plots.png")
Output: test_plots.png
Suggested improvements
- Plot
targets
on top of prior and posterior plots. - Ensure
targets
appear in both diagonal and non-diagonal subplots. - Allow user to specify the color of the
targets
. - Return or attach the diagonal axes (from
histplot_twinx
) so users can adjust them or add custom elements. - Provide more control over legends, such as adding
targets
to legend.
Root cause / possible fix
The issue seems to originate from the histplot_twinx
function used in _pairs_samples()
. It creates a new twin axis which is not included in the returned g
object. As a result, targets
and prior/posterior elements end up on different axes, preventing user control (e.g., via zorder
). This issue does not affect non-diagonal plots since they share the same axes.
import numpy as np
import bayesflow as bf
dummy_prior = np.random.multivariate_normal([0, 0], np.eye(2), size=100)
dummy_posterior = np.random.multivariate_normal([0, 0], 0.1 * np.eye(2), size=100)
dummy_targets = np.array([0, 0])
fig = bf.diagnostics.plots.pairs_posterior(
estimates=dummy_posterior,
priors=dummy_prior,
)
for i in range(2):
for j in range(2):
ax = fig.axes[i, j]
if i == j:
ax.axvline(dummy_targets[i], color="red", linestyle="--", label="Target")
else:
ax.scatter(
dummy_targets[j],
dummy_targets[i],
color="red",
marker="x",
s=100,
label="Target",
)
fig.savefig("example_plots.png")
Output: example_plots.png
Metadata
Metadata
Assignees
Type
Projects
Status