Skip to content

Commit

Permalink
Merge pull request #192 from biaslab/conjugate
Browse files Browse the repository at this point in the history
Conjugate variational Inference
  • Loading branch information
ThijsvdLaar authored Feb 1, 2022
2 parents 30933f2 + 7e78b57 commit dc808ac
Show file tree
Hide file tree
Showing 15 changed files with 1,400 additions and 66 deletions.
805 changes: 805 additions & 0 deletions demo/conjugate_variational_inference.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions src/ForneyLab.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ include("update_rules/point_mass_constraint.jl")
include("update_rules/nonlinear_unscented.jl")
include("update_rules/nonlinear_sampling.jl")
include("update_rules/nonlinear_extended.jl")
include("update_rules/nonlinear_conjugate.jl")
include("update_rules/sample_list.jl")

*(x::ProbabilityDistribution, y::ProbabilityDistribution) = prod!(x, y) # * operator for probability distributions
Expand Down
1 change: 1 addition & 0 deletions src/engines/julia/julia.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ include("update_rules/softmax.jl")
include("update_rules/nonlinear_unscented.jl")
include("update_rules/nonlinear_sampling.jl")
include("update_rules/nonlinear_extended.jl")
include("update_rules/nonlinear_conjugate.jl")
include("update_rules/dot_product.jl")
include("update_rules/poisson.jl")
include("update_rules/moment_constraint.jl")
Expand Down
301 changes: 301 additions & 0 deletions src/engines/julia/update_rules/nonlinear_conjugate.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,301 @@
# Inplementation based on Khan et al. (2017), "Conjugate-computation variational inference:
# Converting variational inference in non-conjugate models to inferences in conjugate models",
# and Akbayrak et al. (2021), "Extended Variational Message Passing for Automated Approximate
# Bayesian Inference"

export
ruleSPNonlinearCOutNM,
ruleSPNonlinearCIn1MN,
ruleSPNonlinearCOutNMX,
ruleSPNonlinearCInMX,
ruleSPNonlinearCInGX,
ruleMNonlinearCInMGX

const default_optimizer = ForgetDelayDescent(200.0, 0.6) # Default optimizer
const default_n_iterations = 1000 # Default number of iterations for gradient descent


#-----------------------
# Conjugate Update Rules
#-----------------------

ruleSPNonlinearCOutNM(g::Function,
msg_out::Nothing,
msg_in1::Message;
dims::Any=nothing,
n_samples=default_n_samples,
n_iterations=default_n_iterations,
optimizer=default_optimizer) =
ruleSPNonlinearSOutNM(g, nothing, msg_in1; dims=dims, n_samples=n_samples) # Reuse sampling update

function ruleSPNonlinearCIn1MN(g::Function,
msg_out::Message,
msg_in1::Message{F, V};
dims::Any=nothing,
n_samples=default_n_samples,
n_iterations=default_n_iterations,
optimizer=default_optimizer) where {F<:FactorNode, V<:VariateType}

msg_s = ruleSPNonlinearSIn1MN(g, msg_out, nothing; dims=dims, n_samples=n_samples) # Returns Message{Function}
η = naturalParams(msg_in1.dist)
λ = renderCVI(msg_s.dist.params[:log_pdf], n_iterations, optimizer, η, msg_in1)

return Message(standardDistribution(V, F, η=λ-η))
end

ruleSPNonlinearCOutNMX(g::Function,
msg_out::Nothing,
msgs_in::Vararg{Message};
dims::Any=nothing,
n_samples=default_n_samples,
n_iterations=default_n_iterations,
optimizer=default_optimizer) =
ruleSPNonlinearSOutNMX(g, nothing, msgs_in...; dims=dims, n_samples=n_samples)

function ruleSPNonlinearCInGX(g::Function,
inx::Int64, # Index of inbound interface inx
msg_out::Message,
msgs_in::Vararg{Message{<:Gaussian}}; # Only Gaussian because of marginalization over inbounds
dims::Any=nothing,
n_samples=default_n_samples,
n_iterations=default_n_iterations,
optimizer=default_optimizer)

# Extract joint statistics of inbound messages
(ms_fw_in, Vs_fw_in) = collectStatistics(msgs_in...) # Return arrays with individual means and covariances
(m_fw_in, V_fw_in, ds) = concatenateGaussianMV(ms_fw_in, Vs_fw_in) # Concatenate individual statistics into joint statistics
msg_fw_in = Message(Multivariate, GaussianMeanVariance, m=m_fw_in, v=V_fw_in) # Joint forward message

# log-pdf of joint backward message over inbounds
log_pdf_s(z) = logPdf(msg_out.dist, g(split(z, ds)...))

# Compute joint marginal belief
η = naturalParams(msg_fw_in.dist)
λ = renderCVI(log_pdf_s, n_iterations, optimizer, η, msg_fw_in)
d_marg = standardDistribution(Multivariate, GaussianMeanVariance, η=λ)
(m_in, V_in) = unsafeMeanCov(d_marg)

# Marginalize joint belief on in's
(m_inx, V_inx) = marginalizeGaussianMV(m_in, V_in, ds, inx)
W_inx = cholinv(V_inx) # Convert to canonical statistics
xi_inx = W_inx*m_inx

# Divide marginal on inx by forward message
(xi_fw_inx, W_fw_inx) = unsafeWeightedMeanPrecision(msgs_in[inx].dist)
xi_bw_inx = xi_inx - xi_fw_inx
W_bw_inx = W_inx - W_fw_inx # Note: subtraction might lead to posdef violations

return Message(variateType(dims), GaussianWeightedMeanPrecision, xi=xi_bw_inx, w=W_bw_inx)
end

# Special case for two inputs with one PointMass (no inx required)
function ruleSPNonlinearCInMX(g::Function,
msg_out::Message,
msg_in1::Message{F, V},
msg_in2::Message{PointMass};
dims::Any=nothing,
n_samples=default_n_samples,
n_iterations=default_n_iterations,
optimizer=default_optimizer) where {F<:FactorNode, V<:VariateType}

msg_s = ruleSPNonlinearSInMX(g, msg_out, nothing, msg_in2; dims=dims, n_samples=n_samples)
η = naturalParams(msg_in1.dist)
λ = renderCVI(msg_s.dist.params[:log_pdf], n_iterations, optimizer, η, msg_in1)

return Message(standardDistribution(V, F, η=λ-η))
end

# Special case for two inputs with one PointMass (no inx required)
function ruleSPNonlinearCInMX(g::Function,
msg_out::Message,
msg_in1::Message{PointMass},
msg_in2::Message{F, V};
dims::Any=nothing,
n_samples=default_n_samples,
n_iterations=default_n_iterations,
optimizer=default_optimizer) where {F<:FactorNode, V<:VariateType}

msg_s = ruleSPNonlinearSInMX(g, msg_out, msg_in1, nothing; dims=dims, n_samples=n_samples)
η = naturalParams(msg_in2.dist)
λ = renderCVI(msg_s.dist.params[:log_pdf], n_iterations, optimizer, η, msg_in2)

return Message(standardDistribution(V, F, η=λ-η))
end

# Joint marginal belief over inbounds
function ruleMNonlinearCInMGX(g::Function,
msg_out::Message,
msgs_in::Vararg{Message{<:Gaussian}}) # Only Gaussian because of marginalization over inbounds

# Extract joint statistics of inbound messages
(ms_fw_in, Vs_fw_in) = collectStatistics(msgs_in...) # Return arrays with individual means and covariances
(m_fw_in, V_fw_in, ds) = concatenateGaussianMV(ms_fw_in, Vs_fw_in) # Concatenate individual statistics into joint statistics
msg_fw_in = Message(Multivariate, GaussianMeanVariance, m=m_fw_in, v=V_fw_in) # Joint forward message

# log-pdf of joint backward message over inbounds
log_pdf_s(z) = logPdf(msg_out.dist, g(split(z, ds)...))

η = naturalParams(msg_fw_in.dist)
λ = renderCVI(log_pdf_s, default_n_iterations, default_optimizer, η, msg_fw_in) # Natural statistics of marginal

return standardDistribution(Multivariate, GaussianMeanVariance, η=λ)
end


#---------------------------
# Custom inbounds collectors
#---------------------------

# Conjugate approximation
function collectSumProductNodeInbounds(node::Nonlinear{Conjugate}, entry::ScheduleEntry)
inbounds = Any[]

# Push function to calling signature
# This function needs to be defined in the scope of the user
push!(inbounds, Dict{Symbol, Any}(:g => node.g,
:keyword => false))

multi_in = isMultiIn(node) # Boolean to indicate a nonlinear node with multiple stochastic inbounds
inx = findfirst(isequal(entry.interface), node.interfaces) - 1 # Find number of inbound interface; 0 for outbound

if (inx > 0) && multi_in # Multi-inbound backward rule
push!(inbounds, Dict{Symbol, Any}(:inx => inx, # Push inbound identifier
:keyword => false))
end

interface_to_schedule_entry = current_inference_algorithm.interface_to_schedule_entry
for node_interface in node.interfaces
inbound_interface = ultimatePartner(node_interface)
if (node_interface == entry.interface != node.interfaces[1])
# Collect the breaker message for a backward rule
haskey(interface_to_schedule_entry, inbound_interface) || error("The nonlinear node's backward rule uses the incoming message on the input edge to determine the approximation point. Try altering the variable order in the scheduler to first perform a forward pass.")
push!(inbounds, interface_to_schedule_entry[inbound_interface])
elseif node_interface == entry.interface
# Ignore inbound message on outbound interface
push!(inbounds, nothing)
elseif isClamped(inbound_interface)
# Hard-code outbound message of constant node in schedule
push!(inbounds, assembleClamp!(inbound_interface.node, Message))
else
# Collect message from previous result
push!(inbounds, interface_to_schedule_entry[inbound_interface])
end
end

# Push custom arguments if defined
if (node.dims !== nothing)
push!(inbounds, Dict{Symbol, Any}(:dims => node.dims[inx + 1],
:keyword => true))
end
if (node.n_samples !== nothing)
push!(inbounds, Dict{Symbol, Any}(:n_samples => node.n_samples,
:keyword => true))
end
if (node.n_iterations !== nothing)
push!(inbounds, Dict{Symbol, Any}(:n_iterations => node.n_iterations,
:keyword => true))
end
if (node.optimizer !== nothing)
push!(inbounds, Dict{Symbol, Any}(:optimizer => node.optimizer,
:keyword => true))
end
return inbounds
end


#-------------------------
# Optimization subroutines
#-------------------------

function renderCVI(log_μ_bw::Function,
n_iterations::Int,
optimizer::Any,
λ_0::Vector,
μ_fw::Message{F, V}) where {F<:FactorNode, V<:VariateType}

# Natural parameters of forward message
η = naturalParams(μ_fw.dist)

# Fisher information matrix
A = λ -> logNormalizer(V, F, η=λ)
Fisher = λ -> ForwardDiff.hessian(A, λ)

# Initialize q marginal
λ_i = deepcopy(λ_0)
q_i = standardDistribution(V, F, η=λ_i)

for i=1:n_iterations
# Store previous results for possible reset
q_i_min = deepcopy(q_i)
λ_i_min = deepcopy(λ_i)

# Given the current sample, define natural gradient of q
s_q_i = sample(q_i)
log_q = λ -> logPdf(V, F, s_q_i, η=λ)
∇log_q = λ -> ForwardDiff.gradient(log_q, λ)

# Compute current free energy gradient and update natural statistics
∇log_μ_bw_i = log_μ_bw(s_q_i)*cholinv(Fisher(λ_i))*∇log_q(λ_i) # Natural gradient of backward message
∇F_i = λ_i - η - ∇log_μ_bw_i # Natural gradient of free energy
λ_i -= apply!(optimizer, λ_i, ∇F_i) # Update λ_i

# Update q_i
q_i = standardDistribution(V, F, η=λ_i)
if !isProper(q_i) # Result is improper; reset statistics
q_i = q_i_min
λ_i = λ_i_min
end
end

return λ_i
end

# Gaussian result that avoids Fisher information matrix construction
function renderCVI(log_μ_bw::Function,
n_iterations::Int,
optimizer::Any,
λ_0::Vector,
μ_fw::Message{F, V}) where {F<:Gaussian, V<:VariateType}

# Natural parameters of forward message
η = naturalParams(μ_fw.dist)

# Gradients/derivatives of Gaussian moments
if V == Univariate
∇m = s -> ForwardDiff.derivative(log_μ_bw, s)
∇v = s -> 0.5*ForwardDiff.derivative(∇m, s)
else
∇m = s -> ForwardDiff.gradient(log_μ_bw, s)
∇v = s -> 0.5*ForwardDiff.jacobian(∇m, s)
end

# Initialize q marginal
λ_i = deepcopy(λ_0)
q_i = standardDistribution(V, F, η=λ_i)

for i=1:n_iterations
# Store previous results for possible reset
q_i_min = deepcopy(q_i)
λ_i_min = deepcopy(λ_i)

# Given the current sample, define natural gradient of q
m_q_i = unsafeMean(q_i)
s_q_i = sample(q_i)
∇λ_i_1 = ∇m(s_q_i) - 2*∇v(s_q_i)*m_q_i
∇λ_i_2 = ∇v(s_q_i)

# Compute current free energy gradient and update natural statistics
∇log_μ_bw_i = vcat(∇λ_i_1, vec(∇λ_i_2))
∇F_i = λ_i - η - ∇log_μ_bw_i # Natural gradient of free energy
λ_i -= apply!(optimizer, λ_i, ∇F_i) # Update λ_i

# Update q_i
q_i = standardDistribution(V, F, η=λ_i)
if !isProper(q_i) # Result is improper; reset statistics
q_i = q_i_min
λ_i = λ_i_min
end
end

return λ_i
end
15 changes: 0 additions & 15 deletions src/engines/julia/update_rules/nonlinear_sampling.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
export
ruleSPNonlinearSOutNM,
ruleSPNonlinearSIn1MN,
ruleSPNonlinearSOutNGX,
ruleSPNonlinearSInGX,
ruleSPNonlinearSOutNMX,
ruleSPNonlinearSInMX,
Expand Down Expand Up @@ -48,20 +47,6 @@ function ruleSPNonlinearSIn1MN(g::Function,
return Message(variateType(dims), Function, log_pdf = (z)->logPdf(msg_out.dist, g(z)))
end

function ruleSPNonlinearSOutNGX(g::Function,
msg_out::Nothing,
msgs_in::Vararg{Message{<:Gaussian}};
dims::Any=nothing,
n_samples=default_n_samples)

samples_in = [sample(msg_in.dist, n_samples) for msg_in in msgs_in]

samples = g.(samples_in...)
weights = ones(n_samples)/n_samples

return Message(variateType(dims), SampleList, s=samples, w=weights)
end

function ruleSPNonlinearSInGX(g::Function,
inx::Int64, # Index of inbound interface inx
msg_out::Message,
Expand Down
2 changes: 1 addition & 1 deletion src/factor_nodes/categorical.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ standardDistribution(V::Type{Univariate}, F::Type{Categorical}; η::Vector) = Pr

logNormalizer(::Type{Univariate}, ::Type{Categorical}; η::Vector) = log(sum(exp.(η)))

logPdf(V::Type{Univariate}, F::Type{Categorical}, x::Vector; η::Vector) = x'*η - logNormalizer(V, F; η=η)
logPdf(V::Type{Univariate}, F::Type{Categorical}, x::AbstractVector; η::Vector) = x'*η - logNormalizer(V, F; η=η)

function prod!( x::ProbabilityDistribution{Univariate, Categorical},
y::ProbabilityDistribution{Univariate, Categorical},
Expand Down
2 changes: 1 addition & 1 deletion src/factor_nodes/gaussian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ function logNormalizer(::Type{Multivariate}, ::Type{<:Gaussian}; η::Vector)
d = Int(-0.5 + 0.5*sqrt(1 + 4*length(η))) # Extract dimensionality
η_1 = η[1:d]
η_2 = reshape(η[d+1:end], d, d)
return η_1'*cholinv(-4*η_2)*η_1 - 0.5*logdet(-2*η_2)
return η_1'*pinv(-4*η_2)*η_1 - 0.5*logdet(-2*η_2)
end

logPdf(V::Type{Univariate}, ::Type{F}, x::Number; η::Vector) where F<:Gaussian = -0.5*log(2pi) + vcat(x, x^2)'*η - logNormalizer(V, F; η=η)
Expand Down
Loading

0 comments on commit dc808ac

Please sign in to comment.