Skip to content

VI interface is too low-level #2783

@penelopeysm

Description

@penelopeysm

Turing in general has been moving away from making people declare vectorised inputs and getting vectorised outputs. While recognising that this can sometimes be useful, it's also unfortunately quite unhelpful because vectorised formats lose meaning: they rely on internal DynamicPPL details of how variables are organised, they don't convey where values are linked or unlinked, etc.

It is pretty much always more useful to always store samples as VarNamedTuples of raw values, and leave the vectorised samples for real experts who really need to know every detail.

In particular, MCMC sampling and optimisation now refuse to take vectors as initial parameters. Likewise, optimisation does not return a vectorised optimum but rather a VarNamedTuple; if people want to get a vector they have to call a separate function vector_names_and_params(::ModeResult).

Since the MCMC and optimisation interfaces have been updated, the remaining one that still needs some coaxing is VI. In particular, we have that

q, info, state = vi(...)

and q is right now a Bijectors.TransformedDistribution, from which you sample a vector of values. In light of the above changes, I think this is no longer appropriate. q should be a wrapper around the transformed distribution plus a LogDensityFunction, from which you can sample a vector of values, and then feed it into the LDF to get back a VarNamedTuple of values. Calling rand(q) should give you a VarNamedTuple, or possibly a ParamsWithStats, I'm not entirely sure right now.

(Note that the requisite LDF is actually stored as state.prob, so all the info we need is already there.)

This would solve the issue in the docs where "extracting VarNames is a bit verbose at the moment" (https://turinglang.org/docs/tutorials/variational-inference/#obtaining-summary-statistics) (in fact, it's not only verbose -- it is also straight up wrong, for example, it will run into the same LKJCholesky bug that optimisation used to face #2734).

It would also solve a current thorny bit where VI expects that the bijector from constrained to unconstrained space is a constant (and can be obtained via bijector(model)). This works for many models, but is not really correct for models such as

@model function f()
    x ~ Normal()
    y ~ truncated(Normal(); lower=x)
end

because the bijector for y depends on the value of x.

By always transforming to unconstrained space, running VI on the unconstrained LDF and then including the wrapper to convert back to constrained space, we ensure that it behaves correctly even on cases like this.

cc. @Red-Portal

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions