Skip to content

Commit 3ed16a4

Browse files
authored
Merge pull request #15 from lanl-ansi/multi-interactions
multiRISE passed reconstruction tests on synthetic examples
2 parents 081705d + beb73c5 commit 3ed16a4

File tree

4 files changed

+157
-15
lines changed

4 files changed

+157
-15
lines changed

src/GraphicalModelLearning.jl

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ module GraphicalModelLearning
44

55
export learn, inverse_ising
66

7-
export GMLFormulation, RISE, logRISE, RPLE, RISEA
7+
export GMLFormulation, RISE, logRISE, RPLE, RISEA, multiRISE
88
export GMLMethod, NLP
99

1010
using JuMP
@@ -19,6 +19,14 @@ include("sampling.jl")
1919

2020
@compat abstract type GMLFormulation end
2121

22+
type multiRISE <: GMLFormulation
23+
regularizer::Real
24+
symmetrization::Bool
25+
interaction_order::Integer
26+
end
27+
# default values
28+
multiRISE() = multiRISE(0.4, true, 2)
29+
2230
type RISE <: GMLFormulation
2331
regularizer::Real
2432
symmetrization::Bool
@@ -69,6 +77,64 @@ function data_info{T <: Real}(samples::Array{T,2})
6977
return num_conf, num_spins, num_samples
7078
end
7179

80+
function learn{T <: Real}(samples::Array{T,2}, formulation::multiRISE, method::NLP)
81+
num_conf, num_spins, num_samples = data_info(samples)
82+
83+
lambda = formulation.regularizer*sqrt(log((num_spins^2)/0.05)/num_samples)
84+
inter_order = formulation.interaction_order
85+
86+
reconstruction = Dict{Tuple,Real}()
87+
88+
for current_spin = 1:num_spins
89+
nodal_stat = Dict{Tuple,Array{Real,1}}()
90+
91+
for p = 1:inter_order
92+
nodal_keys = Array{Tuple{},1}()
93+
neighbours = [i for i=1:num_spins if i!=current_spin]
94+
if p == 1
95+
nodal_keys = [(current_spin,)]
96+
else
97+
perm = permutations(neighbours, p - 1)
98+
if length(perm) > 0
99+
nodal_keys = [(current_spin, perm[i]...) for i=1:length(perm)]
100+
end
101+
end
102+
103+
for index = 1:length(nodal_keys)
104+
nodal_stat[nodal_keys[index]] = [ prod(samples[k, 1 + i] for i=nodal_keys[index]) for k=1:num_conf]
105+
end
106+
end
107+
108+
m = Model(solver = method.solver)
109+
110+
@variable(m, x[keys(nodal_stat)])
111+
@variable(m, z[keys(nodal_stat)])
112+
113+
@NLobjective(m, Min,
114+
sum((samples[k,1]/num_samples)*exp(-sum(x[inter]*stat[k] for (inter,stat) = nodal_stat)) for k=1:num_conf) +
115+
lambda*sum(z[inter] for inter = keys(nodal_stat) if length(inter)>1)
116+
)
117+
118+
for inter in keys(nodal_stat)
119+
@constraint(m, z[inter] >= x[inter]) #z_plus
120+
@constraint(m, z[inter] >= -x[inter]) #z_minus
121+
end
122+
123+
status = solve(m)
124+
@assert status == :Optimal
125+
126+
127+
nodal_reconstruction = getvalue(x)
128+
for inter = keys(nodal_stat)
129+
reconstruction[inter] = deepcopy(nodal_reconstruction[inter])
130+
end
131+
end
132+
#if formulation.symmetrization
133+
# reconstruction = 0.5*(reconstruction + transpose(reconstruction))
134+
#end
135+
136+
return reconstruction
137+
end
72138

73139
function learn{T <: Real}(samples::Array{T,2}, formulation::RISE, method::NLP)
74140
num_conf, num_spins, num_samples = data_info(samples)

src/models.jl

Lines changed: 51 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@ FactorGraph{T <: Real}(matrix::Array{T,2}) = convert(FactorGraph{T}, matrix)
1818
function check_model_data{T <: Real}(order::Int, varible_count::Int, alphabet::Symbol, terms::Dict{Tuple,T}, variable_names::Nullable{Vector{String}})
1919
if !in(alphabet, alphabets)
2020
error("alphabet $(alphabet) is not supported")
21-
return false
21+
return false
2222
end
2323
if !isnull(variable_names) && length(variable_names) != varible_count
2424
error("expected $(varible_count) but only given $(length(variable_names))")
25-
return false
25+
return false
2626
end
2727
for (k,v) in terms
28-
if length(k) != order
28+
if length(k) > order
2929
error("a term has $(length(k)) indices but should have $(order) indices")
3030
return false
3131
end
@@ -95,6 +95,14 @@ function Base.convert{T <: Real}(::Type{FactorGraph{T}}, m::Array{T,2})
9595
varible_count = size(m,1)
9696

9797
terms = Dict{Tuple,T}()
98+
99+
for key in permutations(1:varible_count, 1)
100+
weight = m[key..., key...]
101+
if !isapprox(weight, 0.0)
102+
terms[key] = weight
103+
end
104+
end
105+
98106
for key in permutations(1:varible_count, 2)
99107
weight = m[key...]
100108
if !isapprox(weight, 0.0)
@@ -118,33 +126,64 @@ function Base.convert{T <: Real}(::Type{Array{T,2}}, gm::FactorGraph{T})
118126

119127
matrix = zeros(gm.varible_count, gm.varible_count)
120128
for (k,v) in gm
121-
matrix[k...] = v
122-
r = reverse(k)
123-
matrix[r...] = v
129+
if length(k) == 1
130+
matrix[k..., k...] = v
131+
else
132+
matrix[k...] = v
133+
r = reverse(k)
134+
matrix[r...] = v
135+
end
124136
end
125137

126138
return matrix
127139
end
128140

129141

142+
Base.convert{T <: Real}(::Type{Dict}, m::Array{T,2}) = convert(Dict{Tuple,T}, m)
143+
function Base.convert{T <: Real}(::Type{Dict{Tuple,T}}, m::Array{T,2})
144+
@assert size(m,1) == size(m,2) #check matrix is square
145+
146+
varible_count = size(m,1)
147+
148+
terms = Dict{Tuple,T}()
149+
150+
for key in permutations(1:varible_count, 1)
151+
weight = m[key..., key...]
152+
if !isapprox(weight, 0.0)
153+
terms[key] = weight
154+
end
155+
end
156+
157+
for key in permutations(1:varible_count, 2, asymmetric=true)
158+
if key[1] != key[2]
159+
weight = m[key...]
160+
if !isapprox(weight, 0.0)
161+
terms[key] = weight
162+
end
163+
end
164+
end
165+
166+
return terms
167+
end
168+
169+
130170

131171
permutations(items, order::Int; asymmetric::Bool = false) = sort(permutations([], items, order, asymmetric))
132172

133-
function permutations(partical_perm::Array{Any,1}, items, order::Int, asymmetric::Bool)
173+
function permutations(partial_perm::Array{Any,1}, items, order::Int, asymmetric::Bool)
134174
if order == 0
135-
return [tuple(partical_perm...)]
175+
return [tuple(partial_perm...)]
136176
else
137177
perms = []
138178
for item in items
139-
if !asymmetric && length(partical_perm) > 0
140-
if partical_perm[end] < item
179+
if !asymmetric && length(partial_perm) > 0
180+
if partial_perm[end] >= item
141181
continue
142182
end
143183
end
144-
perm = permutations(vcat([item], partical_perm), items, order-1, asymmetric)
184+
perm = permutations(vcat(partial_perm, item), items, order-1, asymmetric)
145185
append!(perms, perm)
146186
end
147187
return perms
148188
end
149189
end
150-

src/sampling.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ bool_to_spin(bool::Int) = 2*bool-1
2626
function weigh_proba{T <: Real}(int_representation::Int, adj::Array{T,2}, prior::Array{T,1}, spins::Array{Int,1})
2727
digits!(spins, int_representation, 2)
2828
spins .= bool_to_spin.(spins)
29-
return exp(((0.5) * spins' * adj * spins + prior' * spins)[1])
29+
return exp((0.5 * spins' * adj * spins + prior' * spins)[1])
3030
end
3131

3232

test/runtests.jl

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@ include("common.jl")
1111
gm2 = FactorGraph(matrix)
1212
for key in keys(gm)
1313
@test isapprox(gm[key], gm2[key])
14-
@test isapprox(gm[key], matrix[key...])
14+
if length(key) == 1
15+
@test isapprox(gm[key], matrix[key..., key...])
16+
else
17+
@test isapprox(gm[key], matrix[key...])
18+
end
1519
end
1620
end
1721
end
@@ -22,6 +26,7 @@ end
2226
srand(0) # fix random number generator
2327
samples = sample(gm, gibbs_test_samples)
2428
base_samples = readcsv("data/$(name)_samples.csv")
29+
#println(name)
2530
#println(base_samples)
2631
#println(samples)
2732
#println(abs.(base_samples-samples))
@@ -110,6 +115,37 @@ srand(0) # fix random number generator
110115
end
111116

112117

118+
@testset "inverse multi-body formulations" begin
119+
120+
for (name, gm) in gms
121+
samples = readcsv("data/$(name)_samples.csv")
122+
srand(0) # fix random number generator
123+
learned_ising = learn(samples, RISE(0.2, false))
124+
learned_two_body = learn(samples, multiRISE(0.2, false, 2))
125+
126+
learned_ising_dict = convert(Dict, learned_ising)
127+
#println(learned_ising_dict)
128+
#println(learned_two_body)
129+
130+
@test length(learned_ising_dict) == length(learned_two_body)
131+
for (key, value) in learned_ising_dict
132+
@test isapprox(learned_two_body[key], value)
133+
end
134+
end
135+
136+
samples = readcsv("data/mvt_samples.csv")
137+
138+
rand(0) # fix random number generator
139+
learned_ising = learn(samples, RISE(0.2, false), NLP(IpoptSolver(print_level=0)))
140+
learned_two_body = learn(samples, multiRISE(0.2, false, 2), NLP(IpoptSolver(print_level=0)))
141+
142+
learned_ising_dict = convert(Dict, learned_ising)
143+
@test length(learned_ising_dict) == length(learned_two_body)
144+
for (key, value) in learned_ising_dict
145+
@test isapprox(learned_two_body[key], value)
146+
end
147+
end
148+
113149

114150
srand(0) # fix random number generator
115151
@testset "docs example" begin
@@ -120,3 +156,4 @@ srand(0) # fix random number generator
120156
err = abs.(convert(Array{Float64,2}, model) - learned)
121157
@test maximum(err) <= 0.01
122158
end
159+

0 commit comments

Comments
 (0)