Description
🐛 Bug Description
ConditionalDensityEstimator
and RatioEstimator
on sbi
main use different names, orders, and shape expectations. Also, RatioEstimator
's documented shape expectations are incompatible with what it actually expects.
Details
While ConditionalDensityEstimator
(CDE
) and RatioEstimator
(RE
) do not share a common parent type, ideally they would still be as consistent as possible. I assume here that x
in RE
and input
in CDE are roughly the same and theta
and input
are roughly the same, respectively.
Inconsistent attributes
CDE
uses attributes (input_shape, condition_shape)
:
sbi/sbi/neural_nets/estimators/base.py
Line 135 in a1d7555
RE
uses attributes (theta_shape, x_shape)
: sbi/sbi/neural_nets/ratio_estimators.py
Lines 29 to 30 in a1d7555
Their order in the constructors are reversed.
Inconsistent shapes
CDE
documents that the shape of the input
is (sample_dim, batch_dim, *input_shape)
and the shape of condition
is (batch_dim, *condition_shape)
. While it doesn't enforce this, at least some children do. e.g. MixedDensityEstimator
:
RE
documents that the shape of x
is (batch_dim, *x_shape)
and that the shape of theta
is (sample_dim, batch_dim, *theta_shape)
. Note that the two classes differ in which of the arguments has a sample_dim
. However, RE
actually enforces that x
is (*batch_shape, *x_shape)
and theta
is (*batch_shape, *theta_shape)
, i.e. the two arguments share the same prefix, which is incompatible with the documented shapes:
sbi/sbi/neural_nets/ratio_estimators.py
Lines 126 to 128 in a1d7555
Inconsistent argument order in methods
While CDE.log_prob
and RE.unnormalized_log_ratio
are not equivalent, one would expect their order of arguments to be similar. However, the former takes the order (input, condition)
:
sbi/sbi/neural_nets/estimators/base.py
Line 155 in a1d7555
while the latter takes
(theta, x)
:sbi/sbi/neural_nets/ratio_estimators.py
Line 156 in a1d7555
📌 Additional Context
Torch distributions (and Pyro) implement both sample
and log_prob
, supporting arbitrary batch_shape
and sample_shape
(not just a single dimension). While not necessary, it would be nice if these methods supported the same shape conventions. This would in particular simplify #1491.