Skip to content

diagnostics.pairs_posterior() plots targets behind other elements #430

Closed
@thegialeo

Description

@thegialeo

The diagnostics.pairs_posterior() function currently plots targets behind the prior and posterior distributions, making them hard to see. Additionally, there is limited customization for the appearance and legend handling of the targets, and no way to adjust their rendering order due to separate axes being used in diagonal subplots. This does not affact non-diagonal subplots.

Minimal example to reproduce

import numpy as np
import bayesflow as bf

dummy_prior = np.random.multivariate_normal([0, 0], np.eye(2), size=100)
dummy_posterior = np.random.multivariate_normal([0, 0], 0.1 * np.eye(2), size=100)
dummy_targets = np.array([0, 0])

fig = bf.diagnostics.plots.pairs_posterior(
    estimates=dummy_posterior,
    priors=dummy_prior,
    targets=dummy_targets,
)

fig.savefig("test_plots.png")

Output: test_plots.png

Image

Suggested improvements

  • Plot targets on top of prior and posterior plots.
  • Ensure targets appear in both diagonal and non-diagonal subplots.
  • Allow user to specify the color of the targets.
  • Return or attach the diagonal axes (from histplot_twinx) so users can adjust them or add custom elements.
  • Provide more control over legends, such as adding targets to legend.

Root cause / possible fix

The issue seems to originate from the histplot_twinx function used in _pairs_samples(). It creates a new twin axis which is not included in the returned g object. As a result, targets and prior/posterior elements end up on different axes, preventing user control (e.g., via zorder). This issue does not affect non-diagonal plots since they share the same axes.

import numpy as np
import bayesflow as bf

dummy_prior = np.random.multivariate_normal([0, 0], np.eye(2), size=100)
dummy_posterior = np.random.multivariate_normal([0, 0], 0.1 * np.eye(2), size=100)
dummy_targets = np.array([0, 0])

fig = bf.diagnostics.plots.pairs_posterior(
    estimates=dummy_posterior,
    priors=dummy_prior,
)

for i in range(2):
    for j in range(2):
        ax = fig.axes[i, j]
        if i == j:
            ax.axvline(dummy_targets[i], color="red", linestyle="--", label="Target")
        else:
            ax.scatter(
                dummy_targets[j],
                dummy_targets[i],
                color="red",
                marker="x",
                s=100,
                label="Target",
            )

fig.savefig("example_plots.png")

Output: example_plots.png

Image

Metadata

Metadata

Assignees

Labels

refactoringSome code shall be redesigneduser interfaceChanges to the user interface and improvements in usability

Type

Projects

Status

Done

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions