Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add "low-rank" variational families #76

Merged
merged 46 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
03563ea
rename location scale source file
Red-Portal Aug 3, 2024
5ab7286
revert renaming of location_scale file
Red-Portal Aug 3, 2024
3e0bf3d
add location-low-rank-scale family (except `entropy` and `logpdf`)
Red-Portal Aug 3, 2024
0bd6e5c
add feature complete `MvLocationScaleLowRank` with tests
Red-Portal Aug 5, 2024
34546e1
fix remove misleading comment
Red-Portal Aug 5, 2024
e030f2d
fix add missing test files
Red-Portal Aug 5, 2024
c7f36d6
fix broadcasting error on Julia 1.6
Red-Portal Aug 5, 2024
1bb3e3e
fix bug in sampling from `LocationScaleLowRank`
Red-Portal Aug 7, 2024
ddd2122
fix missing squared bug in `LocationScaleLowRank`
Red-Portal Aug 7, 2024
b24737f
add documentation for low-rank families
Red-Portal Aug 9, 2024
1d56953
add convenience constructors for `LocationScaleLowRank`
Red-Portal Aug 9, 2024
6752c6b
Merge branch 'master' of github.com:TuringLang/AdvancedVI.jl into low…
Red-Portal Aug 10, 2024
52568b5
fix mhauru's suggestions and run formatter
Red-Portal Aug 10, 2024
96eae86
run formatter
Red-Portal Aug 10, 2024
15556da
run formatter
Red-Portal Aug 10, 2024
f796154
fix bugs and improve comments in `MvLocationScale` and lowrank
Red-Portal Aug 11, 2024
6b1699c
promote families.md into a higher category
Red-Portal Aug 11, 2024
5187d76
add test for `MVLocationScale` with non-Gaussian
Red-Portal Aug 14, 2024
8821908
Merge branch 'master' of github.com:TuringLang/AdvancedVI.jl into low…
Red-Portal Aug 27, 2024
6dfc919
tighten compat bound for `Distributions`
Red-Portal Aug 27, 2024
c3ce393
Merge branch 'master' of github.com:TuringLang/AdvancedVI.jl into low…
Red-Portal Sep 4, 2024
5c04d50
Merge branch 'master' of github.com:TuringLang/AdvancedVI.jl into low…
Red-Portal Sep 5, 2024
ba293e5
fix base distribution standardization bug in `LocationScale`
Red-Portal Sep 5, 2024
426d943
fix base distribution standardization bug in `LocationScaleLowRank`
Red-Portal Sep 5, 2024
3cc9e80
format weird indentation in test `for` loops
Red-Portal Sep 5, 2024
0481dda
update docs add example for `LocationScaleLowRank`
Red-Portal Sep 5, 2024
8449402
fix docs warn about divergence when using `MvLocationScaleLowRank`
Red-Portal Sep 6, 2024
ff14c4c
Merge branch 'master' of github.com:TuringLang/AdvancedVI.jl into low…
Red-Portal Sep 9, 2024
e48f231
Merge branch 'master' into lowrank
yebai Sep 10, 2024
aa8feee
Merge branch 'master' into lowrank
yebai Sep 10, 2024
5149869
Merge branch 'master' into lowrank
yebai Sep 10, 2024
e196da6
Update Benchmark.yml
yebai Sep 10, 2024
e4bff67
disable more features for PRs from forks
yebai Sep 10, 2024
894a849
fix `LocationScale` interfaces to only allow univariate base dist
Red-Portal Sep 11, 2024
f1cabba
Merge branch 'lowrank' of github.com:Red-Portal/AdvancedVI.jl into lo…
Red-Portal Sep 11, 2024
ce6793c
fix test comparison operator for families
Red-Portal Sep 11, 2024
71aeb5a
fix test comparison operator for families
Red-Portal Sep 11, 2024
77ace2b
fix test comparison operator for families
Red-Portal Sep 11, 2024
641de39
fix test comparison operator for families
Red-Portal Sep 11, 2024
a58f209
fix test comparison operator for families
Red-Portal Sep 11, 2024
846b259
fix test comparison operator for families
Red-Portal Sep 11, 2024
1116f68
fix test comparison operator for families
Red-Portal Sep 11, 2024
42d730d
fix formatting
Red-Portal Sep 11, 2024
99d08c5
fix formatting
Red-Portal Sep 11, 2024
4a90c5d
fix scale lower bound to `1e-4`
Red-Portal Sep 12, 2024
c41709b
fix docstring for `LowRankGaussian`
Red-Portal Sep 12, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/Benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ concurrency:
permissions:
contents: write
pull-requests: write
issues: write

jobs:
benchmark:
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ Accessors = "0.1"
Bijectors = "0.13"
ChainRulesCore = "1.16"
DiffResults = "1"
Distributions = "0.25.87"
Distributions = "0.25.111"
DocStringExtensions = "0.8, 0.9"
Enzyme = "0.12.32"
FillArrays = "1.3"
Expand Down
2 changes: 1 addition & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ makedocs(;
"ELBO Maximization" => [
"Overview" => "elbo/overview.md",
"Reparameterization Gradient Estimator" => "elbo/repgradelbo.md",
"Location-Scale Variational Family" => "locscale.md",
],
"Variational Families" => "families.md",
"Optimization" => "optimization.md",
],
)
Expand Down
267 changes: 267 additions & 0 deletions docs/src/families.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
# [Reparameterizable Variational Families](@id families)

The [RepGradELBO](@ref repgradelbo) objective assumes that the members of the variational family have a differentiable sampling path.
We provide multiple pre-packaged variational families that can be readily used.

## [The `LocationScale` Family](@id locscale)

The [location-scale](https://en.wikipedia.org/wiki/Location%E2%80%93scale_family) variational family is a family of probability distributions, where their sampling process can be represented as

```math
z \sim q_{\lambda} \qquad\Leftrightarrow\qquad
z \stackrel{d}{=} C u + m;\quad u \sim \varphi
```

where ``C`` is the *scale*, ``m`` is the location, and ``\varphi`` is the *base distribution*.
``m`` and ``C`` form the variational parameters ``\lambda = (m, C)`` of ``q_{\lambda}``.
The location-scale family encompases many practical variational families, which can be instantiated by setting the *base distribution* of ``u`` and the structure of ``C``.

The probability density is given by

```math
q_{\lambda}(z) = {|C|}^{-1} \varphi(C^{-1}(z - m)),
```

the covariance is given as

```math
\mathrm{Var}\left(q_{\lambda}\right) = C \mathrm{Var}(q_{\lambda}) C^{\top}
```

and the entropy is given as

```math
\mathbb{H}(q_{\lambda}) = \mathbb{H}(\varphi) + \log |C|,
```

where ``\mathbb{H}(\varphi)`` is the entropy of the base distribution.
Notice the ``\mathbb{H}(\varphi)`` does not depend on ``\log |C|``.
The derivative of the entropy with respect to ``\lambda`` is thus independent of the base distribution.

### API

!!! note

For stable convergence, the initial `scale` needs to be sufficiently large and well-conditioned.
Initializing `scale` to have small eigenvalues will often result in initial divergences and numerical instabilities.

```@docs
MvLocationScale
```

The following are specialized constructors for convenience:

```@docs
FullRankGaussian
MeanFieldGaussian
```

### Gaussian Variational Families

```julia
using AdvancedVI, LinearAlgebra, Distributions;
μ = zeros(2);

L = LowerTriangular(diagm(ones(2)));
q = FullRankGaussian(μ, L)

L = Diagonal(ones(2));
q = MeanFieldGaussian(μ, L)
```

### Student-$$t$$ Variational Families

```julia
using AdvancedVI, LinearAlgebra, Distributions;
μ = zeros(2);
ν = 3;

# Full-Rank
L = LowerTriangular(diagm(ones(2)));
q = MvLocationScale(μ, L, TDist(ν))

# Mean-Field
L = Diagonal(ones(2));
q = MvLocationScale(μ, L, TDist(ν))
```

### Laplace Variational families

```julia
using AdvancedVI, LinearAlgebra, Distributions;
μ = zeros(2);

# Full-Rank
L = LowerTriangular(diagm(ones(2)));
q = MvLocationScale(μ, L, Laplace())

# Mean-Field
L = Diagonal(ones(2));
q = MvLocationScale(μ, L, Laplace())
```

## The `LocationScaleLowRank` Family

In practice, `LocationScale` families with full-rank scale matrices are known to converge slowly as they require a small SGD stepsize.
Low-rank variational families can be an effective alternative[^ONS2018].
`LocationScaleLowRank` generally represent any ``d``-dimensional distribution which its sampling path can be represented as

```math
z \sim q_{\lambda} \qquad\Leftrightarrow\qquad
z \stackrel{d}{=} D u_1 + U u_2 + m;\quad u_1, u_2 \sim \varphi
```

where ``D \in \mathbb{R}^{d \times d}`` is a diagonal matrix, ``U \in \mathbb{R}^{d \times r}`` is a dense low-rank matrix for the rank ``r > 0``, ``m \in \mathbb{R}^d`` is the location, and ``\varphi`` is the *base distribution*.
``m``, ``D``, and ``U`` form the variational parameters ``\lambda = (m, D, U)``.

The covariance of this distribution is given as

```math
\mathrm{Var}\left(q_{\lambda}\right) = D \mathrm{Var}(\varphi) D + U \mathrm{Var}(\varphi) U^{\top}
```

and the entropy is given by the matrix determinant lemma as

```math
\mathbb{H}(q_{\lambda})
= \mathbb{H}(\varphi) + \log |\Sigma|
= \mathbb{H}(\varphi) + 2 \log |D| + \log |I + U^{\top} D^{-2} U|,
```

where ``\mathbb{H}(\varphi)`` is the entropy of the base distribution.

```@setup lowrank
using ADTypes
using AdvancedVI
using Distributions
using LinearAlgebra
using LogDensityProblems
using Optimisers
using Plots
using ReverseDiff

struct Target{D}
dist::D
end

function LogDensityProblems.logdensity(model::Target, θ)
logpdf(model.dist, θ)
end

function LogDensityProblems.dimension(model::Target)
return length(model.dist)
end

function LogDensityProblems.capabilities(::Type{<:Target})
return LogDensityProblems.LogDensityOrder{0}()
end

n_dims = 30
U_true = randn(n_dims, 3)
D_true = Diagonal(log.(1 .+ exp.(randn(n_dims))))
Σ_true = D_true + U_true*U_true'
Σsqrt_true = sqrt(Σ_true)
μ_true = randn(n_dims)
model = Target(MvNormal(μ_true, Σ_true));

d = LogDensityProblems.dimension(model);
μ = zeros(d);

L = Diagonal(ones(d));
q0_mf = MeanFieldGaussian(μ, L)

L = LowerTriangular(diagm(ones(d)));
q0_fr = FullRankGaussian(μ, L)

D = ones(n_dims)
U = zeros(n_dims, 3)
q0_lr = LowRankGaussian(μ, D, U)

obj = RepGradELBO(1);

max_iter = 10^4

function callback(; params, averaged_params, restructure, stat, kwargs...)
q = restructure(averaged_params)
μ, Σ = mean(q), cov(q)
(dist2 = sum(abs2, μ - μ_true) + tr(Σ + Σ_true - 2*sqrt(Σsqrt_true*Σ*Σsqrt_true)),)
end

_, _, stats_fr, _ = AdvancedVI.optimize(
model,
obj,
q0_fr,
max_iter;
show_progress = false,
adtype = AutoReverseDiff(),
optimizer = Adam(0.01),
averager = PolynomialAveraging(),
callback = callback,
);

_, _, stats_mf, _ = AdvancedVI.optimize(
model,
obj,
q0_mf,
max_iter;
show_progress = false,
adtype = AutoReverseDiff(),
optimizer = Adam(0.01),
averager = PolynomialAveraging(),
callback = callback,
);

_, _, stats_lr, _ = AdvancedVI.optimize(
model,
obj,
q0_lr,
max_iter;
show_progress = false,
adtype = AutoReverseDiff(),
optimizer = Adam(0.01),
averager = PolynomialAveraging(),
callback = callback,
);

t = [stat.iteration for stat in stats_fr]
dist_fr = [sqrt(stat.dist2) for stat in stats_fr]
dist_mf = [sqrt(stat.dist2) for stat in stats_mf]
dist_lr = [sqrt(stat.dist2) for stat in stats_lr]
plot( t, dist_mf , label="Mean-Field Gaussian", xlabel="Iteration", ylabel="Wasserstein-2 Distance")
plot!(t, dist_fr, label="Full-Rank Gaussian", xlabel="Iteration", ylabel="Wasserstein-2 Distance")
plot!(t, dist_lr, label="Low-Rank Gaussian", xlabel="Iteration", ylabel="Wasserstein-2 Distance")
savefig("lowrank_family_wasserstein.svg")
nothing
```

Consider a 30-dimensional Gaussian with a diagonal plus low-rank covariance structure, where the true rank is 3.
Then, we can compare the convergence speed of `LowRankGaussian` versus `FullRankGaussian`:

![](lowrank_family_wasserstein.svg)

As we can see, `LowRankGaussian` converges faster than `FullRankGaussian`.
While `FullRankGaussian` can converge to the true solution since it is a more expressive variational family, `LowRankGaussian` gets there faster.

!!! info
`MvLocationScaleLowRank` tend to work better with the `Optimisers.Adam` optimizer due to non-smoothness.
Other optimisers may experience divergences.


### API

```@docs
MvLocationScaleLowRank
```

The `logpdf` of `MvLocationScaleLowRank` has an optional argument `non_differentiable::Bool` (default: `false`).
If set as `true`, a more efficient ``O\left(r d^2\right)`` implementation is used to evaluate the density.
This, however, is not differentiable under most AD frameworks due to the use of Cholesky `lowrankupdate`.
The default value is `false`, which uses a ``O\left(d^3\right)`` implementation, is differentiable and therefore compatible with the `StickingTheLandingEntropy` estimator.

The following is a specialized constructor for convenience:

```@docs
LowRankGaussian
```

[^ONS2018]: Ong, V. M. H., Nott, D. J., & Smith, M. S. (2018). Gaussian variational approximation with a factor covariance structure. Journal of Computational and Graphical Statistics, 27(3), 465-478.
80 changes: 0 additions & 80 deletions docs/src/locscale.md

This file was deleted.

4 changes: 4 additions & 0 deletions src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,10 @@ export MvLocationScale, MeanFieldGaussian, FullRankGaussian

include("families/location_scale.jl")

export MvLocationScaleLowRank, LowRankGaussian

include("families/location_scale_low_rank.jl")

# Optimization Rules

include("optimization/rules.jl")
Expand Down
Loading
Loading