Skip to content

Inconsistency in documented/expected shapes for estimators #1500

Open
@sethaxen

Description

@sethaxen

🐛 Bug Description

ConditionalDensityEstimator and RatioEstimator on sbi main use different names, orders, and shape expectations. Also, RatioEstimator's documented shape expectations are incompatible with what it actually expects.

Details

While ConditionalDensityEstimator (CDE) and RatioEstimator (RE) do not share a common parent type, ideally they would still be as consistent as possible. I assume here that x in RE and input in CDE are roughly the same and theta and input are roughly the same, respectively.

Inconsistent attributes

CDE uses attributes (input_shape, condition_shape):

self, net: nn.Module, input_shape: torch.Size, condition_shape: torch.Size

RE uses attributes (theta_shape, x_shape):
theta_shape: torch.Size | tuple[int, ...],
x_shape: torch.Size | tuple[int, ...],

Their order in the constructors are reversed.

Inconsistent shapes

CDE documents that the shape of the input is (sample_dim, batch_dim, *input_shape) and the shape of condition is (batch_dim, *condition_shape). While it doesn't enforce this, at least some children do. e.g. MixedDensityEstimator:

batch_dim = condition.shape[0]

input_sample_dim, input_batch_dim = input.shape[:2]

RE documents that the shape of x is (batch_dim, *x_shape) and that the shape of theta is (sample_dim, batch_dim, *theta_shape). Note that the two classes differ in which of the arguments has a sample_dim. However, RE actually enforces that x is (*batch_shape, *x_shape) and theta is (*batch_shape, *theta_shape), i.e. the two arguments share the same prefix, which is incompatible with the documented shapes:

theta_prefix = theta.shape[: -len(self.theta_shape)]
x_prefix = x.shape[: -len(self.x_shape)]
if theta_prefix != x_prefix:

Inconsistent argument order in methods

While CDE.log_prob and RE.unnormalized_log_ratio are not equivalent, one would expect their order of arguments to be similar. However, the former takes the order (input, condition):

def log_prob(self, input: Tensor, condition: Tensor, **kwargs) -> Tensor:

while the latter takes (theta, x):
def unnormalized_log_ratio(self, theta: Tensor, x: Tensor) -> Tensor:

📌 Additional Context

Torch distributions (and Pyro) implement both sample and log_prob, supporting arbitrary batch_shape and sample_shape (not just a single dimension). While not necessary, it would be nice if these methods supported the same shape conventions. This would in particular simplify #1491.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions