diff --git a/src/engines/julia/update_rules/nonlinear.jl b/src/engines/julia/update_rules/nonlinear.jl index 370b4889..a467211b 100644 --- a/src/engines/julia/update_rules/nonlinear.jl +++ b/src/engines/julia/update_rules/nonlinear.jl @@ -16,7 +16,8 @@ end function ruleSPNonlinearOutVG( msg_out::Nothing, msg_in1::Message{F, Multivariate}, g::Function, - J_g::Function) where F<:Gaussian + J_g::Function, + g_inv::Union{Function, Nothing}=nothing) where F<:Gaussian d_in1 = convert(ProbabilityDistribution{Multivariate, GaussianMeanVariance}, msg_in1.dist) @@ -30,7 +31,8 @@ end function ruleSPNonlinearOutVG( msg_out::Nothing, msg_in1::Message{F, Univariate}, g::Function, - J_g::Function) where F<:Gaussian + J_g::Function, + g_inv::Union{Function, Nothing}=nothing) where F<:Gaussian d_in1 = convert(ProbabilityDistribution{Univariate, GaussianMeanVariance}, msg_in1.dist) @@ -43,11 +45,18 @@ end function ruleSPNonlinearIn1GV( msg_out::Message{F, Multivariate}, msg_in1::Message, # Any type of message, used for determining the approximation point g::Function, - J_g::Function) where F<:Gaussian + J_g::Function, + g_inv::Union{Function, Nothing}=nothing) where F<:Gaussian d_out = convert(ProbabilityDistribution{Multivariate, GaussianMeanPrecision}, msg_out.dist) - (A, b) = approximate(unsafeMean(msg_in1.dist), g, J_g) + if (g_inv != nothing) + x_hat = g_inv(unsafeMean(msg_out.dist)) + else + x_hat = unsafeMean(msg_in1.dist) + end + + (A, b) = approximate(x_hat, g, J_g) A_inv = pinv(A) W_q = A'*d_out.params[:w]*A W_q = W_q + tiny*diageye(size(W_q)[1]) # Ensure W_q is invertible @@ -58,11 +67,18 @@ end function ruleSPNonlinearIn1GV( msg_out::Message{F, Univariate}, msg_in1::Message, # Any type of message, used for determining the approximation point g::Function, - J_g::Function) where F<:Gaussian + J_g::Function, + g_inv::Union{Function, Nothing}=nothing) where F<:Gaussian d_out = convert(ProbabilityDistribution{Univariate, GaussianMeanPrecision}, msg_out.dist) - (a, b) = approximate(unsafeMean(msg_in1.dist), g, J_g) + if (g_inv != nothing) + x_hat = g_inv(unsafeMean(msg_out.dist)) + else + x_hat = unsafeMean(msg_in1.dist) + end + + (a, b) = approximate(x_hat, g, J_g) w_q = clamp(d_out.params[:w]*a^2, tiny, huge) Message(Univariate, GaussianMeanPrecision, m=(d_out.params[:m] - b)/a, w=w_q) @@ -100,6 +116,9 @@ function collectSumProductNodeInbounds(node::Nonlinear, entry::ScheduleEntry, in # These functions need to be defined in the scope of the user push!(inbound_messages, "$(node.g)") push!(inbound_messages, "$(node.J_g)") + if (node.g_inv != nothing) + push!(inbound_messages, "$(node.g_inv)") + end return inbound_messages end diff --git a/src/factor_nodes/gaussian_mean_precision.jl b/src/factor_nodes/gaussian_mean_precision.jl index b9afe270..68c17e49 100644 --- a/src/factor_nodes/gaussian_mean_precision.jl +++ b/src/factor_nodes/gaussian_mean_precision.jl @@ -49,6 +49,8 @@ vague(::Type{GaussianMeanPrecision}, dims::Int64) = ProbabilityDistribution(Mult unsafeMean(dist::ProbabilityDistribution{V, GaussianMeanPrecision}) where V<:VariateType = deepcopy(dist.params[:m]) # unsafe mean +unsafeMode(dist::ProbabilityDistribution{V, GaussianMeanPrecision}) where V<:VariateType = deepcopy(dist.params[:m]) # unsafe mode + unsafeVar(dist::ProbabilityDistribution{Univariate, GaussianMeanPrecision}) = 1.0/dist.params[:w] # unsafe variance unsafeVar(dist::ProbabilityDistribution{Multivariate, GaussianMeanPrecision}) = diag(cholinv(dist.params[:w])) diff --git a/src/factor_nodes/gaussian_mean_variance.jl b/src/factor_nodes/gaussian_mean_variance.jl index 3a3104d0..a8c1075f 100644 --- a/src/factor_nodes/gaussian_mean_variance.jl +++ b/src/factor_nodes/gaussian_mean_variance.jl @@ -51,6 +51,8 @@ vague(::Type{GaussianMeanVariance}, dims::Int64) = ProbabilityDistribution(Multi unsafeMean(dist::ProbabilityDistribution{V, GaussianMeanVariance}) where V<:VariateType = deepcopy(dist.params[:m]) # unsafe mean +unsafeMode(dist::ProbabilityDistribution{V, GaussianMeanVariance}) where V<:VariateType = deepcopy(dist.params[:m]) # unsafe mode + unsafeVar(dist::ProbabilityDistribution{Univariate, GaussianMeanVariance}) = dist.params[:v] # unsafe variance unsafeVar(dist::ProbabilityDistribution{Multivariate, GaussianMeanVariance}) = diag(dist.params[:v]) diff --git a/src/factor_nodes/gaussian_weighted_mean_precision.jl b/src/factor_nodes/gaussian_weighted_mean_precision.jl index 5a102b8e..7b6f1b41 100644 --- a/src/factor_nodes/gaussian_weighted_mean_precision.jl +++ b/src/factor_nodes/gaussian_weighted_mean_precision.jl @@ -17,6 +17,8 @@ vague(::Type{GaussianWeightedMeanPrecision}, dims::Int64) = ProbabilityDistribut unsafeMean(dist::ProbabilityDistribution{V, GaussianWeightedMeanPrecision}) where V<:VariateType = cholinv(dist.params[:w])*dist.params[:xi] +unsafeMode(dist::ProbabilityDistribution{V, GaussianWeightedMeanPrecision}) where V<:VariateType = cholinv(dist.params[:w])*dist.params[:xi] + unsafeVar(dist::ProbabilityDistribution{Univariate, GaussianWeightedMeanPrecision}) = 1.0/dist.params[:w] # unsafe variance unsafeVar(dist::ProbabilityDistribution{Multivariate, GaussianWeightedMeanPrecision}) = diag(cholinv(dist.params[:w])) diff --git a/src/factor_nodes/nonlinear.jl b/src/factor_nodes/nonlinear.jl index 097396f4..96db4e43 100644 --- a/src/factor_nodes/nonlinear.jl +++ b/src/factor_nodes/nonlinear.jl @@ -24,11 +24,12 @@ mutable struct Nonlinear <: DeltaFactor g::Function # Vector function that expresses the output vector as a function of the input vector; reduces to scalar for 1-d J_g::Function # Jacobi matrix of g, as a function of the input vector; in the 1-d case this reduces to the first derivative of g + g_inv::Union{Function, Nothing} # When g is invertible, g_inv might provide a more accurate estimate of the working point for the backward message dims::Tuple # Dimension of breaker message on input interface - function Nonlinear(out, in1, g::Function, J_g::Function; dims=(1,), id=ForneyLab.generateId(Nonlinear)) + function Nonlinear(out, in1, g::Function, J_g::Function, g_inv::Union{Function, Nothing}=nothing; dims=(1,), id=ForneyLab.generateId(Nonlinear)) @ensureVariables(out, in1) - self = new(id, Vector{Interface}(undef, 2), Dict{Symbol,Interface}(), g, J_g, dims) + self = new(id, Vector{Interface}(undef, 2), Dict{Symbol,Interface}(), g, J_g, g_inv, dims) ForneyLab.addNode!(currentGraph(), self) self.i[:out] = self.interfaces[1] = associate!(Interface(self), out) self.i[:in1] = self.interfaces[2] = associate!(Interface(self), in1) diff --git a/src/probability_distribution.jl b/src/probability_distribution.jl index 1e973fd4..88fecadb 100644 --- a/src/probability_distribution.jl +++ b/src/probability_distribution.jl @@ -6,6 +6,7 @@ MatrixVariate, PointMass, @RV, mean, +mode, var, cov, differentialEntropy, @@ -35,6 +36,7 @@ matches(Ta::Type{Pa}, Tb::Type{Pb}) where {Pa<:ProbabilityDistribution, Pb<:Prob matches(::Type{Nothing}, ::Type{T}) where T<:ProbabilityDistribution = false mean(dist::ProbabilityDistribution) = isProper(dist) ? unsafeMean(dist) : error("mean($(dist)) is undefined because the distribution is improper.") +mode(dist::ProbabilityDistribution) = isProper(dist) ? unsafeMode(dist) : error("mode($(dist)) is undefined because the distribution is improper.") var(dist::ProbabilityDistribution) = isProper(dist) ? unsafeVar(dist) : error("var($(dist)) is undefined because the distribution is improper.") cov(dist::ProbabilityDistribution) = isProper(dist) ? unsafeCov(dist) : error("cov($(dist)) is undefined because the distribution is improper.") diff --git a/test/factor_nodes/test_gaussian_mean_precision.jl b/test/factor_nodes/test_gaussian_mean_precision.jl index 78954abb..55499f6f 100644 --- a/test/factor_nodes/test_gaussian_mean_precision.jl +++ b/test/factor_nodes/test_gaussian_mean_precision.jl @@ -2,7 +2,7 @@ module GaussianMeanPrecisionTest using Test using ForneyLab -import ForneyLab: outboundType, isApplicable, isProper, unsafeMean, unsafeVar, unsafeCov, unsafeMeanCov, unsafePrecision, unsafeWeightedMean, unsafeWeightedMeanPrecision +import ForneyLab: outboundType, isApplicable, isProper, unsafeMean, unsafeMode, unsafeVar, unsafeCov, unsafeMeanCov, unsafePrecision, unsafeWeightedMean, unsafeWeightedMeanPrecision import ForneyLab: SPGaussianMeanPrecisionOutVPP, SPGaussianMeanPrecisionMPVP, SPGaussianMeanPrecisionOutVGP, SPGaussianMeanPrecisionMGVP, VBGaussianMeanPrecisionOut, VBGaussianMeanPrecisionM, VBGaussianMeanPrecisionW, SVBGaussianMeanPrecisionOutVGD, SVBGaussianMeanPrecisionMGVD, SVBGaussianMeanPrecisionW, MGaussianMeanPrecisionGGD @testset "dims" begin @@ -39,6 +39,7 @@ end @testset "unsafe statistics" begin # Univariate @test unsafeMean(ProbabilityDistribution(Univariate, GaussianMeanPrecision, m=2.0, w=4.0)) == 2.0 + @test unsafeMode(ProbabilityDistribution(Univariate, GaussianMeanPrecision, m=2.0, w=4.0)) == 2.0 @test unsafeVar(ProbabilityDistribution(Univariate, GaussianMeanPrecision, m=2.0, w=4.0)) == 0.25 @test unsafeCov(ProbabilityDistribution(Univariate, GaussianMeanPrecision, m=2.0, w=4.0)) == 0.25 @test unsafeMeanCov(ProbabilityDistribution(Univariate, GaussianMeanPrecision, m=2.0, w=4.0)) == (2.0, 0.25) @@ -48,6 +49,7 @@ end # Multivariate @test unsafeMean(ProbabilityDistribution(Multivariate, GaussianMeanPrecision, m=[2.0], w=mat(4.0))) == [2.0] + @test unsafeMode(ProbabilityDistribution(Multivariate, GaussianMeanPrecision, m=[2.0], w=mat(4.0))) == [2.0] @test unsafeVar(ProbabilityDistribution(Multivariate, GaussianMeanPrecision, m=[2.0], w=mat(4.0))) == [0.25] @test unsafeCov(ProbabilityDistribution(Multivariate, GaussianMeanPrecision, m=[2.0], w=mat(4.0))) == mat(0.25) @test unsafeMeanCov(ProbabilityDistribution(Multivariate, GaussianMeanPrecision, m=[2.0], w=mat(4.0))) == ([2.0], mat(0.25)) diff --git a/test/factor_nodes/test_gaussian_mean_variance.jl b/test/factor_nodes/test_gaussian_mean_variance.jl index 83f7faed..1e33e0b1 100644 --- a/test/factor_nodes/test_gaussian_mean_variance.jl +++ b/test/factor_nodes/test_gaussian_mean_variance.jl @@ -2,7 +2,7 @@ module GaussianMeanVarianceTest using Test using ForneyLab -import ForneyLab: outboundType, isApplicable, isProper, unsafeMean, unsafeVar, unsafeCov, unsafeMeanCov, unsafePrecision, unsafeWeightedMean, unsafeWeightedMeanPrecision +import ForneyLab: outboundType, isApplicable, isProper, unsafeMean, unsafeMode, unsafeVar, unsafeCov, unsafeMeanCov, unsafePrecision, unsafeWeightedMean, unsafeWeightedMeanPrecision import ForneyLab: SPGaussianMeanVarianceOutVPP, SPGaussianMeanVarianceMPVP, SPGaussianMeanVarianceOutVGP, SPGaussianMeanVarianceMGVP, VBGaussianMeanVarianceM, VBGaussianMeanVarianceOut import LinearAlgebra: det, diag @@ -40,6 +40,7 @@ end @testset "unsafe statistics" begin # Univariate @test unsafeMean(ProbabilityDistribution(Univariate, GaussianMeanVariance, m=2.0, v=4.0)) == 2.0 + @test unsafeMode(ProbabilityDistribution(Univariate, GaussianMeanVariance, m=2.0, v=4.0)) == 2.0 @test unsafeVar(ProbabilityDistribution(Univariate, GaussianMeanVariance, m=2.0, v=4.0)) == 4.0 @test unsafeCov(ProbabilityDistribution(Univariate, GaussianMeanVariance, m=2.0, v=4.0)) == 4.0 @test unsafeMeanCov(ProbabilityDistribution(Univariate, GaussianMeanVariance, m=2.0, v=4.0)) == (2.0, 4.0) @@ -49,6 +50,7 @@ end # Multivariate @test unsafeMean(ProbabilityDistribution(Multivariate, GaussianMeanVariance, m=[2.0], v=mat(4.0))) == [2.0] + @test unsafeMode(ProbabilityDistribution(Multivariate, GaussianMeanVariance, m=[2.0], v=mat(4.0))) == [2.0] @test unsafeVar(ProbabilityDistribution(Multivariate, GaussianMeanVariance, m=[2.0], v=mat(4.0))) == [4.0] @test unsafeCov(ProbabilityDistribution(Multivariate, GaussianMeanVariance, m=[2.0], v=mat(4.0))) == mat(4.0) @test unsafeMeanCov(ProbabilityDistribution(Multivariate, GaussianMeanVariance, m=[2.0], v=mat(4.0))) == ([2.0], mat(4.0)) diff --git a/test/factor_nodes/test_gaussian_weighted_mean_precision.jl b/test/factor_nodes/test_gaussian_weighted_mean_precision.jl index cc243ee0..11bf5db0 100644 --- a/test/factor_nodes/test_gaussian_weighted_mean_precision.jl +++ b/test/factor_nodes/test_gaussian_weighted_mean_precision.jl @@ -2,7 +2,7 @@ module GaussianWeightedMeanPrecisionTest using Test using ForneyLab -import ForneyLab: isProper, unsafeMean, unsafeVar, unsafeCov, unsafeMeanCov, unsafePrecision, unsafeWeightedMean, unsafeWeightedMeanPrecision +import ForneyLab: isProper, unsafeMean, unsafeMode, unsafeVar, unsafeCov, unsafeMeanCov, unsafePrecision, unsafeWeightedMean, unsafeWeightedMeanPrecision @testset "dims" begin @test dims(ProbabilityDistribution(Univariate, GaussianWeightedMeanPrecision, xi=0.0, w=1.0)) == 1 @@ -38,6 +38,7 @@ end @testset "unsafe statistics" begin # Univariate @test unsafeMean(ProbabilityDistribution(Univariate, GaussianWeightedMeanPrecision, xi=2.0, w=4.0)) == 0.5 + @test unsafeMode(ProbabilityDistribution(Univariate, GaussianWeightedMeanPrecision, xi=2.0, w=4.0)) == 0.5 @test unsafeVar(ProbabilityDistribution(Univariate, GaussianWeightedMeanPrecision, xi=2.0, w=4.0)) == 0.25 @test unsafeCov(ProbabilityDistribution(Univariate, GaussianWeightedMeanPrecision, xi=2.0, w=4.0)) == 0.25 @test unsafeMeanCov(ProbabilityDistribution(Univariate, GaussianWeightedMeanPrecision, xi=2.0, w=4.0)) == (0.5, 0.25) @@ -47,6 +48,7 @@ end # Multivariate @test unsafeMean(ProbabilityDistribution(Multivariate, GaussianWeightedMeanPrecision, xi=[2.0], w=mat(4.0))) == [0.5] + @test unsafeMode(ProbabilityDistribution(Multivariate, GaussianWeightedMeanPrecision, xi=[2.0], w=mat(4.0))) == [0.5] @test unsafeVar(ProbabilityDistribution(Multivariate, GaussianWeightedMeanPrecision, xi=[2.0], w=mat(4.0))) == [0.25] @test unsafeCov(ProbabilityDistribution(Multivariate, GaussianWeightedMeanPrecision, xi=[2.0], w=mat(4.0))) == mat(0.25) @test unsafeMeanCov(ProbabilityDistribution(Multivariate, GaussianWeightedMeanPrecision, xi=[2.0], w=mat(4.0))) == ([0.5], mat(0.25))