Skip to content

Commit 081705d

Browse files
authored
Adding Support for Factor Graphs (#13)
* adding basic factor graph data structure * adding tests for factor graphs
1 parent 4476ff1 commit 081705d

File tree

9 files changed

+209
-27
lines changed

9 files changed

+209
-27
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@ Try the following commands in julia,
1313
```
1414
using GraphicalModelLearning
1515
16-
model = [0.0 0.1 0.2; 0.1 0.0 0.3; 0.2 0.3 0.0]
16+
model = FactorGraph([0.0 0.1 0.2; 0.1 0.0 0.3; 0.2 0.3 0.0])
1717
samples = sample(model, 100000)
1818
learned = learn(samples)
1919
20-
err = abs.(model - learned)
20+
err = abs.(convert(Array{Float64,2}, model) - learned)
2121
```
2222

2323
Note that the first invocation of `learn` will be slow as the dependent libraries are compiled. Subsequent calls will be fast.

src/GraphicalModelLearning.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ using Ipopt
1313

1414
using Compat # used for julia v0.5 abstract types
1515

16+
include("models.jl")
17+
1618
include("sampling.jl")
1719

1820
@compat abstract type GMLFormulation end

src/models.jl

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# data structures graphical models
2+
3+
export FactorGraph
4+
5+
alphabets = [:spin, :boolean, :integer, :integer_pos, :real, :real_pos]
6+
7+
type FactorGraph{T <: Real}
8+
order::Int
9+
varible_count::Int
10+
alphabet::Symbol
11+
terms::Dict{Tuple,T} # TODO, would be nice to have a stronger tuple type here
12+
variable_names::Nullable{Vector{String}}
13+
FactorGraph(a,b,c,d,e) = check_model_data(a,b,c,d,e) ? new(a,b,c,d,e) : error("generic init problem")
14+
end
15+
FactorGraph{T <: Real}(order::Int, varible_count::Int, alphabet::Symbol, terms::Dict{Tuple,T}) = FactorGraph{T}(order, varible_count, alphabet, terms, Nullable{Vector{String}}())
16+
FactorGraph{T <: Real}(matrix::Array{T,2}) = convert(FactorGraph{T}, matrix)
17+
18+
function check_model_data{T <: Real}(order::Int, varible_count::Int, alphabet::Symbol, terms::Dict{Tuple,T}, variable_names::Nullable{Vector{String}})
19+
if !in(alphabet, alphabets)
20+
error("alphabet $(alphabet) is not supported")
21+
return false
22+
end
23+
if !isnull(variable_names) && length(variable_names) != varible_count
24+
error("expected $(varible_count) but only given $(length(variable_names))")
25+
return false
26+
end
27+
for (k,v) in terms
28+
if length(k) != order
29+
error("a term has $(length(k)) indices but should have $(order) indices")
30+
return false
31+
end
32+
for (i,index) in enumerate(k)
33+
#println(i," ",index)
34+
if index < 1 || index > varible_count
35+
error("a term has an index of $(index) but it should be in the range of 1:$(varible_count)")
36+
return false
37+
end
38+
if i > 1
39+
if k[i-1] > index
40+
error("the term $(k) does not have ascending indices")
41+
end
42+
end
43+
end
44+
end
45+
return true
46+
end
47+
48+
function Base.show(io::IO, gm::FactorGraph)
49+
println(io, "alphabet: ", gm.alphabet)
50+
println(io, "vars: ", gm.varible_count)
51+
if !isnull(gm.variable_names)
52+
println(io, "variable names: ")
53+
println(io, " ", get(gm.variable_names))
54+
end
55+
56+
println(io, "terms: ")
57+
for k in sort(collect(keys(gm.terms)))
58+
println(" ", k, " => ", gm.terms[k])
59+
end
60+
end
61+
62+
Base.start(gm::FactorGraph) = start(gm.terms)
63+
Base.next(gm::FactorGraph, state) = next(gm.terms, state)
64+
Base.done(gm::FactorGraph, state) = done(gm.terms, state)
65+
66+
Base.length(gm::FactorGraph) = length(gm.terms)
67+
68+
Base.getindex(gm::FactorGraph, i) = gm.terms[i]
69+
Base.keys(gm::FactorGraph) = keys(gm.terms)
70+
71+
72+
function diag_keys(gm::FactorGraph)
73+
dkeys = Tuple[]
74+
for i in 1:gm.varible_count
75+
key = diag_key(gm, i)
76+
if key in keys(gm.terms)
77+
push!(dkeys, key)
78+
end
79+
end
80+
return sort(dkeys)
81+
end
82+
83+
diag_key(gm::FactorGraph, i::Int) = tuple(fill(i, gm.order)...)
84+
85+
#Base.diag{T <: Real}(gm::FactorGraph{T}) = [ get(gm.terms, diag_key(gm, i), zero(T)) for i in 1:gm.varible_count ]
86+
87+
Base.DataFmt.writecsv{T <: Real}(io, gm::FactorGraph{T}, args...; kwargs...) = writecsv(io, convert(Array{T,2}, gm), args...; kwargs...)
88+
89+
Base.convert{T <: Real}(::Type{FactorGraph}, m::Array{T,2}) = convert(FactorGraph{T}, m)
90+
function Base.convert{T <: Real}(::Type{FactorGraph{T}}, m::Array{T,2})
91+
@assert size(m,1) == size(m,2) #check matrix is square
92+
93+
info("assuming spin alphabet")
94+
alphabet = :spin
95+
varible_count = size(m,1)
96+
97+
terms = Dict{Tuple,T}()
98+
for key in permutations(1:varible_count, 2)
99+
weight = m[key...]
100+
if !isapprox(weight, 0.0)
101+
terms[key] = weight
102+
end
103+
104+
rev = reverse(key)
105+
if !isapprox(m[rev...], 0.0) && !isapprox(m[key...], m[rev...])
106+
delta = abs(m[key...] - m[rev...])
107+
warn("values at $(key) and $(rev) differ by $(delta), only $(key) will be used")
108+
end
109+
end
110+
111+
return FactorGraph(2, varible_count, alphabet, terms)
112+
end
113+
114+
function Base.convert{T <: Real}(::Type{Array{T,2}}, gm::FactorGraph{T})
115+
if gm.order != 2
116+
error("cannot convert a FactorGraph of order $(gm.order) to a matrix")
117+
end
118+
119+
matrix = zeros(gm.varible_count, gm.varible_count)
120+
for (k,v) in gm
121+
matrix[k...] = v
122+
r = reverse(k)
123+
matrix[r...] = v
124+
end
125+
126+
return matrix
127+
end
128+
129+
130+
131+
permutations(items, order::Int; asymmetric::Bool = false) = sort(permutations([], items, order, asymmetric))
132+
133+
function permutations(partical_perm::Array{Any,1}, items, order::Int, asymmetric::Bool)
134+
if order == 0
135+
return [tuple(partical_perm...)]
136+
else
137+
perms = []
138+
for item in items
139+
if !asymmetric && length(partical_perm) > 0
140+
if partical_perm[end] < item
141+
continue
142+
end
143+
end
144+
perm = permutations(vcat([item], partical_perm), items, order-1, asymmetric)
145+
append!(perms, perm)
146+
end
147+
return perms
148+
end
149+
end
150+

src/sampling.jl

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,30 +13,35 @@ function int_to_spin(int_representation::Int, spin_number::Int)
1313
return spin
1414
end
1515

16+
1617
function weigh_proba{T <: Real}(int_representation::Int, adj::Array{T,2}, prior::Array{T,1})
17-
spin_number = size(adj,1)
18+
spin_number = size(adj,1)
1819
spins = int_to_spin(int_representation, spin_number)
1920
return exp(((0.5) * spins' * adj * spins + prior' * spins)[1])
2021
end
2122

2223

2324
bool_to_spin(bool::Int) = 2*bool-1
2425

25-
function weigh_proba{T <: Real}(int_representation::Int, adj::Array{T,2}, prior::Array{T,1}, assignment_tmp::Array{Int,1})
26-
digits!(assignment_tmp, int_representation, 2)
27-
assignment_tmp .= bool_to_spin.(assignment_tmp)
28-
return exp(((0.5) * assignment_tmp' * adj * assignment_tmp + prior' * assignment_tmp)[1])
26+
function weigh_proba{T <: Real}(int_representation::Int, adj::Array{T,2}, prior::Array{T,1}, spins::Array{Int,1})
27+
digits!(spins, int_representation, 2)
28+
spins .= bool_to_spin.(spins)
29+
return exp(((0.5) * spins' * adj * spins + prior' * spins)[1])
2930
end
3031

3132

32-
function sample_generation{T <: Real}(samples_per_bin::Integer, adj::Array{T,2}, prior::Array{T,1}, bins::Int)
33+
function sample_generation{T <: Real}(gm::FactorGraph{T}, samples_per_bin::Integer, bins::Int)
3334
@assert bins >= 1
3435

35-
spin_number = size(adj,1)
36+
spin_number = gm.varible_count
3637
config_number = 2^spin_number
3738

39+
adjacency_matrix = convert(Array{T,2}, gm)
40+
prior_vector = transpose(diag(adjacency_matrix))[1,:]
41+
42+
items = [i for i in 0:(config_number-1)]
3843
assignment_tmp = [0 for i in 1:spin_number] # pre allocate assignment memory
39-
weights = [weigh_proba(i, adj, prior, assignment_tmp) for i in (0:config_number-1)]
44+
weights = [weigh_proba(i, adjacency_matrix, prior_vector, assignment_tmp) for i in (0:config_number-1)]
4045

4146
items = [i for i in 0:(config_number-1)]
4247
raw_sample = StatsBase.sample(items, StatsBase.Weights(weights), samples_per_bin*bins, ordered=false)
@@ -51,14 +56,19 @@ function sample_generation{T <: Real}(samples_per_bin::Integer, adj::Array{T,2},
5156
return spin_samples
5257
end
5358

54-
sample{T <: Real}(adjacency_matrix::Array{T,2}, number_sample::Integer) = sample(adjacency_matrix, number_sample, 1, Gibbs())[1]
55-
sample{T <: Real}(adjacency_matrix::Array{T,2}, number_sample::Integer, replicates::Integer) = sample(adjacency_matrix, number_sample, replicates, Gibbs())
59+
sample{T <: Real}(gm::FactorGraph{T}, number_sample::Integer) = sample(gm, number_sample, 1, Gibbs())[1]
60+
sample{T <: Real}(gm::FactorGraph{T}, number_sample::Integer, replicates::Integer) = sample(gm, number_sample, replicates, Gibbs())
5661

57-
function sample{T <: Real}(adjacency_matrix::Array{T,2}, number_sample::Integer, replicates::Integer, sampler::Gibbs)
58-
prior_vector = transpose(diag(adjacency_matrix))[1,:] #priors, or magnetic fields part
5962

60-
# generation of samples
61-
samples = sample_generation(number_sample, adjacency_matrix, prior_vector, replicates)
63+
function sample{T <: Real}(gm::FactorGraph{T}, number_sample::Integer, replicates::Integer, sampler::Gibbs)
64+
if gm.order != 2
65+
error("sampling is only supported for FactorGraphs of order 2, given order $(gm.order)")
66+
end
67+
if gm.alphabet != :spin
68+
error("sampling is only supported for spin FactorGraphs, given alphabet $(gm.alphabet)")
69+
end
70+
71+
samples = sample_generation(gm, number_sample, replicates)
6272

6373
return samples
6474
end

test/common.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,20 @@ formulations = Dict(
1313
)
1414

1515
gms = Dict(
16-
"a" => [
16+
"a" => FactorGraph([
1717
0.0 0.1 0.2;
1818
0.1 0.0 0.3;
1919
0.2 0.3 0.0
20-
],
21-
"b" => [
20+
]),
21+
"b" => FactorGraph([
2222
0.3 0.1 0.2;
2323
0.1 0.2 0.3;
2424
0.2 0.3 0.1
25-
],
26-
"c" => [
25+
]),
26+
"c" => FactorGraph([
2727
0.0 0.1 0.2 0.3;
2828
0.1 0.0 0.2 0.3;
2929
0.2 0.2 0.0 0.3;
3030
0.3 0.3 0.3 0.0
31-
]
31+
])
3232
)

test/data/build_data.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44

55
using GraphicalModelLearning
66

7-
# remove old files
8-
`rm -rf *.csv`
7+
## remove old files
8+
#`rm -rf *.csv`
99

1010
include("../common.jl")
1111

test/data/c_model_1.csv

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
0.0, 0.1, 0.2, 0.3
2+
0.1, 0.0, 0.2, 0.3
3+
0.2, 0.2, 0.0, 0.3
4+
0.3, 0.3, 0.3, 0.0

test/data/c_model_2.csv

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
0.1, 0.1, 0.2, 0.3
2+
0.0, -0.2, 0.2, 0.3
3+
0.0, 0.0, 0.3, 0.3
4+
0.0, 0.0, 0.0, -0.4

test/runtests.jl

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,18 @@ using Base.Test
55
include("common.jl")
66

77

8+
@testset "factor graphs" begin
9+
for (name, gm) in gms
10+
matrix = convert(Array{Float64,2}, gm)
11+
gm2 = FactorGraph(matrix)
12+
for key in keys(gm)
13+
@test isapprox(gm[key], gm2[key])
14+
@test isapprox(gm[key], matrix[key...])
15+
end
16+
end
17+
end
18+
19+
820
@testset "gibbs sampler" begin
921
for (name, gm) in gms
1022
srand(0) # fix random number generator
@@ -90,7 +102,7 @@ srand(0) # fix random number generator
90102
sample_histo = sample(gm, act.samples)
91103
#learned_gm = inverse_ising(sample_histo, method=act.formulation)
92104
learned_gm = learn(sample_histo, act.formulation)
93-
max_error = maximum(abs.(gm - learned_gm))
105+
max_error = maximum(abs.(convert(Array{Float64,2}, gm) - learned_gm))
94106
@test max_error <= act.threshold
95107
end
96108
end
@@ -101,10 +113,10 @@ end
101113

102114
srand(0) # fix random number generator
103115
@testset "docs example" begin
104-
model = [0.0 0.1 0.2; 0.1 0.0 0.3; 0.2 0.3 0.0]
116+
model = FactorGraph([0.0 0.1 0.2; 0.1 0.0 0.3; 0.2 0.3 0.0])
105117
samples = sample(model, 100000)
106118
learned = learn(samples)
107119

108-
err = abs.(model - learned)
120+
err = abs.(convert(Array{Float64,2}, model) - learned)
109121
@test maximum(err) <= 0.01
110122
end

0 commit comments

Comments
 (0)