Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A more universal entropy calculation method for sampling based inference #151

Merged
merged 11 commits into from
Feb 23, 2021
Merged
Prev Previous commit
Next Next commit
Initial tests for entropy calculation and new SampleList features
  • Loading branch information
semihakbayrak committed Jan 29, 2021
commit 5b4918737314abd1f3b980dec774803bdf3cb94e
11 changes: 5 additions & 6 deletions src/factor_nodes/sample_list.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,9 @@ isProper(dist::ProbabilityDistribution{V, SampleList}) where V<:VariateType = ab
y::ProbabilityDistribution{V, SampleList},
z::ProbabilityDistribution{V, SampleList}=ProbabilityDistribution(V, SampleList, s=[0.0], w=[1.0])) where V<:VariateType

logIntegrand = (samples) -> logPdf.([x], samples)
if haskey(y.params, :logintegrand)
logIntegrand(samples) = y.params[:logintegrand](samples) .+ logPdf.([x], samples)
else
logIntegrand(samples) = logPdf.([x], samples)
logIntegrand = (samples) -> y.params[:logintegrand](samples) .+ logPdf.([x], samples)
end

samples = y.params[:s]
Expand All @@ -156,11 +155,11 @@ isProper(dist::ProbabilityDistribution{V, SampleList}) where V<:VariateType = ab

z.params[:w] = weights
z.params[:s] = samples
z.params[:logintegrand] = (s) -> logIntegrand(s)
z.params[:logintegrand] = logIntegrand
if haskey(y.params, :logproposal) && haskey(y.params, :unnormalizedweights)
z.params[:unnormalizedweights] = w_raw_x.*y.params[:unnormalizedweights]
logProposal(s) = y.params[:logproposal](s)
z.params[:logproposal] = (s) -> logProposal(s)
logProposal = y.params[:logproposal]
z.params[:logproposal] = logProposal
# calculate entropy
H_y = log(sum(w_raw_x.*y.params[:unnormalizedweights])) - log(n_samples)
H_x = -sum( weights.*(logProposal(samples) + log.(y.params[:unnormalizedweights]) + log_samples_x) )
Expand Down
38 changes: 35 additions & 3 deletions test/factor_nodes/test_sample_list.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,40 @@ end
end

@testset "prod!" begin
@test ProbabilityDistribution(Univariate, SampleList, s=[0.0, 1.0], w=[0.5, 0.5]) * ProbabilityDistribution(Univariate, GaussianMeanVariance, m=0.0, v=1.0) == ProbabilityDistribution(Univariate, SampleList, s=[0.0, 1.0], w=[0.6224593312018546, 0.37754066879814546])
@test ProbabilityDistribution(Multivariate, SampleList, s=[[0.0], [1.0]], w=[0.5, 0.5]) * ProbabilityDistribution(Multivariate, GaussianMeanVariance, m=[0.0], v=mat(1.0)) == ProbabilityDistribution(Multivariate, SampleList, s=[[0.0], [1.0]], w=[0.6224593312018546, 0.37754066879814546])
dist1 = ProbabilityDistribution(Univariate, GaussianMeanVariance, m=0.0, v=1.0)
dist2 = dist1 * ProbabilityDistribution(Univariate, SampleList, s=[0.0, 1.0], w=[0.5, 0.5])
dist3 = ProbabilityDistribution(Univariate, SampleList, s=[0.0, 1.0], w=[0.6224593312018546, 0.37754066879814546])
@test dist2.params[:s] == dist3.params[:s]
@test dist2.params[:w] == dist3.params[:w]
@test dist2.params[:logintegrand]([-0.7, 1.5]) == logPdf.([dist1],[-0.7, 1.5])
#@test ProbabilityDistribution(Multivariate, SampleList, s=[[0.0], [1.0]], w=[0.5, 0.5]) * ProbabilityDistribution(Multivariate, GaussianMeanVariance, m=[0.0], v=mat(1.0)) == ProbabilityDistribution(Multivariate, SampleList, s=[[0.0], [1.0]], w=[0.6224593312018546, 0.37754066879814546])
dist1 = ProbabilityDistribution(Multivariate, GaussianMeanVariance, m=[0.0], v=mat(1.0))
dist2 = ProbabilityDistribution(Multivariate, SampleList, s=[[0.0], [1.0]], w=[0.5, 0.5]) * dist1
dist3 = ProbabilityDistribution(Multivariate, SampleList, s=[[0.0], [1.0]], w=[0.6224593312018546, 0.37754066879814546])
@test dist2.params[:s] == dist3.params[:s]
@test dist2.params[:w] == dist3.params[:w]
@test dist2.params[:logintegrand]([[-0.7], [1.5]]) == logPdf.([dist1],[[-0.7], [1.5]])
end

@testset "Initialization of logproposal, logintegrand, unnormalizedweights" begin
dist1 = ProbabilityDistribution(Univariate, Beta, a=2.0, b=1.0)
dist2 = ProbabilityDistribution(Univariate, Gamma, a=2.0, b=1.0)
dist3 = ProbabilityDistribution(Univariate, SampleList, s=[0.0, 1.0], w=[0.5, 0.5])
dist4 = dist1*dist2
dist5 = dist2*dist3
dist6 = dist4*dist2

@test haskey(dist4.params, :logproposal)
@test haskey(dist4.params, :logintegrand)
@test haskey(dist4.params, :unnormalizedweights)

@test !haskey(dist5.params, :logproposal)
@test haskey(dist5.params, :logintegrand)
@test !haskey(dist5.params, :unnormalizedweights)

@test haskey(dist6.params, :logproposal)
@test haskey(dist6.params, :logintegrand)
@test haskey(dist6.params, :unnormalizedweights)
end

@testset "bootstrap" begin
Expand All @@ -48,7 +80,7 @@ end

p1 = ProbabilityDistribution(Multivariate, GaussianMeanVariance, m=[2.0], v=mat(0.0))
p2 = ProbabilityDistribution(MatrixVariate, SampleList, s=[mat(tiny)], w=[1.0])
@test isapprox(bootstrap(p1, p2)[1][1], 2.0, atol=1e-4)
@test isapprox(bootstrap(p1, p2)[1][1], 2.0, atol=1e-4)
end


Expand Down