Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A more universal entropy calculation method for sampling based inference #151

Merged
merged 11 commits into from
Feb 23, 2021
Merged
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ version = "0.11.2"

[deps]
Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Expand All @@ -18,9 +19,9 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
julia = "1"
Documenter = "0.25.2"
ForwardDiff = "0.10.12"
SpecialFunctions = "0.8.0, 0.10.3"
StatsBase = "0.32.2, 0.33.1"
StatsFuns = "0.9.5"
julia = "1"
1 change: 1 addition & 0 deletions src/ForneyLab.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ using Printf: @sprintf
using StatsFuns: logmvgamma, betainvcdf, gammainvcdf, poisinvcdf
using ForwardDiff
using StatsBase: Weights
using DataStructures: Queue, enqueue!, dequeue!

import Statistics: mean, var, cov
import Base: +, -, *, ^, ==, exp, convert, show, prod!
Expand Down
16 changes: 9 additions & 7 deletions src/engines/julia/update_rules/nonlinear_sampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,14 @@ function ruleSPNonlinearSOutNM(g::Function,
weights = msg_in1.dist.params[:w]

return Message(variate, SampleList, s=samples, w=weights)
end
end

function msgSPNonlinearSOutNGX(g::Function,
msg_out::Nothing,
msgs_in::Vararg{Message{<:Gaussian, <:VariateType}};
n_samples=default_n_samples,
variate)

samples_in = [sample(msg_in.dist, n_samples) for msg_in in msgs_in]

samples = g.(samples_in...)
Expand Down Expand Up @@ -83,7 +83,7 @@ function msgSPNonlinearSInGX(g::Function,
msgs_in::Vararg{Message{<:Gaussian, <:VariateType}};
n_samples=default_n_samples,
variate)

# Extract joint statistics of inbound messages
(ms_fw_in, Vs_fw_in) = collectStatistics(msgs_in...) # Return arrays with individual means and covariances
(m_fw_in, V_fw_in, ds) = concatenateGaussianMV(ms_fw_in, Vs_fw_in) # Concatenate individual statistics into joint statistics
Expand Down Expand Up @@ -124,7 +124,7 @@ function ruleSPNonlinearSInGX(g::Function,
msgs_in::Vararg{Message{<:Gaussian, <:VariateType}};
n_samples=default_n_samples,
variate)
msgSPNonlinearSInGX(g, inx, msg_out, msg_in..., n_samples=n_samples, variate=variate)
msgSPNonlinearSInGX(g, inx, msg_out, msgs_in..., n_samples=n_samples, variate=variate)
end

function msgSPNonlinearSOutNFactorX(g::Function,
Expand Down Expand Up @@ -172,7 +172,7 @@ function msgSPNonlinearSInFactorX(g::Function,
push!(samples_in,sample(msgs_in[i].dist, n_samples))
end
end

return samples_in
end

Expand Down Expand Up @@ -345,6 +345,7 @@ function gradientOptimization(log_joint::Function, d_log_joint::Function, m_init
m_old = m_initial
satisfied = false
step_count = 0
m_latests = if dim_tot == 1 Queue{Float64}() else Queue{Vector}() end

while !satisfied
m_new = m_old .+ step_size.*d_log_joint(m_old)
Expand All @@ -360,12 +361,13 @@ function gradientOptimization(log_joint::Function, d_log_joint::Function, m_init
m_new = m_old .+ step_size.*d_log_joint(m_old)
end
step_count += 1
m_total .+= m_old
m_average = m_total ./ step_count
enqueue!(m_latests, m_old)
if step_count > 10
m_average = sum(x for x in m_latests)./10
if sum(sqrt.(((m_new.-m_average)./m_average).^2)) < dim_tot*0.1
satisfied = true
end
dequeue!(m_latests);
end
if step_count > dim_tot*250
satisfied = true
Expand Down
27 changes: 27 additions & 0 deletions src/engines/julia/update_rules/nonlinear_unscented.jl
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,33 @@ function concatenateGaussianMV(ms::Vector{Vector{Float64}}, Vs::Vector{<:Abstrac
return (m, V, ds) # Return concatenated mean and covariance with original dimensions (for splitting)
end

# Concatenate multiple mixed statistics
function concatenateGaussianMV(ms::Vector{Any}, Vs::Vector{Any})
# Extract dimensions
ds = [length(m_k) for m_k in ms]
d_in_tot = sum(ds)

# Initialize concatenated statistics
m = zeros(d_in_tot)
V = zeros(d_in_tot, d_in_tot)

# Construct concatenated statistics
d_start = 1
for k = 1:length(ms) # For each inbound statistic
d_end = d_start + ds[k] - 1
if ds[k] == 1
m[d_start] = ms[k]
V[d_start, d_start] = Vs[k]
else
m[d_start:d_end] = ms[k]
V[d_start:d_end, d_start:d_end] = Vs[k]
end
d_start = d_end + 1
end

return (m, V, ds) # Return concatenated mean and covariance with original dimensions (for splitting)
end

"""
Split a vector in chunks of lengths specified by ds.
"""
Expand Down
62 changes: 51 additions & 11 deletions src/factor_nodes/sample_list.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,18 +127,36 @@ unsafeMeanVector(dist::ProbabilityDistribution{V, SampleList}) where V<:VariateT

isProper(dist::ProbabilityDistribution{V, SampleList}) where V<:VariateType = abs(sum(dist.params[:w]) - 1) < 0.001

# prod of a pdf (or distribution) message and a SampleList message
# this function is capable to calculate entropy with SampleList messages in VMP setting
@symmetrical function prod!(
x::ProbabilityDistribution{V}, # Includes function distributions
y::ProbabilityDistribution{V, SampleList},
z::ProbabilityDistribution{V, SampleList}=ProbabilityDistribution(V, SampleList, s=[0.0], w=[1.0])) where V<:VariateType

samples = y.params[:s]
n_samples = length(samples)
log_samples_x = logPdf.([x], samples)
# Suppose in the previous time step m1(pdf) and m2(pdf) messages collided.
# The resulting collision m3 (sampleList) = m1*m2 is supposed to carry
# the proposal (m1) and integrand (m2) distributions. m1 is the message from which
# the samples are drawn. m2 is the message on which the samples are evaluated and
# weights are calculated. In case Particle Filtering (BP), entropy will not be calculated
# and in the first step there won't be any integrand information.
if haskey(y.params, :logintegrand)
# recall that we are calculating m3*m4. If m3 consists of integrand information
# update it: new_integrand = m2*m3. This allows us to collide arbitrary number of beliefs
# to approximate posterior and yet estimate the entropy.
logIntegrand = (samples) -> y.params[:logintegrand](samples) .+ logPdf.([x], samples)
else
# If there is no integrand information before, set it to m4
logIntegrand = (samples) -> logPdf.([x], samples)
end

samples = y.params[:s] # samples come from proposal (m1)
n_samples = length(samples) # number of samples
log_samples_x = logPdf.([x], samples) # evaluate samples in logm4, i.e. logm4(s)

# Compute sample weights
w_raw_x = exp.(log_samples_x)
w_prod = w_raw_x.*y.params[:w]
w_raw_x = exp.(log_samples_x) # m4(s)
w_prod = w_raw_x.*y.params[:w] # update the weights of posterior w_unnormalized = m4(s)*w_prev
weights = w_prod./sum(w_prod) # Normalize weights

# Resample if required
Expand All @@ -148,9 +166,21 @@ isProper(dist::ProbabilityDistribution{V, SampleList}) where V<:VariateType = ab
weights = ones(n_samples)./n_samples
end

# TODO: no entropy is computed here; include computation?
z.params[:w] = weights
z.params[:s] = samples
# resulting posterior or message
z.params[:w] = weights # set adjusted weights
z.params[:s] = samples # samples are still coming from the same proposal
z.params[:logintegrand] = logIntegrand # set integrand
if haskey(y.params, :logproposal) && haskey(y.params, :unnormalizedweights)
z.params[:unnormalizedweights] = w_raw_x.*y.params[:unnormalizedweights] # m4(s)*m2(s)
logProposal = y.params[:logproposal] # m1
z.params[:logproposal] = logProposal # m1
# calculate entropy
H_y = log(sum(w_raw_x.*y.params[:unnormalizedweights])) - log(n_samples) # log(sum_i(m4(s_i)*m2(s_i))/N)
# -sum_i(w_i*log(m1(s_i)*m2(s_i)*m4(s_i)))
H_x = -sum( weights.*(logProposal(samples) + log.(y.params[:unnormalizedweights]) + log_samples_x) )
entropy = H_x + H_y
z.params[:entropy] = entropy
end

return z
end
Expand Down Expand Up @@ -197,7 +227,11 @@ function sampleWeightsAndEntropy(x::ProbabilityDistribution, y::ProbabilityDistr
H_x = -sum( weights.*(log_samples_x + log_samples_y) )
entropy = H_x + H_y

return (samples, weights, entropy)
# Inform next step about the proposal and integrand to be used in entropy calculation in smoothing
logproposal = (samples) -> logPdf.([x], samples)
logintegrand = (samples) -> logPdf.([y], samples)

return (samples, weights, w_raw, logproposal, logintegrand, entropy)
end

# General product definition that returns a SampleList
Expand All @@ -206,10 +240,13 @@ function prod!(
y::ProbabilityDistribution{V},
z::ProbabilityDistribution{V, SampleList} = ProbabilityDistribution(V, SampleList, s=[0.0], w=[1.0])) where {V<:VariateType}

(samples, weights, entropy) = sampleWeightsAndEntropy(x, y)
(samples, weights, unnormalizedweights, logproposal, logintegrand, entropy) = sampleWeightsAndEntropy(x, y)

z.params[:s] = samples
z.params[:w] = weights
z.params[:unnormalizedweights] = unnormalizedweights
z.params[:logproposal] = logproposal
z.params[:logintegrand] = logintegrand
z.params[:entropy] = entropy

return z
Expand All @@ -221,10 +258,13 @@ end
y::ProbabilityDistribution{Multivariate, Function},
z::ProbabilityDistribution{Univariate, SampleList} = ProbabilityDistribution(Univariate, SampleList, s=[0.0], w=[1.0]))

(samples, weights, entropy) = sampleWeightsAndEntropy(x, y)
(samples, weights, unnormalizedweights, logproposal, logintegrand, entropy) = sampleWeightsAndEntropy(x, y)

z.params[:s] = samples
z.params[:w] = weights
z.params[:unnormalizedweights] = unnormalizedweights
z.params[:logproposal] = logproposal
z.params[:logintegrand] = logintegrand
z.params[:entropy] = entropy

return z
Expand Down
1 change: 1 addition & 0 deletions test/factor_nodes/test_nonlinear_unscented.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ end
@testset "concatenateGaussianMV" begin
@test concatenateGaussianMV([1.0, 2.0, 3.0], [4.0, 5.0, 6.0]) == ([1.0, 2.0, 3.0], Diagonal([4.0, 5.0, 6.0]), ones(Int64, 3))
@test concatenateGaussianMV([[1.0], [2.0, 3.0]], [mat(4.0), Diagonal([5.0, 6.0])]) == ([1.0, 2.0, 3.0], [4.0 0.0 0.0; 0.0 5.0 0.0; 0.0 0.0 6.0], [1, 2])
@test concatenateGaussianMV([1.0, [2.0, 3.0]], [4.0, Diagonal([5.0, 6.0])]) == ([1.0, 2.0, 3.0], [4.0 0.0 0.0; 0.0 5.0 0.0; 0.0 0.0 6.0], [1, 2])
end

@testset "split" begin
Expand Down
82 changes: 78 additions & 4 deletions test/factor_nodes/test_sample_list.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,45 @@ end
@test unsafeCov(ProbabilityDistribution(MatrixVariate, SampleList, s=[eye(2), eye(2)], w=[0.5, 0.5])) == zeros(4,4)
end

@testset "prod!" begin
@test ProbabilityDistribution(Univariate, SampleList, s=[0.0, 1.0], w=[0.5, 0.5]) * ProbabilityDistribution(Univariate, GaussianMeanVariance, m=0.0, v=1.0) == ProbabilityDistribution(Univariate, SampleList, s=[0.0, 1.0], w=[0.6224593312018546, 0.37754066879814546])
@test ProbabilityDistribution(Multivariate, SampleList, s=[[0.0], [1.0]], w=[0.5, 0.5]) * ProbabilityDistribution(Multivariate, GaussianMeanVariance, m=[0.0], v=mat(1.0)) == ProbabilityDistribution(Multivariate, SampleList, s=[[0.0], [1.0]], w=[0.6224593312018546, 0.37754066879814546])
@testset "Univariate SampleList prod!" begin
dist_gaussian = ProbabilityDistribution(Univariate, GaussianMeanVariance, m=0.0, v=1.0)
dist_prod = dist_gaussian * ProbabilityDistribution(Univariate, SampleList, s=[0.0, 1.0], w=[0.5, 0.5])
dist_true_prod = ProbabilityDistribution(Univariate, SampleList, s=[0.0, 1.0], w=[0.6224593312018546, 0.37754066879814546])

@test dist_prod.params[:s] == dist_true_prod.params[:s]
@test dist_prod.params[:w] == dist_true_prod.params[:w]
@test dist_prod.params[:logintegrand]([-0.7, 1.5]) == logPdf.([dist_gaussian],[-0.7, 1.5])
end

@testset "Multivariate SampleList prod!" begin
dist_gaussian = ProbabilityDistribution(Multivariate, GaussianMeanVariance, m=[0.0], v=mat(1.0))
dist_prod = dist_gaussian * ProbabilityDistribution(Multivariate, SampleList, s=[[0.0], [1.0]], w=[0.5, 0.5])
dist_true_prod = ProbabilityDistribution(Multivariate, SampleList, s=[[0.0], [1.0]], w=[0.6224593312018546, 0.37754066879814546])

@test dist_prod.params[:s] == dist_true_prod.params[:s]
@test dist_prod.params[:w] == dist_true_prod.params[:w]
@test dist_prod.params[:logintegrand]([[-0.7], [1.5]]) == logPdf.([dist_gaussian],[[-0.7], [1.5]])
end

@testset "Initialization of logproposal, logintegrand, unnormalizedweights" begin
dist_beta = ProbabilityDistribution(Univariate, Beta, a=2.0, b=1.0)
dist_gamma = ProbabilityDistribution(Univariate, Gamma, a=2.0, b=1.0)
dist_sample = ProbabilityDistribution(Univariate, SampleList, s=[0.0, 1.0], w=[0.5, 0.5])
dist_betagamma = dist_beta*dist_gamma
dist_gammasample = dist_gamma*dist_sample
dist_betagg = dist_betagamma*dist_gamma

@test haskey(dist_betagamma.params, :logproposal)
@test haskey(dist_betagamma.params, :logintegrand)
@test haskey(dist_betagamma.params, :unnormalizedweights)

@test !haskey(dist_gammasample.params, :logproposal)
@test haskey(dist_gammasample.params, :logintegrand)
@test !haskey(dist_gammasample.params, :unnormalizedweights)

@test haskey(dist_betagg.params, :logproposal)
@test haskey(dist_betagg.params, :logintegrand)
@test haskey(dist_betagg.params, :unnormalizedweights)
end

@testset "bootstrap" begin
Expand All @@ -48,7 +84,7 @@ end

p1 = ProbabilityDistribution(Multivariate, GaussianMeanVariance, m=[2.0], v=mat(0.0))
p2 = ProbabilityDistribution(MatrixVariate, SampleList, s=[mat(tiny)], w=[1.0])
@test isapprox(bootstrap(p1, p2)[1][1], 2.0, atol=1e-4)
@test isapprox(bootstrap(p1, p2)[1][1], 2.0, atol=1e-4)
end


Expand Down Expand Up @@ -109,4 +145,42 @@ end
@test marginals[:y] == ProbabilityDistribution(Univariate, SampleList, s=collect(1.0:4.0), w=ones(4)/4)
end

@testset "Differential Entropy estimate" begin
fg = FactorGraph()

@RV x ~ Gamma(3.4, 1.2)
@RV s1 ~ Nonlinear{Sampling}(x, g=g)
@RV y1 ~ Poisson(s1)
@RV y2 ~ Poisson(x)

@RV z ~ Gamma(3.4, 1.2)
@RV s2 ~ Nonlinear{Sampling}(z, g=g)
@RV y3 ~ Poisson(z)
@RV y4 ~ Poisson(s2)

@RV w ~ Gamma(3.4, 1.2)
@RV y5 ~ Poisson(w)
@RV y6 ~ Poisson(w)

placeholder(y1,:y1)
placeholder(y2,:y2)
placeholder(y3,:y3)
placeholder(y4,:y4)
placeholder(y5,:y5)
placeholder(y6,:y6)

pfz = PosteriorFactorization(fg)
algo = messagePassingAlgorithm([x,z,w])
source_code = algorithmSourceCode(algo)
eval(Meta.parse(source_code));

v_1, v_2 = 7, 5
data = Dict(:y1 => v_1, :y2 => v_2, :y3 => v_1, :y4 => v_2, :y5 => v_1, :y6 => v_2)
marginals = step!(data)

@test isapprox(marginals[:x].params[:entropy], differentialEntropy(marginals[:w]), atol=0.1)
@test isapprox(marginals[:z].params[:entropy], differentialEntropy(marginals[:w]), atol=0.1)

end

end