Add "low-rank" variational families #192
Annotations
4 warnings
Run julia-actions/julia-docdeploy@v1:
../../../.julia/packages/Documenter/bYYzK/src/Utilities/Utilities.jl#L34
failed to run `@setup` block in src/families.md
```@setup lowrank
using ADTypes
using AdvancedVI
using Distributions
using LinearAlgebra
using LogDensityProblems
using Optimisers
using Plots
using ReverseDiff
struct Target{D}
dist::D
end
function LogDensityProblems.logdensity(model::Target, θ)
logpdf(model.dist, θ)
end
function LogDensityProblems.dimension(model::Target)
return length(model.dist)
end
function LogDensityProblems.capabilities(::Type{<:Target})
return LogDensityProblems.LogDensityOrder{0}()
end
n_dims = 30
U_true = randn(n_dims, 3)
D_true = Diagonal(log.(1 .+ exp.(randn(n_dims))))
Σ_true = D_true + U_true*U_true'
Σsqrt_true = sqrt(Σ_true)
μ_true = randn(n_dims)
model = Target(MvNormal(μ_true, Σ_true));
d = LogDensityProblems.dimension(model);
μ = zeros(d);
L = Diagonal(ones(d));
q0_mf = MeanFieldGaussian(μ, L)
L = LowerTriangular(diagm(ones(d)));
q0_fr = FullRankGaussian(μ, L)
D = ones(n_dims)
U = zeros(n_dims, 3)
q0_lr = LowRankGaussian(μ, D, U)
obj = RepGradELBO(1);
max_iter = 10^4
function callback(; params, averaged_params, restructure, stat, kwargs...)
q = restructure(averaged_params)
μ, Σ = mean(q), cov(q)
(dist2 = sum(abs2, μ - μ_true) + tr(Σ + Σ_true - 2*sqrt(Σsqrt_true*Σ*Σsqrt_true)),)
end
_, _, stats_fr, _ = AdvancedVI.optimize(
model,
obj,
q0_fr,
max_iter;
show_progress = false,
adtype = AutoReverseDiff(),
optimizer = Adam(0.01),
averager = PolynomialAveraging(),
callback = callback,
);
_, _, stats_mf, _ = AdvancedVI.optimize(
model,
obj,
q0_mf,
max_iter;
show_progress = false,
adtype = AutoReverseDiff(),
optimizer = Adam(0.01),
averager = PolynomialAveraging(),
callback = callback,
);
_, _, stats_lr, _ = AdvancedVI.optimize(
model,
obj,
q0_lr,
max_iter;
show_progress = false,
adtype = AutoReverseDiff(),
optimizer = Adam(0.01),
averager = PolynomialAveraging(),
callback = callback,
);
t = [stat.iteration for stat in stats_fr]
dist_fr = [sqrt(stat.dist2) for stat in stats_fr]
dist_mf = [sqrt(stat.dist2) for stat in stats_mf]
dist_lr = [sqrt(stat.dist2) for stat in stats_lr]
plot( t, dist_mf , label="Mean-Field Gaussian", xlabel="Iteration", ylabel="Wasserstein-2 Distance")
plot!(t, dist_fr, label="Full-Rank Gaussian", xlabel="Iteration", ylabel="Wasserstein-2 Distance")
plot!(t, dist_lr, label="Low-Rank Gaussian", xlabel="Iteration", ylabel="Wasserstein-2 Distance")
savefig("lowrank_family_wasserstein.svg")
nothing
```
exception =
LoadError: ArgumentError: Package ReverseDiff not found in current path.
- Run `import Pkg; Pkg.add("ReverseDiff")` to install the ReverseDiff package.
in expression starting at string:8
|
Run julia-actions/julia-docdeploy@v1:
../../../.julia/packages/Documenter/bYYzK/src/Utilities/Utilities.jl#L34
7 docstrings not included in the manual:
AdvancedVI.value_and_gradient!
AdvancedVI.ClosedFormEntropy
AdvancedVI.value :: Tuple{AdvancedVI.AbstractAverager, Any}
AdvancedVI.estimate_entropy
AdvancedVI.restructure_ad_forward :: Tuple{ADTypes.AbstractADType, Any, Any}
AdvancedVI.apply :: Tuple{AdvancedVI.AbstractAverager, Any, Any}
AdvancedVI.update_variational_params!
These are docstrings in the checked modules (configured with the modules keyword)
that are not included in @docs or @autodocs blocks.
|
Run julia-actions/julia-docdeploy@v1:
../../../.julia/packages/Documenter/bYYzK/src/Writers/HTMLWriter.jl#L2103
invalid local image: unresolved path in families.md
link = "lowrank_family_wasserstein.svg"
|
The following actions use a deprecated Node.js version and will be forced to run on node20: julia-actions/setup-julia@v1. For more info: https://github.blog/changelog/2024-03-07-github-actions-all-actions-will-run-on-node20-instead-of-node16-by-default/
|
Loading