Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 67 additions & 1 deletion src/GraphicalModelLearning.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ module GraphicalModelLearning

export learn, inverse_ising

export GMLFormulation, RISE, logRISE, RPLE, RISEA
export GMLFormulation, RISE, logRISE, RPLE, RISEA, multiRISE
export GMLMethod, NLP

using JuMP
Expand All @@ -19,6 +19,14 @@ include("sampling.jl")

@compat abstract type GMLFormulation end

type multiRISE <: GMLFormulation
regularizer::Real
symmetrization::Bool
interaction_order::Integer
end
# default values
multiRISE() = multiRISE(0.4, true, 2)

type RISE <: GMLFormulation
regularizer::Real
symmetrization::Bool
Expand Down Expand Up @@ -69,6 +77,64 @@ function data_info{T <: Real}(samples::Array{T,2})
return num_conf, num_spins, num_samples
end

function learn{T <: Real}(samples::Array{T,2}, formulation::multiRISE, method::NLP)
num_conf, num_spins, num_samples = data_info(samples)

lambda = formulation.regularizer*sqrt(log((num_spins^2)/0.05)/num_samples)
inter_order = formulation.interaction_order

reconstruction = Dict{Tuple,Real}()

for current_spin = 1:num_spins
nodal_stat = Dict{Tuple,Array{Real,1}}()

for p = 1:inter_order
nodal_keys = Array{Tuple{},1}()
neighbours = [i for i=1:num_spins if i!=current_spin]
if p == 1
nodal_keys = [(current_spin,)]
else
perm = permutations(neighbours, p - 1)
if length(perm) > 0
nodal_keys = [(current_spin, perm[i]...) for i=1:length(perm)]
end
end

for index = 1:length(nodal_keys)
nodal_stat[nodal_keys[index]] = [ prod(samples[k, 1 + i] for i=nodal_keys[index]) for k=1:num_conf]
end
end

m = Model(solver = method.solver)

@variable(m, x[keys(nodal_stat)])
@variable(m, z[keys(nodal_stat)])

@NLobjective(m, Min,
sum((samples[k,1]/num_samples)*exp(-sum(x[inter]*stat[k] for (inter,stat) = nodal_stat)) for k=1:num_conf) +
lambda*sum(z[inter] for inter = keys(nodal_stat) if length(inter)>1)
)

for inter in keys(nodal_stat)
@constraint(m, z[inter] >= x[inter]) #z_plus
@constraint(m, z[inter] >= -x[inter]) #z_minus
end

status = solve(m)
@assert status == :Optimal


nodal_reconstruction = getvalue(x)
for inter = keys(nodal_stat)
reconstruction[inter] = deepcopy(nodal_reconstruction[inter])
end
end
#if formulation.symmetrization
# reconstruction = 0.5*(reconstruction + transpose(reconstruction))
#end

return reconstruction
end

function learn{T <: Real}(samples::Array{T,2}, formulation::RISE, method::NLP)
num_conf, num_spins, num_samples = data_info(samples)
Expand Down
63 changes: 51 additions & 12 deletions src/models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ FactorGraph{T <: Real}(matrix::Array{T,2}) = convert(FactorGraph{T}, matrix)
function check_model_data{T <: Real}(order::Int, varible_count::Int, alphabet::Symbol, terms::Dict{Tuple,T}, variable_names::Nullable{Vector{String}})
if !in(alphabet, alphabets)
error("alphabet $(alphabet) is not supported")
return false
return false
end
if !isnull(variable_names) && length(variable_names) != varible_count
error("expected $(varible_count) but only given $(length(variable_names))")
return false
return false
end
for (k,v) in terms
if length(k) != order
if length(k) > order
error("a term has $(length(k)) indices but should have $(order) indices")
return false
end
Expand Down Expand Up @@ -95,6 +95,14 @@ function Base.convert{T <: Real}(::Type{FactorGraph{T}}, m::Array{T,2})
varible_count = size(m,1)

terms = Dict{Tuple,T}()

for key in permutations(1:varible_count, 1)
weight = m[key..., key...]
if !isapprox(weight, 0.0)
terms[key] = weight
end
end

for key in permutations(1:varible_count, 2)
weight = m[key...]
if !isapprox(weight, 0.0)
Expand All @@ -118,33 +126,64 @@ function Base.convert{T <: Real}(::Type{Array{T,2}}, gm::FactorGraph{T})

matrix = zeros(gm.varible_count, gm.varible_count)
for (k,v) in gm
matrix[k...] = v
r = reverse(k)
matrix[r...] = v
if length(k) == 1
matrix[k..., k...] = v
else
matrix[k...] = v
r = reverse(k)
matrix[r...] = v
end
end

return matrix
end


Base.convert{T <: Real}(::Type{Dict}, m::Array{T,2}) = convert(Dict{Tuple,T}, m)
function Base.convert{T <: Real}(::Type{Dict{Tuple,T}}, m::Array{T,2})
@assert size(m,1) == size(m,2) #check matrix is square

varible_count = size(m,1)

terms = Dict{Tuple,T}()

for key in permutations(1:varible_count, 1)
weight = m[key..., key...]
if !isapprox(weight, 0.0)
terms[key] = weight
end
end

for key in permutations(1:varible_count, 2, asymmetric=true)
if key[1] != key[2]
weight = m[key...]
if !isapprox(weight, 0.0)
terms[key] = weight
end
end
end

return terms
end



permutations(items, order::Int; asymmetric::Bool = false) = sort(permutations([], items, order, asymmetric))

function permutations(partical_perm::Array{Any,1}, items, order::Int, asymmetric::Bool)
function permutations(partial_perm::Array{Any,1}, items, order::Int, asymmetric::Bool)
if order == 0
return [tuple(partical_perm...)]
return [tuple(partial_perm...)]
else
perms = []
for item in items
if !asymmetric && length(partical_perm) > 0
if partical_perm[end] < item
if !asymmetric && length(partial_perm) > 0
if partial_perm[end] >= item
continue
end
end
perm = permutations(vcat([item], partical_perm), items, order-1, asymmetric)
perm = permutations(vcat(partial_perm, item), items, order-1, asymmetric)
append!(perms, perm)
end
return perms
end
end

2 changes: 1 addition & 1 deletion src/sampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ bool_to_spin(bool::Int) = 2*bool-1
function weigh_proba{T <: Real}(int_representation::Int, adj::Array{T,2}, prior::Array{T,1}, spins::Array{Int,1})
digits!(spins, int_representation, 2)
spins .= bool_to_spin.(spins)
return exp(((0.5) * spins' * adj * spins + prior' * spins)[1])
return exp((0.5 * spins' * adj * spins + prior' * spins)[1])
end


Expand Down
39 changes: 38 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@ include("common.jl")
gm2 = FactorGraph(matrix)
for key in keys(gm)
@test isapprox(gm[key], gm2[key])
@test isapprox(gm[key], matrix[key...])
if length(key) == 1
@test isapprox(gm[key], matrix[key..., key...])
else
@test isapprox(gm[key], matrix[key...])
end
end
end
end
Expand All @@ -22,6 +26,7 @@ end
srand(0) # fix random number generator
samples = sample(gm, gibbs_test_samples)
base_samples = readcsv("data/$(name)_samples.csv")
#println(name)
#println(base_samples)
#println(samples)
#println(abs.(base_samples-samples))
Expand Down Expand Up @@ -110,6 +115,37 @@ srand(0) # fix random number generator
end


@testset "inverse multi-body formulations" begin

for (name, gm) in gms
samples = readcsv("data/$(name)_samples.csv")
srand(0) # fix random number generator
learned_ising = learn(samples, RISE(0.2, false))
learned_two_body = learn(samples, multiRISE(0.2, false, 2))

learned_ising_dict = convert(Dict, learned_ising)
#println(learned_ising_dict)
#println(learned_two_body)

@test length(learned_ising_dict) == length(learned_two_body)
for (key, value) in learned_ising_dict
@test isapprox(learned_two_body[key], value)
end
end

samples = readcsv("data/mvt_samples.csv")

rand(0) # fix random number generator
learned_ising = learn(samples, RISE(0.2, false), NLP(IpoptSolver(print_level=0)))
learned_two_body = learn(samples, multiRISE(0.2, false, 2), NLP(IpoptSolver(print_level=0)))

learned_ising_dict = convert(Dict, learned_ising)
@test length(learned_ising_dict) == length(learned_two_body)
for (key, value) in learned_ising_dict
@test isapprox(learned_two_body[key], value)
end
end


srand(0) # fix random number generator
@testset "docs example" begin
Expand All @@ -120,3 +156,4 @@ srand(0) # fix random number generator
err = abs.(convert(Array{Float64,2}, model) - learned)
@test maximum(err) <= 0.01
end