Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Nonlinear dimensionality specification #187

Merged
merged 5 commits into from
Dec 15, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
606 changes: 10 additions & 596 deletions demo/bootstrap_particle_filter.ipynb

Large diffs are not rendered by default.

31 changes: 15 additions & 16 deletions demo/nonlinear_kalman_filter.ipynb

Large diffs are not rendered by default.

80 changes: 9 additions & 71 deletions demo/nonlinear_online_estimation.ipynb

Large diffs are not rendered by default.

41 changes: 16 additions & 25 deletions demo/variational_laplace_and_sampling.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,7 @@
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"┌ Info: Precompiling ForneyLab [9fc3f58a-c2cc-5bff-9419-6a294fefdca9]\n",
"└ @ Base loading.jl:1273\n"
]
}
],
"outputs": [],
"source": [
"using ForneyLab, LinearAlgebra\n",
"\n",
Expand Down Expand Up @@ -206,7 +197,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"The marginal for l is a ProbabilityDistribution{Univariate,SampleList} with mean 0.636 and variance 0.014\n"
"The marginal for l is a ProbabilityDistribution{Univariate, SampleList} with mean 0.611 and variance 0.014\n"
]
}
],
Expand Down Expand Up @@ -320,7 +311,7 @@
{
"data": {
"text/plain": [
"𝒩(xi=1.40, w=2.33)\n"
"𝒩(xi=1.40, w=2.32)\n"
]
},
"execution_count": 14,
Expand All @@ -341,7 +332,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Free energy per iteration: 2.037, 1.952, 1.957, 1.942, 1.943"
"Free energy per iteration: 2.041, 1.946, 1.938, 1.943, 1.939"
]
}
],
Expand Down Expand Up @@ -437,7 +428,7 @@
{
"data": {
"text/plain": [
"Dir(a=[2.35, 5.15, 3.20])\n"
"Dir(a=[2.34, 5.16, 3.20])\n"
]
},
"execution_count": 19,
Expand All @@ -458,10 +449,10 @@
"name": "stdout",
"output_type": "stream",
"text": [
"The marginal for x is a ProbabilityDistribution{Univariate,SampleList} with mean vector entries\n",
" [1] = 0.227738\n",
" [2] = 0.772048\n",
" [3] = 0.000214287\n"
"The marginal for x is a ProbabilityDistribution{Univariate, SampleList} with mean vector entries\n",
" [1] = 0.225319\n",
" [2] = 0.774444\n",
" [3] = 0.000237294\n"
]
}
],
Expand All @@ -480,7 +471,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Free energy per iteration: 2.463, 2.459, 2.468, 2.416, 2.449"
"Free energy per iteration: 2.504, 2.43, 2.44, 2.426, 2.492"
]
}
],
Expand Down Expand Up @@ -553,7 +544,7 @@
{
"data": {
"text/plain": [
"𝒩(xi=0.91, w=0.99)\n"
"𝒩(xi=0.77, w=0.85)\n"
]
},
"execution_count": 25,
Expand All @@ -573,7 +564,7 @@
{
"data": {
"text/plain": [
"𝒩(xi=6.75, w=3.63)\n"
"𝒩(xi=6.10, w=3.29)\n"
]
},
"execution_count": 26,
Expand All @@ -594,7 +585,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"The marginal for m is a ProbabilityDistribution{Univariate,SampleList} with mean 4.229 and variance 0.979\n"
"The marginal for m is a ProbabilityDistribution{Univariate, SampleList} with mean 4.175 and variance 0.902\n"
]
}
],
Expand All @@ -617,15 +608,15 @@
"lastKernelId": null
},
"kernelspec": {
"display_name": "Julia 1.3.0",
"display_name": "Julia 1.6.4",
"language": "julia",
"name": "julia-1.3"
"name": "julia-1.6"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "1.3.0"
"version": "1.6.4"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion src/engines/julia/generators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ function vagueSourceCode(entry::ScheduleEntry)
family_code = removePrefix(entry.family)
dims = entry.dimensionality
if dims == ()
vague_code = "vague($family_code)"
vague_code = "vague($family_code)" # Default
else
vague_code = "vague($family_code, $dims)"
end
Expand Down
12 changes: 6 additions & 6 deletions src/engines/julia/update_rules/gaussian_mean_precision.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ function ruleVBGaussianMeanPrecisionW( dist_out::ProbabilityDistribution{Multiv
(m_mean, v_mean) = unsafeMeanCov(dist_mean)
(m_out, v_out) = unsafeMeanCov(dist_out)

Message(MatrixVariate, Wishart, v=cholinv( v_mean + v_out + (m_mean - m_out)*(m_mean - m_out)' ), nu=dims(dist_out) + 2.0)
Message(MatrixVariate, Wishart, v=cholinv( v_mean + v_out + (m_mean - m_out)*(m_mean - m_out)' ), nu=dims(dist_out)[1] + 2.0)
end

ruleVBGaussianMeanPrecisionOut( dist_out::Any,
Expand All @@ -63,21 +63,21 @@ ruleVBGaussianMeanPrecisionOut( dist_out::Any,
Message(V, GaussianMeanPrecision, m=unsafeMean(dist_mean), w=unsafeMean(dist_prec))

ruleSVBGaussianMeanPrecisionOutVGD(dist_out::Any,
msg_mean::Message{F, V},
dist_prec::ProbabilityDistribution) where{F<:Gaussian, V<:VariateType} =
msg_mean::Message{<:Gaussian, V},
dist_prec::ProbabilityDistribution) where V<:VariateType =
Message(V, GaussianMeanVariance, m=unsafeMean(msg_mean.dist), v=unsafeCov(msg_mean.dist) + cholinv(unsafeMean(dist_prec)))

function ruleSVBGaussianMeanPrecisionW(
dist_out_mean::ProbabilityDistribution{Multivariate, F},
dist_prec::Any) where F<:Gaussian

joint_dims = dims(dist_out_mean)
joint_d = dims(dist_out_mean)[1]
d_out_mean = convert(ProbabilityDistribution{Multivariate, GaussianMeanVariance}, dist_out_mean)
(m, V) = unsafeMeanCov(d_out_mean)
if joint_dims == 2
if joint_d == 2
return Message(Univariate, Gamma, a=1.5, b=0.5*(V[1,1] - V[1,2] - V[2,1] + V[2,2] + (m[1] - m[2])^2))
else
d = Int64(joint_dims/2)
d = Int64(joint_d/2)
return Message(MatrixVariate, Wishart, v=cholinv( V[1:d,1:d] - V[1:d,d+1:end] - V[d+1:end, 1:d] + V[d+1:end,d+1:end] + (m[1:d] - m[d+1:end])*(m[1:d] - m[d+1:end])' ), nu=d + 2.0)
end
end
Expand Down
4 changes: 2 additions & 2 deletions src/engines/julia/update_rules/gaussian_mixture.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ function ruleVBGaussianMixtureW(dist_out::ProbabilityDistribution,
(m_mean_k, v_mean_k) = unsafeMeanCov(dist_means[k])
(m_out, v_out) = unsafeMeanCov(dist_out)
z_bar = unsafeMeanVector(dist_switch)
d = dims(dist_means[1])
d = dims(dist_means[1])[1]

return Message(MatrixVariate, Wishart,
nu = 1.0 + z_bar[k] + d,
Expand Down Expand Up @@ -123,7 +123,7 @@ function ruleVBGaussianMixtureOut( dist_out::Any,
dist_means = collect(dist_factors[1:2:end])
dist_precs = collect(dist_factors[2:2:end])
z_bar = unsafeMeanVector(dist_switch)
d = dims(dist_means[1])
d = dims(dist_means[1])[1]

w = Diagonal(zeros(d))
xi = zeros(d)
Expand Down
1 change: 0 additions & 1 deletion src/engines/julia/update_rules/multiplication.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ function ruleSPMultiplicationIn1GNP(msg_out::Message{F, Multivariate},

dist_a_matr = convert(ProbabilityDistribution{MatrixVariate, PointMass}, msg_a.dist)
msg_in1_mult = ruleSPMultiplicationIn1GNP(msg_out, nothing, Message(dist_a_matr))
(dims(msg_in1_mult.dist) == 1) || error("Implicit conversion to Univariate failed for $(msg_in1_mult.dist)")

return Message(Univariate, GaussianWeightedMeanPrecision, xi=msg_in1_mult.dist.params[:xi][1], w=msg_in1_mult.dist.params[:w][1,1])
end
81 changes: 47 additions & 34 deletions src/engines/julia/update_rules/nonlinear_extended.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,41 +6,46 @@ ruleSPNonlinearEInGX,
ruleMNonlinearEInGX

"""
Concatenate a vector of vectors and return with original dimensions (for splitting)
Concatenate a vector (of vectors and floats) and return with original dimensions (for splitting)
"""
function concatenate(xs::Vector{Vector{Float64}})
ds = [length(x_k) for x_k in xs] # Extract dimensions
function concatenate(xs::Vector)
ds = [size(x_k) for x_k in xs] # Extract dimensions
x = vcat(xs...)

return (x, ds)
end

"""
Return local linearization of g around expansion point x_hat
for Nonlinear node with single input interface
"""
function localLinearization(V::Type{Univariate}, g::Function, x_hat::Float64)
function localLinearizationSingleIn(g::Function, x_hat::Float64)
a = ForwardDiff.derivative(g, x_hat)
b = g(x_hat) - a*x_hat

return (a, b)
end

function localLinearization(V::Type{Multivariate}, g::Function, x_hat::Vector{Float64})
function localLinearizationSingleIn(g::Function, x_hat::Vector{Float64})
A = ForwardDiff.jacobian(g, x_hat)
b = g(x_hat) - A*x_hat

return (A, b)
end

function localLinearization(V::Type{Univariate}, g::Function, x_hat::Vector{Float64})
"""
Return local linearization of g around expansion point x_hat
for Nonlinear node with multiple input interfaces
"""
function localLinearizationMultiIn(g::Function, x_hat::Vector{Float64})
g_unpacked(x::Vector) = g(x...)
A = ForwardDiff.gradient(g_unpacked, x_hat)'
b = g(x_hat...) - A*x_hat

return (A, b)
end

function localLinearization(V::Type{Multivariate}, g::Function, x_hat::Vector{Vector{Float64}})
function localLinearizationMultiIn(g::Function, x_hat::Vector{Vector{Float64}})
(x_cat, ds) = concatenate(x_hat)
g_unpacked(x::Vector) = g(split(x, ds)...)
A = ForwardDiff.jacobian(g_unpacked, x_cat)
Expand All @@ -57,74 +62,82 @@ end
# Forward rule
function ruleSPNonlinearEOutNG(g::Function,
msg_out::Nothing,
msg_in1::Message{<:Gaussian, V}) where V<:VariateType
msg_in1::Message{<:Gaussian})

(m_in1, V_in1) = unsafeMeanCov(msg_in1.dist)
(A, b) = localLinearization(V, g, m_in1)
(A, b) = localLinearizationSingleIn(g, m_in1)
m = A*m_in1 + b
V = A*V_in1*A'

return Message(GaussianMeanVariance, A*m_in1 + b, A*V_in1*A') # Automatically determine VariateType
return Message(variateType(m), GaussianMeanVariance, m=m, v=V)
end

# Multi-argument forward rule
function ruleSPNonlinearEOutNGX(g::Function, # Needs to be in front of Vararg
msg_out::Nothing,
msgs_in::Vararg{Message{<:Gaussian, V}}) where V<:VariateType
msgs_in::Vararg{Message{<:Gaussian}})

(ms_fw_in, Vs_fw_in) = collectStatistics(msgs_in...) # Returns arrays with individual means and covariances
(A, b) = localLinearization(V, g, ms_fw_in)
(A, b) = localLinearizationMultiIn(g, ms_fw_in)
(m_fw_in, V_fw_in, _) = concatenateGaussianMV(ms_fw_in, Vs_fw_in)
m = A*m_fw_in + b
V = A*V_fw_in*A'

return Message(GaussianMeanVariance, A*m_fw_in + b, A*V_fw_in*A') # Automatically determine VariateType
return Message(variateType(m), GaussianMeanVariance, m=m, v=V)
end

# Backward rule with given inverse
function ruleSPNonlinearEIn1GG(g::Function,
g_inv::Function,
msg_out::Message{<:Gaussian, V},
msg_in1::Nothing) where V<:VariateType
msg_out::Message{<:Gaussian},
msg_in1::Nothing)

(m_out, V_out) = unsafeMeanCov(msg_out.dist)
(A, b) = localLinearization(V, g_inv, m_out)
(A, b) = localLinearizationSingleIn(g_inv, m_out)
m = A*m_out + b
V = A*V_out*A'

return Message(GaussianMeanVariance, A*m_out + b, A*V_out*A') # Automatically determine VariateType
return Message(variateType(m), GaussianMeanVariance, m=m, v=V)
end

# Multi-argument backward rule with given inverse
function ruleSPNonlinearEInGX(g::Function, # Needs to be in front of Vararg
g_inv::Function,
msg_out::Message{<:Gaussian},
msgs_in::Vararg{Union{Message{<:Gaussian, V}, Nothing}}) where V<:VariateType
msgs_in::Vararg{Union{Message{<:Gaussian}, Nothing}})

(ms, Vs) = collectStatistics(msg_out, msgs_in...) # Returns arrays with individual means and covariances
(A, b) = localLinearization(V, g_inv, ms)
(A, b) = localLinearizationMultiIn(g_inv, ms)
(mc, Vc) = concatenateGaussianMV(ms, Vs)
m = A*mc + b
V = A*Vc*A'

return Message(V, GaussianMeanVariance, m=A*mc, v=A*Vc*A')
return Message(variateType(m), GaussianMeanVariance, m=m, v=V)
end

# Backward rule with unknown inverse
function ruleSPNonlinearEIn1GG(g::Function,
msg_out::Message{<:Gaussian},
msg_in1::Message{<:Gaussian, V}) where V<:VariateType
msg_in1::Message{<:Gaussian})

m_in1 = unsafeMean(msg_in1.dist)
d_out = convert(ProbabilityDistribution{V, GaussianMeanPrecision}, msg_out.dist)
m_out = d_out.params[:m]
W_out = d_out.params[:w]
(A, b) = localLinearization(V, g, m_in1)
(m_out, W_out) = unsafeMeanPrecision(msg_out.dist)
(A, b) = localLinearizationSingleIn(g, m_in1)
xi = A'*W_out*(m_out - b)
W = A'*W_out*A

return Message(V, GaussianWeightedMeanPrecision, xi=A'*W_out*(m_out - b), w=A'*W_out*A)
return Message(variateType(xi), GaussianWeightedMeanPrecision, xi=xi, w=W)
end

# Multi-argument backward rule with unknown inverse
function ruleSPNonlinearEInGX(g::Function,
inx::Int64, # Index of inbound interface inx
msg_out::Message{<:Gaussian},
msgs_in::Vararg{Message{<:Gaussian, V}}) where V<:VariateType
msgs_in::Vararg{Message{<:Gaussian}})

# Approximate joint inbounds
(ms_fw_in, Vs_fw_in) = collectStatistics(msgs_in...) # Returns arrays with individual means and covariances
(A, b) = localLinearization(V, g, ms_fw_in)
(A, b) = localLinearizationMultiIn(g, ms_fw_in)

(m_fw_in, V_fw_in, ds) = concatenateGaussianMV(ms_fw_in, Vs_fw_in)
m_fw_out = A*m_fw_in + b
Expand All @@ -136,27 +149,27 @@ function ruleSPNonlinearEInGX(g::Function,
(m_in, V_in) = smoothRTS(m_fw_out, V_fw_out, C_fw, m_fw_in, V_fw_in, m_bw_out, V_bw_out)

# Marginalize joint belief on in's
(m_inx, V_inx) = marginalizeGaussianMV(V, m_in, V_in, ds, inx) # Marginalization is overloaded on VariateType V
(m_inx, V_inx) = marginalizeGaussianMV(m_in, V_in, ds, inx)
W_inx = cholinv(V_inx) # Convert to canonical statistics
xi_inx = W_inx*m_inx

# Divide marginal on inx by forward message
(xi_fw_inx, W_fw_inx) = unsafeWeightedMeanPrecision(msgs_in[inx].dist)
xi_bw_inx = xi_inx - xi_fw_inx
W_bw_inx = W_inx - W_fw_inx # Note: subtraction might lead to posdef inconsistencies
W_bw_inx = W_inx - W_fw_inx # Note: subtraction might lead to posdef violations

return Message(V, GaussianWeightedMeanPrecision, xi=xi_bw_inx, w=W_bw_inx)
return Message(variateType(xi_bw_inx), GaussianWeightedMeanPrecision, xi=xi_bw_inx, w=W_bw_inx)
end

function ruleMNonlinearEInGX(g::Function,
msg_out::Message{<:Gaussian},
msgs_in::Vararg{Message{<:Gaussian, V}}) where V<:VariateType
msgs_in::Vararg{Message{<:Gaussian}})

# Approximate joint inbounds
(ms_fw_in, Vs_fw_in) = collectStatistics(msgs_in...) # Returns arrays with individual means and covariances
(A, b) = localLinearization(V, g, ms_fw_in)
(A, b) = localLinearizationMultiIn(g, ms_fw_in)

(m_fw_in, V_fw_in, ds) = concatenateGaussianMV(ms_fw_in, Vs_fw_in)
(m_fw_in, V_fw_in, _) = concatenateGaussianMV(ms_fw_in, Vs_fw_in)
m_fw_out = A*m_fw_in + b
V_fw_out = A*V_fw_in*A'
C_fw = V_fw_in*A'
Expand Down
Loading