Skip to content

Commit

Permalink
formatting all files (#80)
Browse files Browse the repository at this point in the history
  • Loading branch information
yebai committed Aug 9, 2024
1 parent 6b2eb57 commit 40fd15b
Show file tree
Hide file tree
Showing 17 changed files with 403 additions and 411 deletions.
2 changes: 2 additions & 0 deletions .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
style="blue"
format_markdown = true
29 changes: 6 additions & 23 deletions src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,9 @@ Instead, the return values should be used.
"""
function update_variational_params! end

update_variational_params!(::Type, opt_st, params, restructure, grad) =
Optimisers.update!(opt_st, params, grad)
function update_variational_params!(::Type, opt_st, params, restructure, grad)
return Optimisers.update!(opt_st, params, grad)
end

# estimators
"""
Expand All @@ -106,13 +107,7 @@ This function needs to be implemented only if `obj` is stateful.
- `params`: Initial variational parameters.
- `restructure`: Function that reconstructs the variational approximation from `λ`.
"""
init(
::Random.AbstractRNG,
::AbstractVariationalObjective,
::Any,
::Any,
::Any,
) = nothing
init(::Random.AbstractRNG, ::AbstractVariationalObjective, ::Any, ::Any, ::Any) = nothing

"""
estimate_objective([rng,] obj, q, prob; kwargs...)
Expand All @@ -136,7 +131,6 @@ function estimate_objective end

export estimate_objective


"""
estimate_gradient!(rng, obj, adtype, out, prob, λ, restructure, obj_state)
Expand Down Expand Up @@ -177,25 +171,16 @@ Estimate the entropy of `q`.
"""
function estimate_entropy end

export
RepGradELBO,
ClosedFormEntropy,
StickingTheLandingEntropy,
MonteCarloEntropy
export RepGradELBO, ClosedFormEntropy, StickingTheLandingEntropy, MonteCarloEntropy

include("objectives/elbo/entropy.jl")
include("objectives/elbo/repgradelbo.jl")


# Variational Families
export
MvLocationScale,
MeanFieldGaussian,
FullRankGaussian
export MvLocationScale, MeanFieldGaussian, FullRankGaussian

include("families/location_scale.jl")


# Optimization Routine

function optimize end
Expand All @@ -205,7 +190,6 @@ export optimize
include("utils.jl")
include("optimize.jl")


# optional dependencies
if !isdefined(Base, :get_extension) # check whether :get_extension is defined in Base
using Requires
Expand All @@ -232,4 +216,3 @@ end
end

end

93 changes: 44 additions & 49 deletions src/families/location_scale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,21 @@ represented as follows:
z = scale*u + location
```
"""
struct MvLocationScale{
S, D <: ContinuousDistribution, L, E <: Real
} <: ContinuousMultivariateDistribution
location ::L
scale ::S
dist ::D
struct MvLocationScale{S,D<:ContinuousDistribution,L,E<:Real} <:
ContinuousMultivariateDistribution
location::L
scale::S
dist::D
scale_eps::E
end

function MvLocationScale(
location ::AbstractVector{T},
scale ::AbstractMatrix{T},
dist ::ContinuousDistribution;
scale_eps::T = sqrt(eps(T))
) where {T <: Real}
MvLocationScale(location, scale, dist, scale_eps)
location::AbstractVector{T},
scale::AbstractMatrix{T},
dist::ContinuousDistribution;
scale_eps::T=sqrt(eps(T)),
) where {T<:Real}
return MvLocationScale(location, scale, dist, scale_eps)
end

Functors.@functor MvLocationScale (location, scale)
Expand All @@ -38,85 +37,85 @@ Functors.@functor MvLocationScale (location, scale)
# `scale <: Diagonal`, which is not the default behavior. Otherwise, forward-mode AD
# is very inefficient.
# begin
struct RestructureMeanField{S <: Diagonal, D, L}
q::MvLocationScale{S, D, L}
struct RestructureMeanField{S<:Diagonal,D,L}
q::MvLocationScale{S,D,L}
end

function (re::RestructureMeanField)(flat::AbstractVector)
n_dims = div(length(flat), 2)
n_dims = div(length(flat), 2)
location = first(flat, n_dims)
scale = Diagonal(last(flat, n_dims))
MvLocationScale(location, scale, re.q.dist, re.q.scale_eps)
scale = Diagonal(last(flat, n_dims))
return MvLocationScale(location, scale, re.q.dist, re.q.scale_eps)
end

function Optimisers.destructure(
q::MvLocationScale{<:Diagonal, D, L}
) where {D, L}
function Optimisers.destructure(q::MvLocationScale{<:Diagonal,D,L}) where {D,L}
@unpack location, scale, dist = q
flat = vcat(location, diag(scale))
flat, RestructureMeanField(q)
return flat, RestructureMeanField(q)
end
# end

Base.length(q::MvLocationScale) = length(q.location)

Base.size(q::MvLocationScale) = size(q.location)

Base.eltype(::Type{<:MvLocationScale{S, D, L}}) where {S, D, L} = eltype(D)
Base.eltype(::Type{<:MvLocationScale{S,D,L}}) where {S,D,L} = eltype(D)

function StatsBase.entropy(q::MvLocationScale)
@unpack location, scale, dist = q
@unpack location, scale, dist = q
n_dims = length(location)
# `convert` is necessary because `entropy` is not type stable upstream
n_dims*convert(eltype(location), entropy(dist)) + logdet(scale)
return n_dims * convert(eltype(location), entropy(dist)) + logdet(scale)
end

function Distributions.logpdf(q::MvLocationScale, z::AbstractVector{<:Real})
@unpack location, scale, dist = q
sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - logdet(scale)
return sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - logdet(scale)
end

function Distributions.rand(q::MvLocationScale)
@unpack location, scale, dist = q
n_dims = length(location)
scale*rand(dist, n_dims) + location
return scale * rand(dist, n_dims) + location
end

function Distributions.rand(
rng::AbstractRNG, q::MvLocationScale{S, D, L}, num_samples::Int
) where {S, D, L}
rng::AbstractRNG, q::MvLocationScale{S,D,L}, num_samples::Int
) where {S,D,L}
@unpack location, scale, dist = q
n_dims = length(location)
scale*rand(rng, dist, n_dims, num_samples) .+ location
return scale * rand(rng, dist, n_dims, num_samples) .+ location
end

# This specialization improves AD performance of the sampling path
function Distributions.rand(
rng::AbstractRNG, q::MvLocationScale{<:Diagonal, D, L}, num_samples::Int
) where {L, D}
rng::AbstractRNG, q::MvLocationScale{<:Diagonal,D,L}, num_samples::Int
) where {L,D}
@unpack location, scale, dist = q
n_dims = length(location)
n_dims = length(location)
scale_diag = diag(scale)
scale_diag.*rand(rng, dist, n_dims, num_samples) .+ location
return scale_diag .* rand(rng, dist, n_dims, num_samples) .+ location
end

function Distributions._rand!(rng::AbstractRNG, q::MvLocationScale, x::AbstractVecOrMat{<:Real})
function Distributions._rand!(
rng::AbstractRNG, q::MvLocationScale, x::AbstractVecOrMat{<:Real}
)
@unpack location, scale, dist = q
rand!(rng, dist, x)
x[:] = scale*x
x[:] = scale * x
return x .+= location
end

Distributions.mean(q::MvLocationScale) = q.location

function Distributions.var(q::MvLocationScale)
function Distributions.var(q::MvLocationScale)
C = q.scale
Diagonal(C*C')
return Diagonal(C * C')
end

function Distributions.cov(q::MvLocationScale)
C = q.scale
Hermitian(C*C')
return Hermitian(C * C')
end

"""
Expand All @@ -132,13 +131,11 @@ Construct a Gaussian variational approximation with a dense covariance matrix.
- `check_args`: Check the conditioning of the initial scale (default: `true`).
"""
function FullRankGaussian(
μ::AbstractVector{T},
L::LinearAlgebra.AbstractTriangular{T};
scale_eps::T = sqrt(eps(T))
) where {T <: Real}
μ::AbstractVector{T}, L::LinearAlgebra.AbstractTriangular{T}; scale_eps::T=sqrt(eps(T))
) where {T<:Real}
@assert minimum(diag(L)) sqrt(scale_eps) "Initial scale is too small (smallest diagonal value is $(minimum(diag(L)))). This might result in unstable optimization behavior."
q_base = Normal{T}(zero(T), one(T))
MvLocationScale(μ, L, q_base, scale_eps)
return MvLocationScale(μ, L, q_base, scale_eps)
end

"""
Expand All @@ -154,13 +151,11 @@ Construct a Gaussian variational approximation with a diagonal covariance matrix
- `check_args`: Check the conditioning of the initial scale (default: `true`).
"""
function MeanFieldGaussian(
μ::AbstractVector{T},
L::Diagonal{T};
scale_eps::T = sqrt(eps(T)),
) where {T <: Real}
μ::AbstractVector{T}, L::Diagonal{T}; scale_eps::T=sqrt(eps(T))
) where {T<:Real}
@assert minimum(diag(L)) sqrt(eps(eltype(L))) "Initial scale is too small (smallest diagonal value is $(minimum(diag(L)))). This might result in unstable optimization behavior."
q_base = Normal{T}(zero(T), one(T))
MvLocationScale(μ, L, q_base, scale_eps)
return MvLocationScale(μ, L, q_base, scale_eps)
end

function update_variational_params!(
Expand All @@ -176,5 +171,5 @@ function update_variational_params!(

params, _ = Optimisers.destructure(q)

opt_st, params
return opt_st, params
end
6 changes: 2 additions & 4 deletions src/objectives/elbo/entropy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ struct ClosedFormEntropy <: AbstractEntropyEstimator end
maybe_stop_entropy_score(::AbstractEntropyEstimator, q, q_stop) = q

function estimate_entropy(::ClosedFormEntropy, ::Any, q)
entropy(q)
return entropy(q)
end

"""
Expand All @@ -31,9 +31,7 @@ struct MonteCarloEntropy <: AbstractEntropyEstimator end
maybe_stop_entropy_score(::StickingTheLandingEntropy, q, q_stop) = q_stop

function estimate_entropy(
::Union{MonteCarloEntropy, StickingTheLandingEntropy},
mc_samples::AbstractMatrix,
q
::Union{MonteCarloEntropy,StickingTheLandingEntropy}, mc_samples::AbstractMatrix, q
)
mean(eachcol(mc_samples)) do mc_sample
-logpdf(q, mc_sample)
Expand Down
Loading

0 comments on commit 40fd15b

Please sign in to comment.