From 57c9e58d09e121e248794aa0e53d373274362728 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sat, 14 Sep 2024 01:32:16 +0900 Subject: [PATCH] Add "low-rank" variational families (#76) * add feature complete `MvLocationScaleLowRank` with tests * fix bugs and improve comments in `MvLocationScale` and lowrank * promote families.md into a higher category * add test for `MVLocationScale` with non-Gaussian * tighten compat bound for `Distributions` * fix base distribution standardization bug in `LocationScale` and `LocationScaleLowRank` * fix `LocationScale` interfaces to only allow univariate base dist * fix test comparison operator for families * fix scale lower bound to `1e-4` --------- Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> Co-authored-by: Markus Hauru --- .github/workflows/Benchmark.yml | 5 +- Project.toml | 2 +- docs/make.jl | 2 +- docs/src/families.md | 267 ++++++++++++++++++ docs/src/locscale.md | 80 ------ src/AdvancedVI.jl | 4 + src/families/location_scale.jl | 70 +++-- src/families/location_scale_low_rank.jl | 176 ++++++++++++ test/Project.toml | 2 +- .../{interface => families}/location_scale.jl | 123 ++++---- test/families/location_scale_low_rank.jl | 178 ++++++++++++ test/runtests.jl | 6 +- 12 files changed, 745 insertions(+), 170 deletions(-) create mode 100644 docs/src/families.md delete mode 100644 docs/src/locscale.md create mode 100644 src/families/location_scale_low_rank.jl rename test/{interface => families}/location_scale.jl (60%) create mode 100644 test/families/location_scale_low_rank.jl diff --git a/.github/workflows/Benchmark.yml b/.github/workflows/Benchmark.yml index 27f4091b..b161782d 100644 --- a/.github/workflows/Benchmark.yml +++ b/.github/workflows/Benchmark.yml @@ -13,6 +13,7 @@ concurrency: permissions: contents: write pull-requests: write + issues: write jobs: benchmark: @@ -47,10 +48,10 @@ jobs: name: Benchmark Results tool: 'julia' output-file-path: bench/benchmark_results.json - summary-always: true + summary-always: ${{ !github.event.pull_request.head.repo.fork }} # Disable summary for PRs from forks github-token: ${{ secrets.GITHUB_TOKEN }} - comment-always: true alert-threshold: "200%" fail-on-alert: true benchmark-data-dir-path: benchmarks + comment-always: ${{ !github.event.pull_request.head.repo.fork }} # Disable comments for PRs from forks auto-push: ${{ !github.event.pull_request.head.repo.fork }} # Disable push for PRs from forks diff --git a/Project.toml b/Project.toml index 62f7a550..3219992c 100644 --- a/Project.toml +++ b/Project.toml @@ -42,7 +42,7 @@ Accessors = "0.1" Bijectors = "0.13" ChainRulesCore = "1.16" DiffResults = "1" -Distributions = "0.25.87" +Distributions = "0.25.111" DocStringExtensions = "0.8, 0.9" Enzyme = "0.12.32" FillArrays = "1.3" diff --git a/docs/make.jl b/docs/make.jl index b71d9a4f..c70bf05f 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -16,8 +16,8 @@ makedocs(; "ELBO Maximization" => [ "Overview" => "elbo/overview.md", "Reparameterization Gradient Estimator" => "elbo/repgradelbo.md", - "Location-Scale Variational Family" => "locscale.md", ], + "Variational Families" => "families.md", "Optimization" => "optimization.md", ], ) diff --git a/docs/src/families.md b/docs/src/families.md new file mode 100644 index 00000000..e270acad --- /dev/null +++ b/docs/src/families.md @@ -0,0 +1,267 @@ +# [Reparameterizable Variational Families](@id families) + +The [RepGradELBO](@ref repgradelbo) objective assumes that the members of the variational family have a differentiable sampling path. +We provide multiple pre-packaged variational families that can be readily used. + +## [The `LocationScale` Family](@id locscale) + +The [location-scale](https://en.wikipedia.org/wiki/Location%E2%80%93scale_family) variational family is a family of probability distributions, where their sampling process can be represented as + +```math +z \sim q_{\lambda} \qquad\Leftrightarrow\qquad +z \stackrel{d}{=} C u + m;\quad u \sim \varphi +``` + +where ``C`` is the *scale*, ``m`` is the location, and ``\varphi`` is the *base distribution*. +``m`` and ``C`` form the variational parameters ``\lambda = (m, C)`` of ``q_{\lambda}``. +The location-scale family encompases many practical variational families, which can be instantiated by setting the *base distribution* of ``u`` and the structure of ``C``. + +The probability density is given by + +```math + q_{\lambda}(z) = {|C|}^{-1} \varphi(C^{-1}(z - m)), +``` + +the covariance is given as + +```math + \mathrm{Var}\left(q_{\lambda}\right) = C \mathrm{Var}(q_{\lambda}) C^{\top} +``` + +and the entropy is given as + +```math + \mathbb{H}(q_{\lambda}) = \mathbb{H}(\varphi) + \log |C|, +``` + +where ``\mathbb{H}(\varphi)`` is the entropy of the base distribution. +Notice the ``\mathbb{H}(\varphi)`` does not depend on ``\log |C|``. +The derivative of the entropy with respect to ``\lambda`` is thus independent of the base distribution. + +### API + +!!! note + + For stable convergence, the initial `scale` needs to be sufficiently large and well-conditioned. + Initializing `scale` to have small eigenvalues will often result in initial divergences and numerical instabilities. + +```@docs +MvLocationScale +``` + +The following are specialized constructors for convenience: + +```@docs +FullRankGaussian +MeanFieldGaussian +``` + +### Gaussian Variational Families + +```julia +using AdvancedVI, LinearAlgebra, Distributions; +μ = zeros(2); + +L = LowerTriangular(diagm(ones(2))); +q = FullRankGaussian(μ, L) + +L = Diagonal(ones(2)); +q = MeanFieldGaussian(μ, L) +``` + +### Student-$$t$$ Variational Families + +```julia +using AdvancedVI, LinearAlgebra, Distributions; +μ = zeros(2); +ν = 3; + +# Full-Rank +L = LowerTriangular(diagm(ones(2))); +q = MvLocationScale(μ, L, TDist(ν)) + +# Mean-Field +L = Diagonal(ones(2)); +q = MvLocationScale(μ, L, TDist(ν)) +``` + +### Laplace Variational families + +```julia +using AdvancedVI, LinearAlgebra, Distributions; +μ = zeros(2); + +# Full-Rank +L = LowerTriangular(diagm(ones(2))); +q = MvLocationScale(μ, L, Laplace()) + +# Mean-Field +L = Diagonal(ones(2)); +q = MvLocationScale(μ, L, Laplace()) +``` + +## The `LocationScaleLowRank` Family + +In practice, `LocationScale` families with full-rank scale matrices are known to converge slowly as they require a small SGD stepsize. +Low-rank variational families can be an effective alternative[^ONS2018]. +`LocationScaleLowRank` generally represent any ``d``-dimensional distribution which its sampling path can be represented as + +```math +z \sim q_{\lambda} \qquad\Leftrightarrow\qquad +z \stackrel{d}{=} D u_1 + U u_2 + m;\quad u_1, u_2 \sim \varphi +``` + +where ``D \in \mathbb{R}^{d \times d}`` is a diagonal matrix, ``U \in \mathbb{R}^{d \times r}`` is a dense low-rank matrix for the rank ``r > 0``, ``m \in \mathbb{R}^d`` is the location, and ``\varphi`` is the *base distribution*. +``m``, ``D``, and ``U`` form the variational parameters ``\lambda = (m, D, U)``. + +The covariance of this distribution is given as + +```math + \mathrm{Var}\left(q_{\lambda}\right) = D \mathrm{Var}(\varphi) D + U \mathrm{Var}(\varphi) U^{\top} +``` + +and the entropy is given by the matrix determinant lemma as + +```math + \mathbb{H}(q_{\lambda}) + = \mathbb{H}(\varphi) + \log |\Sigma| + = \mathbb{H}(\varphi) + 2 \log |D| + \log |I + U^{\top} D^{-2} U|, +``` + +where ``\mathbb{H}(\varphi)`` is the entropy of the base distribution. + +```@setup lowrank +using ADTypes +using AdvancedVI +using Distributions +using LinearAlgebra +using LogDensityProblems +using Optimisers +using Plots +using ReverseDiff + +struct Target{D} + dist::D +end + +function LogDensityProblems.logdensity(model::Target, θ) + logpdf(model.dist, θ) +end + +function LogDensityProblems.dimension(model::Target) + return length(model.dist) +end + +function LogDensityProblems.capabilities(::Type{<:Target}) + return LogDensityProblems.LogDensityOrder{0}() +end + +n_dims = 30 +U_true = randn(n_dims, 3) +D_true = Diagonal(log.(1 .+ exp.(randn(n_dims)))) +Σ_true = D_true + U_true*U_true' +Σsqrt_true = sqrt(Σ_true) +μ_true = randn(n_dims) +model = Target(MvNormal(μ_true, Σ_true)); + +d = LogDensityProblems.dimension(model); +μ = zeros(d); + +L = Diagonal(ones(d)); +q0_mf = MeanFieldGaussian(μ, L) + +L = LowerTriangular(diagm(ones(d))); +q0_fr = FullRankGaussian(μ, L) + +D = ones(n_dims) +U = zeros(n_dims, 3) +q0_lr = LowRankGaussian(μ, D, U) + +obj = RepGradELBO(1); + +max_iter = 10^4 + +function callback(; params, averaged_params, restructure, stat, kwargs...) + q = restructure(averaged_params) + μ, Σ = mean(q), cov(q) + (dist2 = sum(abs2, μ - μ_true) + tr(Σ + Σ_true - 2*sqrt(Σsqrt_true*Σ*Σsqrt_true)),) +end + +_, _, stats_fr, _ = AdvancedVI.optimize( + model, + obj, + q0_fr, + max_iter; + show_progress = false, + adtype = AutoReverseDiff(), + optimizer = Adam(0.01), + averager = PolynomialAveraging(), + callback = callback, +); + +_, _, stats_mf, _ = AdvancedVI.optimize( + model, + obj, + q0_mf, + max_iter; + show_progress = false, + adtype = AutoReverseDiff(), + optimizer = Adam(0.01), + averager = PolynomialAveraging(), + callback = callback, +); + +_, _, stats_lr, _ = AdvancedVI.optimize( + model, + obj, + q0_lr, + max_iter; + show_progress = false, + adtype = AutoReverseDiff(), + optimizer = Adam(0.01), + averager = PolynomialAveraging(), + callback = callback, +); + +t = [stat.iteration for stat in stats_fr] +dist_fr = [sqrt(stat.dist2) for stat in stats_fr] +dist_mf = [sqrt(stat.dist2) for stat in stats_mf] +dist_lr = [sqrt(stat.dist2) for stat in stats_lr] +plot( t, dist_mf , label="Mean-Field Gaussian", xlabel="Iteration", ylabel="Wasserstein-2 Distance") +plot!(t, dist_fr, label="Full-Rank Gaussian", xlabel="Iteration", ylabel="Wasserstein-2 Distance") +plot!(t, dist_lr, label="Low-Rank Gaussian", xlabel="Iteration", ylabel="Wasserstein-2 Distance") +savefig("lowrank_family_wasserstein.svg") +nothing +``` + +Consider a 30-dimensional Gaussian with a diagonal plus low-rank covariance structure, where the true rank is 3. +Then, we can compare the convergence speed of `LowRankGaussian` versus `FullRankGaussian`: + +![](lowrank_family_wasserstein.svg) + +As we can see, `LowRankGaussian` converges faster than `FullRankGaussian`. +While `FullRankGaussian` can converge to the true solution since it is a more expressive variational family, `LowRankGaussian` gets there faster. + +!!! info + + `MvLocationScaleLowRank` tend to work better with the `Optimisers.Adam` optimizer due to non-smoothness. + Other optimisers may experience divergences. + +### API + +```@docs +MvLocationScaleLowRank +``` + +The `logpdf` of `MvLocationScaleLowRank` has an optional argument `non_differentiable::Bool` (default: `false`). +If set as `true`, a more efficient ``O\left(r d^2\right)`` implementation is used to evaluate the density. +This, however, is not differentiable under most AD frameworks due to the use of Cholesky `lowrankupdate`. +The default value is `false`, which uses a ``O\left(d^3\right)`` implementation, is differentiable and therefore compatible with the `StickingTheLandingEntropy` estimator. + +The following is a specialized constructor for convenience: + +```@docs +LowRankGaussian +``` + +[^ONS2018]: Ong, V. M. H., Nott, D. J., & Smith, M. S. (2018). Gaussian variational approximation with a factor covariance structure. Journal of Computational and Graphical Statistics, 27(3), 465-478. diff --git a/docs/src/locscale.md b/docs/src/locscale.md deleted file mode 100644 index 643c3a98..00000000 --- a/docs/src/locscale.md +++ /dev/null @@ -1,80 +0,0 @@ - -# [Location-Scale Variational Family](@id locscale) - -## Introduction -The [location-scale](https://en.wikipedia.org/wiki/Location%E2%80%93scale_family) variational family is a family of probability distributions, where their sampling process can be represented as -```math -z \sim q_{\lambda} \qquad\Leftrightarrow\qquad -z \stackrel{d}{=} C u + m;\quad u \sim \varphi -``` -where ``C`` is the *scale*, ``m`` is the location, and ``\varphi`` is the *base distribution*. -``m`` and ``C`` form the variational parameters ``\lambda = (m, C)`` of ``q_{\lambda}``. -The location-scale family encompases many practical variational families, which can be instantiated by setting the *base distribution* of ``u`` and the structure of ``C``. - -The probability density is given by -```math - q_{\lambda}(z) = {|C|}^{-1} \varphi(C^{-1}(z - m)) -``` -and the entropy is given as -```math - \mathbb{H}(q_{\lambda}) = \mathbb{H}(\varphi) + \log |C|, -``` -where ``\mathbb{H}(\varphi)`` is the entropy of the base distribution. -Notice the ``\mathbb{H}(\varphi)`` does not depend on ``\log |C|``. -The derivative of the entropy with respect to ``\lambda`` is thus independent of the base distribution. - -## Constructors - -!!! note - For stable convergence, the initial `scale` needs to be sufficiently large and well-conditioned. - Initializing `scale` to have small eigenvalues will often result in initial divergences and numerical instabilities. - -```@docs -MvLocationScale -``` - -```@docs -FullRankGaussian -MeanFieldGaussian -``` - -## Gaussian Variational Families -```julia -using AdvancedVI, LinearAlgebra, Distributions; -μ = zeros(2); - -L = diagm(ones(2)) |> LowerTriangular; -q = FullRankGaussian(μ, L) - -L = ones(2) |> Diagonal; -q = MeanFieldGaussian(μ, L) -``` - -## Sudent-$$t$$ Variational Families -```julia -using AdvancedVI, LinearAlgebra, Distributions; -μ = zeros(2); -ν = 3; - -# Full-Rank -L = diagm(ones(2)) |> LowerTriangular; -q = MvLocationScale(μ, L, TDist(ν)) - -# Mean-Field -L = ones(2) |> Diagonal; -q = MvLocationScale(μ, L, TDist(ν)) -``` - -## Laplace Variational families -```julia -using AdvancedVI, LinearAlgebra, Distributions; -μ = zeros(2); - -# Full-Rank -L = diagm(ones(2)) |> LowerTriangular; -q = MvLocationScale(μ, L, Laplace()) - -# Mean-Field -L = ones(2) |> Diagonal; -q = MvLocationScale(μ, L, Laplace()) -``` diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 8ac1b645..5402e075 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -180,6 +180,10 @@ export MvLocationScale, MeanFieldGaussian, FullRankGaussian include("families/location_scale.jl") +export MvLocationScaleLowRank, LowRankGaussian + +include("families/location_scale_low_rank.jl") + # Optimization Rules include("optimization/rules.jl") diff --git a/src/families/location_scale.jl b/src/families/location_scale.jl index 1aab2e71..22af4b4a 100644 --- a/src/families/location_scale.jl +++ b/src/families/location_scale.jl @@ -1,6 +1,14 @@ +struct MvLocationScale{S,D<:ContinuousDistribution,L,E<:Real} <: + ContinuousMultivariateDistribution + location::L + scale::S + dist::D + scale_eps::E +end + """ - MvLocationScale(location, scale, dist) <: ContinuousMultivariateDistribution + MvLocationScale(location, scale, dist; scale_eps) The location scale variational family broadly represents various variational families using `location` and `scale` variational parameters. @@ -12,21 +20,20 @@ represented as follows: u = rand(dist, d) z = scale*u + location ``` -""" -struct MvLocationScale{S,D<:ContinuousDistribution,L,E<:Real} <: - ContinuousMultivariateDistribution - location::L - scale::S - dist::D - scale_eps::E -end +`scale_eps` sets a constraint on the smallest value of `scale` to be enforced during optimization. +This is necessary to guarantee stable convergence. + +# Keyword Arguments +- `scale_eps`: Lower bound constraint for the diagonal of the scale. (default: `1e-4`). +""" function MvLocationScale( location::AbstractVector{T}, scale::AbstractMatrix{T}, - dist::ContinuousDistribution; - scale_eps::T=sqrt(eps(T)), + dist::ContinuousUnivariateDistribution; + scale_eps::T=T(1e-4), ) where {T<:Real} + @assert minimum(diag(scale)) ≥ scale_eps "Initial scale is too small (smallest diagonal value is $(minimum(diag(scale)))). This might result in unstable optimization behavior." return MvLocationScale(location, scale, dist, scale_eps) end @@ -37,8 +44,8 @@ Functors.@functor MvLocationScale (location, scale) # `scale <: Diagonal`, which is not the default behavior. Otherwise, forward-mode AD # is very inefficient. # begin -struct RestructureMeanField{S<:Diagonal,D,L} - model::MvLocationScale{S,D,L} +struct RestructureMeanField{S<:Diagonal,D,L,E} + model::MvLocationScale{S,D,L,E} end function (re::RestructureMeanField)(flat::AbstractVector) @@ -48,7 +55,7 @@ function (re::RestructureMeanField)(flat::AbstractVector) return MvLocationScale(location, scale, re.model.dist, re.model.scale_eps) end -function Optimisers.destructure(q::MvLocationScale{<:Diagonal,D,L}) where {D,L} +function Optimisers.destructure(q::MvLocationScale{<:Diagonal,D,L,E}) where {D,L,E} @unpack location, scale, dist = q flat = vcat(location, diag(scale)) return flat, RestructureMeanField(q) @@ -59,7 +66,7 @@ Base.length(q::MvLocationScale) = length(q.location) Base.size(q::MvLocationScale) = size(q.location) -Base.eltype(::Type{<:MvLocationScale{S,D,L}}) where {S,D,L} = eltype(D) +Base.eltype(::Type{<:MvLocationScale{S,D,L,E}}) where {S,D,L,E} = eltype(D) function StatsBase.entropy(q::MvLocationScale) @unpack location, scale, dist = q @@ -106,54 +113,57 @@ function Distributions._rand!( return x .+= location end -Distributions.mean(q::MvLocationScale) = q.location +function Distributions.mean(q::MvLocationScale) + @unpack location, scale = q + return location + scale * Fill(mean(q.dist), length(location)) +end function Distributions.var(q::MvLocationScale) C = q.scale - return Diagonal(C * C') + σ2 = var(q.dist) + return σ2 * diag(C * C') end function Distributions.cov(q::MvLocationScale) C = q.scale - return Hermitian(C * C') + σ2 = var(q.dist) + return σ2 * Hermitian(C * C') end """ - FullRankGaussian(location, scale; check_args = true) + FullRankGaussian(μ, L; scale_eps) Construct a Gaussian variational approximation with a dense covariance matrix. # Arguments -- `location::AbstractVector{T}`: Mean of the Gaussian. -- `scale::LinearAlgebra.AbstractTriangular{T}`: Cholesky factor of the covariance of the Gaussian. +- `μ::AbstractVector{T}`: Mean of the Gaussian. +- `L::LinearAlgebra.AbstractTriangular{T}`: Cholesky factor of the covariance of the Gaussian. # Keyword Arguments -- `check_args`: Check the conditioning of the initial scale (default: `true`). +- `scale_eps`: Smallest value allowed for the diagonal of the scale. (default: `1e-4`). """ function FullRankGaussian( - μ::AbstractVector{T}, L::LinearAlgebra.AbstractTriangular{T}; scale_eps::T=sqrt(eps(T)) + μ::AbstractVector{T}, L::LinearAlgebra.AbstractTriangular{T}; scale_eps::T=T(1e-4) ) where {T<:Real} - @assert minimum(diag(L)) ≥ sqrt(scale_eps) "Initial scale is too small (smallest diagonal value is $(minimum(diag(L)))). This might result in unstable optimization behavior." q_base = Normal{T}(zero(T), one(T)) return MvLocationScale(μ, L, q_base, scale_eps) end """ - MeanFieldGaussian(location, scale; check_args = true) + MeanFieldGaussian(μ, L; scale_eps) Construct a Gaussian variational approximation with a diagonal covariance matrix. # Arguments -- `location::AbstractVector{T}`: Mean of the Gaussian. -- `scale::Diagonal{T}`: Diagonal Cholesky factor of the covariance of the Gaussian. +- `μ::AbstractVector{T}`: Mean of the Gaussian. +- `L::Diagonal{T}`: Diagonal Cholesky factor of the covariance of the Gaussian. # Keyword Arguments -- `check_args`: Check the conditioning of the initial scale (default: `true`). +- `scale_eps`: Smallest value allowed for the diagonal of the scale. (default: `1e-4`). """ function MeanFieldGaussian( - μ::AbstractVector{T}, L::Diagonal{T}; scale_eps::T=sqrt(eps(T)) + μ::AbstractVector{T}, L::Diagonal{T}; scale_eps::T=T(1e-4) ) where {T<:Real} - @assert minimum(diag(L)) ≥ sqrt(eps(eltype(L))) "Initial scale is too small (smallest diagonal value is $(minimum(diag(L)))). This might result in unstable optimization behavior." q_base = Normal{T}(zero(T), one(T)) return MvLocationScale(μ, L, q_base, scale_eps) end diff --git a/src/families/location_scale_low_rank.jl b/src/families/location_scale_low_rank.jl new file mode 100644 index 00000000..e2044142 --- /dev/null +++ b/src/families/location_scale_low_rank.jl @@ -0,0 +1,176 @@ + +struct MvLocationScaleLowRank{ + L,SD<:AbstractVector,SF<:AbstractMatrix,D<:ContinuousDistribution,E<:Real +} <: ContinuousMultivariateDistribution + location::L + scale_diag::SD + scale_factors::SF + dist::D + scale_eps::E +end + +""" + MvLocationLowRankScale(location, scale_diag, scale_factors, dist; scale_eps) + +Variational family with a covariance in the form of a diagonal matrix plus a squared low-rank matrix. +The rank is given by `size(scale_factors, 2)`. + +It generally represents any distribution for which the sampling path can be +represented as follows: +```julia + d = length(location) + r = size(scale_factors, 2) + u_diag = rand(dist, d) + u_factors = rand(dist, r) + z = scale_diag.*u_diag + scale_factors*u_factors + location +``` + +`scale_eps` sets a constraint on the smallest value of `scale_diag` to be enforced during optimization. +This is necessary to guarantee stable convergence. + +# Keyword Arguments +- `scale_eps`: Lower bound constraint for the values of scale_diag. (default: `sqrt(eps(T))`). +""" +function MvLocationScaleLowRank( + location::AbstractVector{T}, + scale_diag::AbstractVector{T}, + scale_factors::AbstractMatrix{T}, + dist::ContinuousUnivariateDistribution; + scale_eps::T=T(1e-4), +) where {T<:Real} + @assert minimum(scale_diag) ≥ scale_eps "Initial scale is too small (smallest diagonal scale value is $(minimum(scale_diag)). This might result in unstable optimization behavior." + @assert size(scale_factors, 1) == length(scale_diag) + return MvLocationScaleLowRank(location, scale_diag, scale_factors, dist, scale_eps) +end + +Functors.@functor MvLocationScaleLowRank (location, scale_diag, scale_factors) + +Base.length(q::MvLocationScaleLowRank) = length(q.location) + +Base.size(q::MvLocationScaleLowRank) = size(q.location) + +Base.eltype(::Type{<:MvLocationScaleLowRank{L,SD,SF,D,E}}) where {L,SD,SF,D,E} = eltype(L) + +function StatsBase.entropy(q::MvLocationScaleLowRank) + @unpack location, scale_diag, scale_factors, dist = q + n_dims = length(location) + scale_diag2 = scale_diag .* scale_diag + UtDinvU = Hermitian(scale_factors' * (scale_factors ./ scale_diag2)) + logdetΣ = 2 * sum(log.(scale_diag)) + logdet(I + UtDinvU) + return n_dims * convert(eltype(location), entropy(dist)) + logdetΣ / 2 +end + +function Distributions.logpdf( + q::MvLocationScaleLowRank, z::AbstractVector{<:Real}; non_differntiable::Bool=false +) + @unpack location, scale_diag, scale_factors, dist = q + μ_base = mean(dist) + n_dims = length(location) + + scale2chol = if non_differntiable + # Fast O(kd^2) path (not supported by most current AD frameworks): + scale2chol = Cholesky(LowerTriangular(diagm(sqrt.(scale_diag)))) + n_factors = size(scale_factors, 2) + for k in 1:n_factors + factor = scale_factors[:, k] # copy necessary due to in-place mutation + lowrankupdate!(scale2chol, factor) + end + scale2chol + else + # Slow but differentiable O(d^3) path + scale2 = Diagonal(scale_diag .* scale_diag) + scale_factors * scale_factors' + cholesky(scale2) + end + z_std = z - mean(q) + scale2chol.L * Fill(μ_base, n_dims) + return sum(Base.Fix1(logpdf, dist), scale2chol.L \ z_std) - logdet(scale2chol.L) +end + +function Distributions.rand(q::MvLocationScaleLowRank) + @unpack location, scale_diag, scale_factors, dist = q + n_dims = length(location) + n_factors = size(scale_factors, 2) + u_diag = rand(dist, n_dims) + u_fact = rand(dist, n_factors) + return scale_diag .* u_diag + scale_factors * u_fact + location +end + +function Distributions.rand( + rng::AbstractRNG, q::MvLocationScaleLowRank{S,D,L}, num_samples::Int +) where {S,D,L} + @unpack location, scale_diag, scale_factors, dist = q + n_dims = length(location) + n_factors = size(scale_factors, 2) + u_diag = rand(rng, dist, n_dims, num_samples) + u_fact = rand(rng, dist, n_factors, num_samples) + return scale_diag .* u_diag + scale_factors * u_fact .+ location +end + +function Distributions._rand!( + rng::AbstractRNG, q::MvLocationScaleLowRank, x::AbstractVecOrMat{<:Real} +) + @unpack location, scale_diag, scale_factors, dist = q + + rand!(rng, dist, x) + x[:] = scale_diag .* x + + u_fact = rand(rng, dist, size(scale_factors, 2), size(x, 2)) + x[:, :] += scale_factors * u_fact + + return x .+= location +end + +function Distributions.mean(q::MvLocationScaleLowRank) + @unpack location, scale_diag, scale_factors = q + μ = mean(q.dist) + return location + + scale_diag .* Fill(μ, length(scale_diag)) + + scale_factors * Fill(μ, size(scale_factors, 2)) +end + +function Distributions.var(q::MvLocationScaleLowRank) + @unpack scale_diag, scale_factors = q + σ2 = var(q.dist) + return σ2 * + (scale_diag .* scale_diag + sum(scale_factors .* scale_factors; dims=2)[:, 1]) +end + +function Distributions.cov(q::MvLocationScaleLowRank) + @unpack scale_diag, scale_factors = q + σ2 = var(q.dist) + return σ2 * (Diagonal(scale_diag .* scale_diag) + scale_factors * scale_factors') +end + +function update_variational_params!( + ::Type{<:MvLocationScaleLowRank}, opt_st, params, restructure, grad +) + opt_st, params = Optimisers.update!(opt_st, params, grad) + q = restructure(params) + ϵ = q.scale_eps + + # Clip diagonal to guarantee positive definite covariance + @. q.scale_diag = max(q.scale_diag, ϵ) + + params, _ = Optimisers.destructure(q) + + return opt_st, params +end + +""" + LowRankGaussian(μ, D, U; scale_eps) + +Construct a Gaussian variational approximation with a diagonal plus low-rank covariance matrix. + +# Arguments +- `μ::AbstractVector{T}`: Mean of the Gaussian. +- `D::Vector{T}`: Diagonal of the scale. +- `U::Matrix{T}`: Low-rank factors of the scale, where `size(U,2)` is the rank. + +# Keyword Arguments +- `scale_eps`: Smallest value allowed for the diagonal of the scale. (default: `1e-4`). +""" +function LowRankGaussian( + μ::AbstractVector{T}, D::Vector{T}, U::Matrix{T}; scale_eps::T=T(1e-4) +) where {T<:Real} + q_base = Normal{T}(zero(T), one(T)) + return MvLocationScaleLowRank(μ, D, U, q_base; scale_eps) +end diff --git a/test/Project.toml b/test/Project.toml index 251869e7..018198d1 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -27,7 +27,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" ADTypes = "0.2.1, 1" Bijectors = "0.13" DiffResults = "1.0" -Distributions = "0.25.100" +Distributions = "0.25.111" DistributionsAD = "0.6.45" Enzyme = "0.12.32" FillArrays = "1.6.1" diff --git a/test/interface/location_scale.jl b/test/families/location_scale.jl similarity index 60% rename from test/interface/location_scale.jl rename to test/families/location_scale.jl index dcc3369d..bd45458d 100644 --- a/test/interface/location_scale.jl +++ b/test/families/location_scale.jl @@ -1,27 +1,34 @@ @testset "interface LocationScale" begin - @testset "$(string(covtype)) $(basedist) $(realtype)" for basedist in [:gaussian], + @testset "$(string(covtype)) $(basedist) $(realtype)" for basedist in + [:gaussian, :gaussian_nonstd], covtype in [:meanfield, :fullrank], realtype in [Float32, Float64] n_dims = 10 n_montecarlo = 1000_000 - μ = randn(realtype, n_dims) - L = if covtype == :fullrank + location = randn(realtype, n_dims) + scale = if covtype == :fullrank LowerTriangular(tril(I + ones(realtype, n_dims, n_dims) / 2)) else Diagonal(ones(realtype, n_dims)) end - Σ = L * L' q = if covtype == :fullrank && basedist == :gaussian - FullRankGaussian(μ, L) + FullRankGaussian(location, scale) elseif covtype == :meanfield && basedist == :gaussian - MeanFieldGaussian(μ, L) + MeanFieldGaussian(location, scale) + elseif covtype == :fullrank && basedist == :gaussian_nonstd + MvLocationScale(location, scale, Normal(realtype(3), realtype(3))) + elseif covtype == :meanfield && basedist == :gaussian_nonstd + MvLocationScale(location, scale, Normal(realtype(3), realtype(3))) end + q_true = if basedist == :gaussian - MvNormal(μ, Σ) + MvNormal(location, scale * scale') + elseif basedist == :gaussian_nonstd + MvNormal(location + scale * fill(3, n_dims), 9 * scale * scale') end @testset "eltype" begin @@ -46,15 +53,15 @@ @testset "statistics" begin @testset "mean" begin @test eltype(mean(q)) == realtype - @test mean(q) == μ + @test mean(q) ≈ mean(q_true) end @testset "var" begin @test eltype(var(q)) == realtype - @test var(q) ≈ Diagonal(Σ) + @test var(q) ≈ var(q_true) end @testset "cov" begin @test eltype(cov(q)) == realtype - @test cov(q) ≈ Σ + @test cov(q) ≈ cov(q_true) end end @@ -62,11 +69,13 @@ @testset "rand" begin z_samples = mapreduce(x -> rand(q), hcat, 1:n_montecarlo) @test eltype(z_samples) == realtype - @test dropdims(mean(z_samples; dims=2); dims=2) ≈ μ rtol = realtype(1e-2) - @test dropdims(var(z_samples; dims=2); dims=2) ≈ diag(Σ) rtol = realtype( + @test dropdims(mean(z_samples; dims=2); dims=2) ≈ mean(q_true) rtol = realtype( + 1e-2 + ) + @test dropdims(var(z_samples; dims=2); dims=2) ≈ var(q_true) rtol = realtype( 1e-2 ) - @test cov(z_samples; dims=2) ≈ Σ rtol = realtype(1e-2) + @test cov(z_samples; dims=2) ≈ cov(q_true) rtol = realtype(1e-2) z_sample_ref = rand(StableRNG(1), q) @test z_sample_ref == rand(StableRNG(1), q) @@ -75,11 +84,13 @@ @testset "rand batch" begin z_samples = rand(q, n_montecarlo) @test eltype(z_samples) == realtype - @test dropdims(mean(z_samples; dims=2); dims=2) ≈ μ rtol = realtype(1e-2) - @test dropdims(var(z_samples; dims=2); dims=2) ≈ diag(Σ) rtol = realtype( + @test dropdims(mean(z_samples; dims=2); dims=2) ≈ mean(q_true) rtol = realtype( 1e-2 ) - @test cov(z_samples; dims=2) ≈ Σ rtol = realtype(1e-2) + @test dropdims(var(z_samples; dims=2); dims=2) ≈ var(q_true) rtol = realtype( + 1e-2 + ) + @test cov(z_samples; dims=2) ≈ cov(q_true) rtol = realtype(1e-2) samples_ref = rand(StableRNG(1), q, n_montecarlo) @test samples_ref == rand(StableRNG(1), q, n_montecarlo) @@ -94,11 +105,13 @@ z_samples = mapreduce(first, hcat, res) z_samples_ret = mapreduce(last, hcat, res) @test z_samples == z_samples_ret - @test dropdims(mean(z_samples; dims=2); dims=2) ≈ μ rtol = realtype(1e-2) - @test dropdims(var(z_samples; dims=2); dims=2) ≈ diag(Σ) rtol = realtype( + @test dropdims(mean(z_samples; dims=2); dims=2) ≈ mean(q_true) rtol = realtype( 1e-2 ) - @test cov(z_samples; dims=2) ≈ Σ rtol = realtype(1e-2) + @test dropdims(var(z_samples; dims=2); dims=2) ≈ var(q_true) rtol = realtype( + 1e-2 + ) + @test cov(z_samples; dims=2) ≈ cov(q_true) rtol = realtype(1e-2) z_sample_ref = Array{realtype}(undef, n_dims) rand!(StableRNG(1), q, z_sample_ref) @@ -112,11 +125,13 @@ z_samples = Array{realtype}(undef, n_dims, n_montecarlo) z_samples_ret = rand!(q, z_samples) @test z_samples == z_samples_ret - @test dropdims(mean(z_samples; dims=2); dims=2) ≈ μ rtol = realtype(1e-2) - @test dropdims(var(z_samples; dims=2); dims=2) ≈ diag(Σ) rtol = realtype( + @test dropdims(mean(z_samples; dims=2); dims=2) ≈ mean(q_true) rtol = realtype( + 1e-2 + ) + @test dropdims(var(z_samples; dims=2); dims=2) ≈ var(q_true) rtol = realtype( 1e-2 ) - @test cov(z_samples; dims=2) ≈ Σ rtol = realtype(1e-2) + @test cov(z_samples; dims=2) ≈ cov(q_true) rtol = realtype(1e-2) z_samples_ref = Array{realtype}(undef, n_dims, n_montecarlo) rand!(StableRNG(1), q, z_samples_ref) @@ -128,6 +143,38 @@ end end + @testset "scale positive definite projection" begin + @testset "$(string(covtype)) $(realtype) $(bijector)" for covtype in + [:meanfield, :fullrank], + realtype in [Float32, Float64], + bijector in [nothing, :identity] + + d = 5 + μ = zeros(realtype, d) + ϵ = sqrt(realtype(0.5)) + q = if covtype == :fullrank + L = LowerTriangular(Matrix{realtype}(I, d, d)) + FullRankGaussian(μ, L; scale_eps=ϵ) + elseif covtype == :meanfield + L = Diagonal(ones(realtype, d)) + MeanFieldGaussian(μ, L; scale_eps=ϵ) + end + q_trans = if isnothing(bijector) + q + else + Bijectors.TransformedDistribution(q, identity) + end + g = deepcopy(q) + + λ, re = Optimisers.destructure(q) + grad, _ = Optimisers.destructure(g) + opt_st = Optimisers.setup(Descent(one(realtype)), λ) + _, λ′ = AdvancedVI.update_variational_params!(typeof(q), opt_st, λ, re, grad) + q′ = re(λ′) + @test all(var(q′) .≥ ϵ^2) + end + end + @testset "Diagonal destructure" begin n_dims = 10 μ = zeros(n_dims) @@ -139,35 +186,3 @@ @test q == re(λ) end end - -@testset "scale positive definite projection" begin - @testset "$(string(covtype)) $(realtype) $(bijector)" for covtype in - [:meanfield, :fullrank], - realtype in [Float32, Float64], - bijector in [nothing, :identity] - - d = 5 - μ = zeros(realtype, d) - ϵ = sqrt(realtype(0.5)) - q = if covtype == :fullrank - L = LowerTriangular(Matrix{realtype}(I, d, d)) - FullRankGaussian(μ, L; scale_eps=ϵ) - elseif covtype == :meanfield - L = Diagonal(ones(realtype, d)) - MeanFieldGaussian(μ, L; scale_eps=ϵ) - end - q_trans = if isnothing(bijector) - q - else - Bijectors.TransformedDistribution(q, identity) - end - g = deepcopy(q) - - λ, re = Optimisers.destructure(q) - grad, _ = Optimisers.destructure(g) - opt_st = Optimisers.setup(Descent(one(realtype)), λ) - _, λ′ = AdvancedVI.update_variational_params!(typeof(q), opt_st, λ, re, grad) - q′ = re(λ′) - @test all(diag(var(q′)) .≥ ϵ^2) - end -end diff --git a/test/families/location_scale_low_rank.jl b/test/families/location_scale_low_rank.jl new file mode 100644 index 00000000..2accb971 --- /dev/null +++ b/test/families/location_scale_low_rank.jl @@ -0,0 +1,178 @@ + +@testset "interface LocationScaleLowRank" begin + @testset "$(basedist) rank=$(rank) $(realtype)" for basedist in + [:gaussian, :gaussian_nonstd], + n_rank in [1, 2], + realtype in [Float32, Float64] + + n_dims = 10 + n_montecarlo = 1000_000 + + location = randn(realtype, n_dims) + scale_diag = ones(realtype, n_dims) + scale_factors = randn(realtype, n_dims, n_rank) + + q = if basedist == :gaussian + LowRankGaussian(location, scale_diag, scale_factors) + elseif basedist == :gaussian_nonstd + MvLocationScaleLowRank( + location, scale_diag, scale_factors, Normal(realtype(3), realtype(3)) + ) + end + + q_true = if basedist == :gaussian + μ = location + Σ = Diagonal(scale_diag .^ 2) + scale_factors * scale_factors' + MvNormal(location, Σ) + elseif basedist == :gaussian_nonstd + μ = location + scale_diag .* fill(3, n_dims) + scale_factors * fill(3, n_rank) + Σ = 3^2 * (Diagonal(scale_diag .^ 2) + scale_factors * scale_factors') + MvNormal(μ, Σ) + end + + @testset "eltype" begin + @test eltype(q) == realtype + end + + @testset "logpdf" begin + z = rand(q) + @test logpdf(q, z) ≈ logpdf(q_true, z) rtol = realtype(1e-2) + @test eltype(logpdf(q, z)) == realtype + + @test logpdf(q, z; non_differntiable=true) ≈ logpdf(q_true, z) rtol = realtype( + 1e-2 + ) + @test eltype(logpdf(q, z; non_differntiable=true)) == realtype + end + + @testset "entropy" begin + @test eltype(entropy(q)) == realtype + @test entropy(q) ≈ entropy(q_true) + end + + @testset "length" begin + @test length(q) == n_dims + end + + @testset "statistics" begin + @testset "mean" begin + @test eltype(mean(q)) == realtype + @test mean(q) ≈ mean(q_true) + end + @testset "var" begin + @test eltype(var(q)) == realtype + @test var(q) ≈ var(q_true) + end + @testset "cov" begin + @test eltype(cov(q)) == realtype + @test cov(q) ≈ cov(q_true) + end + end + + @testset "sampling" begin + @testset "rand" begin + z_samples = mapreduce(x -> rand(q), hcat, 1:n_montecarlo) + @test eltype(z_samples) == realtype + @test dropdims(mean(z_samples; dims=2); dims=2) ≈ mean(q_true) rtol = realtype( + 1e-2 + ) + @test dropdims(var(z_samples; dims=2); dims=2) ≈ var(q_true) rtol = realtype( + 1e-2 + ) + @test cov(z_samples; dims=2) ≈ cov(q_true) rtol = realtype(1e-2) + + z_sample_ref = rand(StableRNG(1), q) + @test z_sample_ref ≈ rand(StableRNG(1), q) + end + + @testset "rand batch" begin + z_samples = rand(q, n_montecarlo) + @test eltype(z_samples) == realtype + @test dropdims(mean(z_samples; dims=2); dims=2) ≈ mean(q_true) rtol = realtype( + 1e-2 + ) + @test dropdims(var(z_samples; dims=2); dims=2) ≈ var(q_true) rtol = realtype( + 1e-2 + ) + @test cov(z_samples; dims=2) ≈ cov(q_true) rtol = realtype(1e-2) + + samples_ref = rand(StableRNG(1), q, n_montecarlo) + @test samples_ref ≈ rand(StableRNG(1), q, n_montecarlo) + end + + @testset "rand! AbstractVector" begin + res = map(1:n_montecarlo) do _ + z_sample = Array{realtype}(undef, n_dims) + z_sample_ret = rand!(q, z_sample) + (z_sample, z_sample_ret) + end + z_samples = mapreduce(first, hcat, res) + z_samples_ret = mapreduce(last, hcat, res) + @test z_samples ≈ z_samples_ret + @test dropdims(mean(z_samples; dims=2); dims=2) ≈ mean(q_true) rtol = realtype( + 1e-2 + ) + @test dropdims(var(z_samples; dims=2); dims=2) ≈ var(q_true) rtol = realtype( + 1e-2 + ) + @test cov(z_samples; dims=2) ≈ cov(q_true) rtol = realtype(1e-2) + + z_sample_ref = Array{realtype}(undef, n_dims) + rand!(StableRNG(1), q, z_sample_ref) + + z_sample = Array{realtype}(undef, n_dims) + rand!(StableRNG(1), q, z_sample) + @test z_sample_ref ≈ z_sample + end + + @testset "rand! AbstractMatrix" begin + z_samples = Array{realtype}(undef, n_dims, n_montecarlo) + z_samples_ret = rand!(q, z_samples) + @test z_samples ≈ z_samples_ret + @test dropdims(mean(z_samples; dims=2); dims=2) ≈ mean(q_true) rtol = realtype( + 1e-2 + ) + @test dropdims(var(z_samples; dims=2); dims=2) ≈ var(q_true) rtol = realtype( + 1e-2 + ) + @test cov(z_samples; dims=2) ≈ cov(q_true) rtol = realtype(1e-2) + + z_samples_ref = Array{realtype}(undef, n_dims, n_montecarlo) + rand!(StableRNG(1), q, z_samples_ref) + + z_samples = Array{realtype}(undef, n_dims, n_montecarlo) + rand!(StableRNG(1), q, z_samples) + @test z_samples_ref ≈ z_samples + end + end + end + + @testset "diagonal positive definite projection" begin + @testset "$(realtype) $(bijector)" for realtype in [Float32, Float64], + bijector in [nothing, :identity] + + n_rank = 2 + d = 5 + μ = zeros(realtype, d) + ϵ = sqrt(realtype(0.5)) + D = ones(realtype, d) + U = randn(realtype, d, n_rank) + q = MvLocationScaleLowRank( + μ, D, U, Normal{realtype}(zero(realtype), one(realtype)); scale_eps=ϵ + ) + q_trans = if isnothing(bijector) + q + else + Bijectors.TransformedDistribution(q, bijector) + end + g = deepcopy(q) + + λ, re = Optimisers.destructure(q) + grad, _ = Optimisers.destructure(g) + opt_st = Optimisers.setup(Descent(one(realtype)), λ) + _, λ′ = AdvancedVI.update_variational_params!(typeof(q), opt_st, λ, re, grad) + q′ = re(λ′) + @test all(var(q′) .≥ ϵ^2) + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 5d0d2c8d..43958e8e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -52,7 +52,11 @@ if GROUP == "All" || GROUP == "Interface" include("interface/repgradelbo.jl") include("interface/rules.jl") include("interface/averaging.jl") - include("interface/location_scale.jl") +end + +if GROUP == "All" || GROUP == "Families" + include("families/location_scale.jl") + include("families/location_scale_low_rank.jl") end const PROGRESS = haskey(ENV, "PROGRESS")