Skip to content

Commit

Permalink
Merge pull request #8 from biaslab/nonlinear_inverse
Browse files Browse the repository at this point in the history
Nonlinear inverse
  • Loading branch information
ThijsvdLaar authored Oct 31, 2018
2 parents 8831278 + 76aed96 commit e21ade3
Show file tree
Hide file tree
Showing 9 changed files with 45 additions and 11 deletions.
31 changes: 25 additions & 6 deletions src/engines/julia/update_rules/nonlinear.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ end
function ruleSPNonlinearOutVG( msg_out::Nothing,
msg_in1::Message{F, Multivariate},
g::Function,
J_g::Function) where F<:Gaussian
J_g::Function,
g_inv::Union{Function, Nothing}=nothing) where F<:Gaussian

d_in1 = convert(ProbabilityDistribution{Multivariate, GaussianMeanVariance}, msg_in1.dist)

Expand All @@ -30,7 +31,8 @@ end
function ruleSPNonlinearOutVG( msg_out::Nothing,
msg_in1::Message{F, Univariate},
g::Function,
J_g::Function) where F<:Gaussian
J_g::Function,
g_inv::Union{Function, Nothing}=nothing) where F<:Gaussian

d_in1 = convert(ProbabilityDistribution{Univariate, GaussianMeanVariance}, msg_in1.dist)

Expand All @@ -43,11 +45,18 @@ end
function ruleSPNonlinearIn1GV( msg_out::Message{F, Multivariate},
msg_in1::Message, # Any type of message, used for determining the approximation point
g::Function,
J_g::Function) where F<:Gaussian
J_g::Function,
g_inv::Union{Function, Nothing}=nothing) where F<:Gaussian

d_out = convert(ProbabilityDistribution{Multivariate, GaussianMeanPrecision}, msg_out.dist)

(A, b) = approximate(unsafeMean(msg_in1.dist), g, J_g)
if (g_inv != nothing)
x_hat = g_inv(unsafeMean(msg_out.dist))
else
x_hat = unsafeMean(msg_in1.dist)
end

(A, b) = approximate(x_hat, g, J_g)
A_inv = pinv(A)
W_q = A'*d_out.params[:w]*A
W_q = W_q + tiny*diageye(size(W_q)[1]) # Ensure W_q is invertible
Expand All @@ -58,11 +67,18 @@ end
function ruleSPNonlinearIn1GV( msg_out::Message{F, Univariate},
msg_in1::Message, # Any type of message, used for determining the approximation point
g::Function,
J_g::Function) where F<:Gaussian
J_g::Function,
g_inv::Union{Function, Nothing}=nothing) where F<:Gaussian

d_out = convert(ProbabilityDistribution{Univariate, GaussianMeanPrecision}, msg_out.dist)

(a, b) = approximate(unsafeMean(msg_in1.dist), g, J_g)
if (g_inv != nothing)
x_hat = g_inv(unsafeMean(msg_out.dist))
else
x_hat = unsafeMean(msg_in1.dist)
end

(a, b) = approximate(x_hat, g, J_g)
w_q = clamp(d_out.params[:w]*a^2, tiny, huge)

Message(Univariate, GaussianMeanPrecision, m=(d_out.params[:m] - b)/a, w=w_q)
Expand Down Expand Up @@ -100,6 +116,9 @@ function collectSumProductNodeInbounds(node::Nonlinear, entry::ScheduleEntry, in
# These functions need to be defined in the scope of the user
push!(inbound_messages, "$(node.g)")
push!(inbound_messages, "$(node.J_g)")
if (node.g_inv != nothing)
push!(inbound_messages, "$(node.g_inv)")
end

return inbound_messages
end
2 changes: 2 additions & 0 deletions src/factor_nodes/gaussian_mean_precision.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ vague(::Type{GaussianMeanPrecision}, dims::Int64) = ProbabilityDistribution(Mult

unsafeMean(dist::ProbabilityDistribution{V, GaussianMeanPrecision}) where V<:VariateType = deepcopy(dist.params[:m]) # unsafe mean

unsafeMode(dist::ProbabilityDistribution{V, GaussianMeanPrecision}) where V<:VariateType = deepcopy(dist.params[:m]) # unsafe mode

unsafeVar(dist::ProbabilityDistribution{Univariate, GaussianMeanPrecision}) = 1.0/dist.params[:w] # unsafe variance
unsafeVar(dist::ProbabilityDistribution{Multivariate, GaussianMeanPrecision}) = diag(cholinv(dist.params[:w]))

Expand Down
2 changes: 2 additions & 0 deletions src/factor_nodes/gaussian_mean_variance.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ vague(::Type{GaussianMeanVariance}, dims::Int64) = ProbabilityDistribution(Multi

unsafeMean(dist::ProbabilityDistribution{V, GaussianMeanVariance}) where V<:VariateType = deepcopy(dist.params[:m]) # unsafe mean

unsafeMode(dist::ProbabilityDistribution{V, GaussianMeanVariance}) where V<:VariateType = deepcopy(dist.params[:m]) # unsafe mode

unsafeVar(dist::ProbabilityDistribution{Univariate, GaussianMeanVariance}) = dist.params[:v] # unsafe variance
unsafeVar(dist::ProbabilityDistribution{Multivariate, GaussianMeanVariance}) = diag(dist.params[:v])

Expand Down
2 changes: 2 additions & 0 deletions src/factor_nodes/gaussian_weighted_mean_precision.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ vague(::Type{GaussianWeightedMeanPrecision}, dims::Int64) = ProbabilityDistribut

unsafeMean(dist::ProbabilityDistribution{V, GaussianWeightedMeanPrecision}) where V<:VariateType = cholinv(dist.params[:w])*dist.params[:xi]

unsafeMode(dist::ProbabilityDistribution{V, GaussianWeightedMeanPrecision}) where V<:VariateType = cholinv(dist.params[:w])*dist.params[:xi]

unsafeVar(dist::ProbabilityDistribution{Univariate, GaussianWeightedMeanPrecision}) = 1.0/dist.params[:w] # unsafe variance
unsafeVar(dist::ProbabilityDistribution{Multivariate, GaussianWeightedMeanPrecision}) = diag(cholinv(dist.params[:w]))

Expand Down
5 changes: 3 additions & 2 deletions src/factor_nodes/nonlinear.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@ mutable struct Nonlinear <: DeltaFactor

g::Function # Vector function that expresses the output vector as a function of the input vector; reduces to scalar for 1-d
J_g::Function # Jacobi matrix of g, as a function of the input vector; in the 1-d case this reduces to the first derivative of g
g_inv::Union{Function, Nothing} # When g is invertible, g_inv might provide a more accurate estimate of the working point for the backward message
dims::Tuple # Dimension of breaker message on input interface

function Nonlinear(out, in1, g::Function, J_g::Function; dims=(1,), id=ForneyLab.generateId(Nonlinear))
function Nonlinear(out, in1, g::Function, J_g::Function, g_inv::Union{Function, Nothing}=nothing; dims=(1,), id=ForneyLab.generateId(Nonlinear))
@ensureVariables(out, in1)
self = new(id, Vector{Interface}(undef, 2), Dict{Symbol,Interface}(), g, J_g, dims)
self = new(id, Vector{Interface}(undef, 2), Dict{Symbol,Interface}(), g, J_g, g_inv, dims)
ForneyLab.addNode!(currentGraph(), self)
self.i[:out] = self.interfaces[1] = associate!(Interface(self), out)
self.i[:in1] = self.interfaces[2] = associate!(Interface(self), in1)
Expand Down
2 changes: 2 additions & 0 deletions src/probability_distribution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ MatrixVariate,
PointMass,
@RV,
mean,
mode,
var,
cov,
differentialEntropy,
Expand Down Expand Up @@ -35,6 +36,7 @@ matches(Ta::Type{Pa}, Tb::Type{Pb}) where {Pa<:ProbabilityDistribution, Pb<:Prob
matches(::Type{Nothing}, ::Type{T}) where T<:ProbabilityDistribution = false

mean(dist::ProbabilityDistribution) = isProper(dist) ? unsafeMean(dist) : error("mean($(dist)) is undefined because the distribution is improper.")
mode(dist::ProbabilityDistribution) = isProper(dist) ? unsafeMode(dist) : error("mode($(dist)) is undefined because the distribution is improper.")
var(dist::ProbabilityDistribution) = isProper(dist) ? unsafeVar(dist) : error("var($(dist)) is undefined because the distribution is improper.")
cov(dist::ProbabilityDistribution) = isProper(dist) ? unsafeCov(dist) : error("cov($(dist)) is undefined because the distribution is improper.")

Expand Down
4 changes: 3 additions & 1 deletion test/factor_nodes/test_gaussian_mean_precision.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module GaussianMeanPrecisionTest

using Test
using ForneyLab
import ForneyLab: outboundType, isApplicable, isProper, unsafeMean, unsafeVar, unsafeCov, unsafeMeanCov, unsafePrecision, unsafeWeightedMean, unsafeWeightedMeanPrecision
import ForneyLab: outboundType, isApplicable, isProper, unsafeMean, unsafeMode, unsafeVar, unsafeCov, unsafeMeanCov, unsafePrecision, unsafeWeightedMean, unsafeWeightedMeanPrecision
import ForneyLab: SPGaussianMeanPrecisionOutVPP, SPGaussianMeanPrecisionMPVP, SPGaussianMeanPrecisionOutVGP, SPGaussianMeanPrecisionMGVP, VBGaussianMeanPrecisionOut, VBGaussianMeanPrecisionM, VBGaussianMeanPrecisionW, SVBGaussianMeanPrecisionOutVGD, SVBGaussianMeanPrecisionMGVD, SVBGaussianMeanPrecisionW, MGaussianMeanPrecisionGGD

@testset "dims" begin
Expand Down Expand Up @@ -39,6 +39,7 @@ end
@testset "unsafe statistics" begin
# Univariate
@test unsafeMean(ProbabilityDistribution(Univariate, GaussianMeanPrecision, m=2.0, w=4.0)) == 2.0
@test unsafeMode(ProbabilityDistribution(Univariate, GaussianMeanPrecision, m=2.0, w=4.0)) == 2.0
@test unsafeVar(ProbabilityDistribution(Univariate, GaussianMeanPrecision, m=2.0, w=4.0)) == 0.25
@test unsafeCov(ProbabilityDistribution(Univariate, GaussianMeanPrecision, m=2.0, w=4.0)) == 0.25
@test unsafeMeanCov(ProbabilityDistribution(Univariate, GaussianMeanPrecision, m=2.0, w=4.0)) == (2.0, 0.25)
Expand All @@ -48,6 +49,7 @@ end

# Multivariate
@test unsafeMean(ProbabilityDistribution(Multivariate, GaussianMeanPrecision, m=[2.0], w=mat(4.0))) == [2.0]
@test unsafeMode(ProbabilityDistribution(Multivariate, GaussianMeanPrecision, m=[2.0], w=mat(4.0))) == [2.0]
@test unsafeVar(ProbabilityDistribution(Multivariate, GaussianMeanPrecision, m=[2.0], w=mat(4.0))) == [0.25]
@test unsafeCov(ProbabilityDistribution(Multivariate, GaussianMeanPrecision, m=[2.0], w=mat(4.0))) == mat(0.25)
@test unsafeMeanCov(ProbabilityDistribution(Multivariate, GaussianMeanPrecision, m=[2.0], w=mat(4.0))) == ([2.0], mat(0.25))
Expand Down
4 changes: 3 additions & 1 deletion test/factor_nodes/test_gaussian_mean_variance.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module GaussianMeanVarianceTest

using Test
using ForneyLab
import ForneyLab: outboundType, isApplicable, isProper, unsafeMean, unsafeVar, unsafeCov, unsafeMeanCov, unsafePrecision, unsafeWeightedMean, unsafeWeightedMeanPrecision
import ForneyLab: outboundType, isApplicable, isProper, unsafeMean, unsafeMode, unsafeVar, unsafeCov, unsafeMeanCov, unsafePrecision, unsafeWeightedMean, unsafeWeightedMeanPrecision
import ForneyLab: SPGaussianMeanVarianceOutVPP, SPGaussianMeanVarianceMPVP, SPGaussianMeanVarianceOutVGP, SPGaussianMeanVarianceMGVP, VBGaussianMeanVarianceM, VBGaussianMeanVarianceOut
import LinearAlgebra: det, diag

Expand Down Expand Up @@ -40,6 +40,7 @@ end
@testset "unsafe statistics" begin
# Univariate
@test unsafeMean(ProbabilityDistribution(Univariate, GaussianMeanVariance, m=2.0, v=4.0)) == 2.0
@test unsafeMode(ProbabilityDistribution(Univariate, GaussianMeanVariance, m=2.0, v=4.0)) == 2.0
@test unsafeVar(ProbabilityDistribution(Univariate, GaussianMeanVariance, m=2.0, v=4.0)) == 4.0
@test unsafeCov(ProbabilityDistribution(Univariate, GaussianMeanVariance, m=2.0, v=4.0)) == 4.0
@test unsafeMeanCov(ProbabilityDistribution(Univariate, GaussianMeanVariance, m=2.0, v=4.0)) == (2.0, 4.0)
Expand All @@ -49,6 +50,7 @@ end

# Multivariate
@test unsafeMean(ProbabilityDistribution(Multivariate, GaussianMeanVariance, m=[2.0], v=mat(4.0))) == [2.0]
@test unsafeMode(ProbabilityDistribution(Multivariate, GaussianMeanVariance, m=[2.0], v=mat(4.0))) == [2.0]
@test unsafeVar(ProbabilityDistribution(Multivariate, GaussianMeanVariance, m=[2.0], v=mat(4.0))) == [4.0]
@test unsafeCov(ProbabilityDistribution(Multivariate, GaussianMeanVariance, m=[2.0], v=mat(4.0))) == mat(4.0)
@test unsafeMeanCov(ProbabilityDistribution(Multivariate, GaussianMeanVariance, m=[2.0], v=mat(4.0))) == ([2.0], mat(4.0))
Expand Down
4 changes: 3 additions & 1 deletion test/factor_nodes/test_gaussian_weighted_mean_precision.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module GaussianWeightedMeanPrecisionTest

using Test
using ForneyLab
import ForneyLab: isProper, unsafeMean, unsafeVar, unsafeCov, unsafeMeanCov, unsafePrecision, unsafeWeightedMean, unsafeWeightedMeanPrecision
import ForneyLab: isProper, unsafeMean, unsafeMode, unsafeVar, unsafeCov, unsafeMeanCov, unsafePrecision, unsafeWeightedMean, unsafeWeightedMeanPrecision

@testset "dims" begin
@test dims(ProbabilityDistribution(Univariate, GaussianWeightedMeanPrecision, xi=0.0, w=1.0)) == 1
Expand Down Expand Up @@ -38,6 +38,7 @@ end
@testset "unsafe statistics" begin
# Univariate
@test unsafeMean(ProbabilityDistribution(Univariate, GaussianWeightedMeanPrecision, xi=2.0, w=4.0)) == 0.5
@test unsafeMode(ProbabilityDistribution(Univariate, GaussianWeightedMeanPrecision, xi=2.0, w=4.0)) == 0.5
@test unsafeVar(ProbabilityDistribution(Univariate, GaussianWeightedMeanPrecision, xi=2.0, w=4.0)) == 0.25
@test unsafeCov(ProbabilityDistribution(Univariate, GaussianWeightedMeanPrecision, xi=2.0, w=4.0)) == 0.25
@test unsafeMeanCov(ProbabilityDistribution(Univariate, GaussianWeightedMeanPrecision, xi=2.0, w=4.0)) == (0.5, 0.25)
Expand All @@ -47,6 +48,7 @@ end

# Multivariate
@test unsafeMean(ProbabilityDistribution(Multivariate, GaussianWeightedMeanPrecision, xi=[2.0], w=mat(4.0))) == [0.5]
@test unsafeMode(ProbabilityDistribution(Multivariate, GaussianWeightedMeanPrecision, xi=[2.0], w=mat(4.0))) == [0.5]
@test unsafeVar(ProbabilityDistribution(Multivariate, GaussianWeightedMeanPrecision, xi=[2.0], w=mat(4.0))) == [0.25]
@test unsafeCov(ProbabilityDistribution(Multivariate, GaussianWeightedMeanPrecision, xi=[2.0], w=mat(4.0))) == mat(0.25)
@test unsafeMeanCov(ProbabilityDistribution(Multivariate, GaussianWeightedMeanPrecision, xi=[2.0], w=mat(4.0))) == ([0.5], mat(0.25))
Expand Down

0 comments on commit e21ade3

Please sign in to comment.