Skip to content

Commit

Permalink
optionally check convergence on beliefs
Browse files Browse the repository at this point in the history
  • Loading branch information
stecrotti committed Oct 20, 2023
1 parent 9c63711 commit 93da510
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 35 deletions.
2 changes: 1 addition & 1 deletion src/BeliefPropagation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ export BPFactor, TabulatedBPFactor, rand_factor
export FactorGraph, variables, factors, nvariables, nfactors
export BP, reset!, nstates, evaluate, energy
export iterate!, beliefs, factor_beliefs, avg_energy, bethe_free_energy
export damp!
export message_convergence, belief_convergence
export update_f_bp!, update_v_bp!, beliefs_bp, factor_beliefs_bp, avg_energy_bp
export update_f_ms!, update_v_ms!, beliefs_ms, factor_beliefs_ms, iterate_ms!,
avg_energy_ms, bethe_free_energy_ms
Expand Down
28 changes: 16 additions & 12 deletions src/Models/ising.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,24 +67,26 @@ const BPIsing = BP{<:IsingCoupling, <:IsingField, <:Real, <:Real}
BeliefPropagation.nstates(bp::BPIsing, ::Integer) = 2

function BeliefPropagation.update_v_bp!(bp::BPIsing,
i::Integer, hnew, damp::Real, rein::Real,
i::Integer, hnew, bnew, damp::Real, rein::Real,
f::AtomicVector{<:Real}; extra_kwargs...)
(; g, ϕ, u, h, b) = bp
∂i = outedges(g, variable(i))
hᵢ = ϕ[i].βh + b[i]*rein
hnew[idx.(∂i)], b[i] = cavity(u[idx.(∂i)], +, hᵢ)
hnew[idx.(∂i)], bnew[i] = cavity(u[idx.(∂i)], +, hᵢ)
cout, cfull = cavity(2cosh.(u[idx.(∂i)]), *, 1.0)
d = (degree(g, factor(a)) for a in neighbors(g, variable(i)))
err = -Inf
errv = -Inf
for ((_,_,id), dₐ, c) in zip(∂i, d, cout)
zᵢ₂ₐ = 2cosh(hnew[id]) / c
f[i] -= log(zᵢ₂ₐ) * (1 - 1/dₐ)
err = max(err, abs(hnew[id] - h[id]))
errv = max(errv, abs(hnew[id] - h[id]))
h[id] = damp!(h[id], hnew[id], damp)
end
zᵢ = 2cosh(b[i]) / cfull
errb = abs(bnew[i] - b[i])
zᵢ = 2cosh(bnew[i]) / cfull
f[i] -= log(zᵢ) * (1 - degree(g, variable(i)) + sum(1/dₐ for dₐ in d; init=0.0))
return err
b[i] = bnew[i]
return errv, errb
end

function BeliefPropagation.update_f_bp!(bp::BPIsing, a::Integer,
Expand Down Expand Up @@ -129,24 +131,26 @@ function BeliefPropagation.factor_beliefs_bp(bp::BPIsing)
end

function BeliefPropagation.update_v_ms!(bp::BPIsing,
i::Integer, hnew, damp::Real, rein::Real,
i::Integer, hnew, bnew, damp::Real, rein::Real,
f::AtomicVector{<:Real}; extra_kwargs...)
(; g, ϕ, u, h, b) = bp
∂i = outedges(g, variable(i))
hᵢ = ϕ[i].βh + b[i]*rein
hnew[idx.(∂i)], b[i] = cavity(u[idx.(∂i)], +, hᵢ)
hnew[idx.(∂i)], bnew[i] = cavity(u[idx.(∂i)], +, hᵢ)
cout, cfull = cavity(abs.(u[idx.(∂i)]), +, 0.0)
d = (degree(g, factor(a)) for a in neighbors(g, variable(i)))
err = -Inf
errv = -Inf
for ((_,_,id), dₐ, c) in zip(∂i, d, cout)
fᵢ₂ₐ = abs(hnew[id]) - c
f[i] -= fᵢ₂ₐ * (1 - 1/dₐ)
err = max(err, abs(hnew[id] - h[id]))
errv = max(errv, abs(hnew[id] - h[id]))
h[id] = damp!(h[id], hnew[id], damp)
end
fᵢ = abs(b[i]) - cfull
errb = abs(bnew[i] - b[i])
fᵢ = abs(bnew[i]) - cfull
f[i] -= fᵢ * (1 - degree(g, variable(i)) + sum(1/dₐ for dₐ in d; init=0.0))
return err
b[i] = bnew[i]
return errv, errb
end

function BeliefPropagation.update_f_ms!(bp::BPIsing, a::Integer,
Expand Down
43 changes: 27 additions & 16 deletions src/bp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,12 +196,21 @@ abstract type ConvergenceChecker end
struct MessageConvergence{T<:Real} <: ConvergenceChecker
tol :: T
end
msg_convergence(tol::Real) = MessageConvergence(tol)
message_convergence(tol::Real) = MessageConvergence(tol)

function (check_convergence::MessageConvergence)(bp::BP, errv, errf)
function (check_convergence::MessageConvergence)(::BP, errv, errf, errb)
max(maximum(errv), maximum(errf)) < check_convergence.tol
end

struct BeliefConvergence{T<:Real} <: ConvergenceChecker
tol :: T
end
belief_convergence(tol::Real) = BeliefConvergence(tol)

function (check_convergence::BeliefConvergence)(::BP, errv, errf, errb)
maximum(errb) < check_convergence.tol
end

"""
iterate!(bp::BP; kwargs...)
Expand All @@ -224,24 +233,24 @@ Optional arguments
function iterate!(bp::BP; update_variable! = update_v_bp!, update_factor! = update_f_bp!,
maxiter=100, tol=1e-6, damp::Real=0.0, rein::Real=0.0,
f::AbstractVector{<:Real} = zeros(nvariables(bp.g)),
callback = (bp, errv, errf, it, f) -> nothing,
check_convergence=msg_convergence(tol),
callback = (bp, errv, errf, errb, it, f) -> nothing,
check_convergence=message_convergence(tol),
extra_kwargs...
)
(; g, u, h) = bp
unew = copy(u); hnew = copy(h)
errv = zeros(nvariables(g)); errf = zeros(nfactors(g))
(; g, u, h, b) = bp
unew = copy(u); hnew = copy(h); bnew = copy(b)
errv = zeros(nvariables(g)); errf = zeros(nfactors(g)); errb = zeros(nvariables(g))
ff = AtomicVector(f)
for it in 1:maxiter
ff .= 0
@threads for i in variables(bp.g)
errv[i] = update_variable!(bp, i, hnew, damp, rein*it, ff; extra_kwargs...)
errv[i], errb[i] = update_variable!(bp, i, hnew, bnew, damp, rein*it, ff; extra_kwargs...)
end
@threads for a in factors(bp.g)
errf[a] = update_factor!(bp, a, unew, damp, ff; extra_kwargs...)
end
callback(bp, errv, errf, it, f)
check_convergence(bp, errv, errf) && return it
callback(bp, errv, errf, errb, it, f)
check_convergence(bp, errv, errf, errb) && return it
end
return maxiter
end
Expand All @@ -263,27 +272,29 @@ function damp!(x::T, xnew::T, damp::Real) where {T<:AbstractVector}
return x
end

function update_v_bp!(bp::BP{F,FV,M,MB}, i::Integer, hnew, damp::Real, rein::Real,
function update_v_bp!(bp::BP{F,FV,M,MB}, i::Integer, hnew, bnew, damp::Real, rein::Real,
f::AtomicVector{<:Real}; extra_kwargs...) where {
F<:BPFactor, FV<:BPFactor, M<:AbstractVector{<:Real}, MB<:AbstractVector{<:Real}}
(; g, ϕ, u, h, b) = bp
∂i = outedges(g, variable(i))
ϕᵢ = [ϕ[i](x) * b[i][x]^rein for x in 1:nstates(bp, i)]
msg_mult(m1, m2) = m1 .* m2
hnew[idx.(∂i)], b[i] = cavity(u[idx.(∂i)], msg_mult, ϕᵢ)
hnew[idx.(∂i)], bnew[i] = cavity(u[idx.(∂i)], msg_mult, ϕᵢ)
d = (degree(g, factor(a)) for a in neighbors(g, variable(i)))
err = -Inf
errv = -Inf
for ((_,_,id), dₐ) in zip(∂i, d)
zᵢ₂ₐ = sum(hnew[id])
f[i] -= log(zᵢ₂ₐ) * (1 - 1/dₐ)
hnew[id] ./= zᵢ₂ₐ
err = max(err, mean(abs, hnew[id] - h[id]))
errv = max(errv, mean(abs, hnew[id] - h[id]))
h[id] = damp!(h[id], hnew[id], damp)
end
zᵢ = sum(b[i])
errb = mean(abs, bnew[i] - b[i])
zᵢ = sum(bnew[i])
f[i] -= log(zᵢ) * (1 - degree(g, variable(i)) + sum(1/dₐ for dₐ in d; init=0.0))
b[i] = bnew[i]
b[i] ./= zᵢ
return err
return errv, errb
end

function update_f_bp!(bp::BP{F,FV,M,MB}, a::Integer, unew, damp::Real,
Expand Down
14 changes: 8 additions & 6 deletions src/maxsum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,28 @@ function iterate_ms!(bp::BP; kwargs...)
kwargs...)
end

function update_v_ms!(bp::BP, i::Integer, hnew, damp::Real, rein::Real,
function update_v_ms!(bp::BP, i::Integer, hnew, bnew, damp::Real, rein::Real,
f::AtomicVector{<:Real}; extra_kwargs...)
(; g, ϕ, u, h, b) = bp
∂i = outedges(g, variable(i))
logϕᵢ = [log(ϕ[i](x)) + b[i][x]*rein for x in 1:nstates(bp, i)]
msg_sum(m1, m2) = m1 .+ m2
hnew[idx.(∂i)], b[i] = cavity(u[idx.(∂i)], msg_sum, logϕᵢ)
hnew[idx.(∂i)], bnew[i] = cavity(u[idx.(∂i)], msg_sum, logϕᵢ)
d = (degree(g, factor(a)) for a in neighbors(g, variable(i)))
err = -Inf
errv = -Inf
for ((_,_,id), dₐ) in zip(∂i, d)
fᵢ₂ₐ = maximum(hnew[id])
f[i] -= fᵢ₂ₐ * (1 - 1/dₐ)
hnew[id] .-= fᵢ₂ₐ
err = max(err, mean(abs, hnew[id] - h[id]))
errv = max(errv, mean(abs, hnew[id] - h[id]))
h[id] = damp!(h[id], hnew[id], damp)
end
fᵢ = maximum(b[i])
errb = mean(abs, bnew[i] - b[i])
fᵢ = maximum(bnew[i])
f[i] -= fᵢ * (1 - degree(g, variable(i)) + sum(1/dₐ for dₐ in d; init=0.0))
b[i] = bnew[i]
b[i] .-= fᵢ
return err
return errv, errb
end

function update_f_ms!(bp::BP, a::Integer, unew, damp::Real, f::AtomicVector{<:Real};
Expand Down
10 changes: 10 additions & 0 deletions test/bp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,14 @@ end
iterate!(bp; maxiter=100, tol=0.0)
b = beliefs(bp)
test_observables(bp)
end

@testset "Convergence of beliefs" begin
n = 10
g = rand_tree_factor_graph(n)
qs = rand(rng, 2:4, nvariables(g))
bp = rand_bp(rng, g, qs)
iterate!(bp; maxiter=100, check_convergence=belief_convergence(1e-12))
b = beliefs(bp)
test_observables(bp)
end

0 comments on commit 93da510

Please sign in to comment.