Skip to content

improve interface for custom density estimators passed to inference classes #1531

@janfb

Description

@janfb

In all inference methods (child classes of NeuralInference) we generally allow to either pass a string for building the density estimator internally, e.g., NPE(density_estimator="maf") or we allow passing a plain Callable and then try to build the network by calling density_estimator(theta, x) internally.

Two problems here:

  1. we should make it more explicit for the user that what to pass here, e.g., by defining a protocol DensityEstimatorBuilder that makes sure the custom density estimator builder actually returns an object we can work with downstream. As a blue print, see sbi/inference/trainers/npse/vector_field_inference.py in Unify flow matching and score-based models #1497
  2. more generally, the naming is a bit confusing as pointed out by @StarostinV because it is not really an density_estimator but rather a density_estimator_build_fn. However, this would imply a central API change.

Metadata

Metadata

Assignees

No one assigned

    Labels

    API changesThis impacts the public API of the project (e.g. inference class).

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions