Skip to content

Commit

Permalink
dispatch fisher_divergence on Integrator type
Browse files Browse the repository at this point in the history
  • Loading branch information
joannajzou committed May 21, 2024
1 parent 215e250 commit 8b6bafb
Showing 1 changed file with 39 additions and 17 deletions.
56 changes: 39 additions & 17 deletions src/PotentialLearningExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,47 @@ function compute_divergence(
p::Gibbs,
q::Gibbs,
d::FisherDivergence,
)
return fisher_divergence(p, q, d.int)
end

function fisher_divergence(
p::Gibbs,
q::Gibbs,
int::QuadIntegrator,
)
sp(x) = gradlogpdf(p, x)
sq(x) = gradlogpdf(q, x)

if typeof(d.int) <: QuadIntegrator
Zp = normconst(p, d.int)
h(x) = updf(p, x)/Zp .* norm(sp(x) - sq(x))^2
return sum(d.int.w .* h.(d.int.ξ))

elseif typeof(d.int) <: MCMC
xsamp = rand(p, d.int.n, d.int.sampler, d.int.ρ0)
h = x -> norm(sp(x) - sq(x))^2
return sum(h.(xsamp)) / length(xsamp)

elseif typeof(d.int) <: MCSamples
h = x -> norm(sp(x) - sq(x))^2
return sum(h.(d.int.xsamp)) / length(d.int.xsamp)

end
end
Zp = normconst(p, int)
h(x) = updf(p, x)/Zp .* norm(sp(x) - sq(x))^2
return sum(int.w .* h.(int.ξ))
end

function fisher_divergence(
p::Gibbs,
q::Gibbs,
int::MCMC,
)
sp(x) = gradlogpdf(p, x)
sq(x) = gradlogpdf(q, x)

xsamp = rand(p, int.n, int.sampler, int.ρ0)
h = x -> norm(sp(x) - sq(x))^2
return sum(h.(xsamp)) / length(xsamp)
end

function fisher_divergence(
p::Gibbs,
q::Gibbs,
int::MCSamples,
)
sp(x) = gradlogpdf(p, x)
sq(x) = gradlogpdf(q, x)

h = x -> norm(sp(x) - sq(x))^2
return sum(h.(int.xsamp)) / length(int.xsamp)
end


export FisherDivergence
export FisherDivergence, fisher_divergence

0 comments on commit 8b6bafb

Please sign in to comment.