Skip to content

Commit f699092

Browse files
authored
Merge pull request #18 from lanl-ansi/fg-sampler
Generic Factor Graph Sampler
2 parents 8311f51 + 3b9a515 commit f699092

File tree

4 files changed

+109
-11
lines changed

4 files changed

+109
-11
lines changed

src/GraphicalModelLearning.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ function learn{T <: Real}(samples::Array{T,2}, formulation::multiRISE, method::N
146146
end
147147
end
148148

149-
return reconstruction
149+
return FactorGraph(inter_order, num_spins, :spin, reconstruction)
150150
end
151151

152152
function learn{T <: Real}(samples::Array{T,2}, formulation::RISE, method::NLP)

src/models.jl

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# data structures graphical models
22

3-
export FactorGraph
3+
export FactorGraph, jsondata
44

55
alphabets = [:spin, :boolean, :integer, :integer_pos, :real, :real_pos]
66

@@ -14,6 +14,9 @@ type FactorGraph{T <: Real}
1414
end
1515
FactorGraph{T <: Real}(order::Int, varible_count::Int, alphabet::Symbol, terms::Dict{Tuple,T}) = FactorGraph{T}(order, varible_count, alphabet, terms, Nullable{Vector{String}}())
1616
FactorGraph{T <: Real}(matrix::Array{T,2}) = convert(FactorGraph{T}, matrix)
17+
FactorGraph{T <: Real}(dict::Dict{Tuple,T}) = convert(FactorGraph{T}, dict)
18+
FactorGraph(list::Array{Any,1}) = convert(FactorGraph, list)
19+
1720

1821
function check_model_data{T <: Real}(order::Int, varible_count::Int, alphabet::Symbol, terms::Dict{Tuple,T}, variable_names::Nullable{Vector{String}})
1922
if !in(alphabet, alphabets)
@@ -35,11 +38,14 @@ function check_model_data{T <: Real}(order::Int, varible_count::Int, alphabet::S
3538
error("a term has an index of $(index) but it should be in the range of 1:$(varible_count)")
3639
return false
3740
end
41+
#=
42+
# TODO see when this should be enforced
3843
if i > 1
3944
if k[i-1] > index
4045
error("the term $(k) does not have ascending indices")
4146
end
4247
end
48+
=#
4349
end
4450
end
4551
return true
@@ -53,12 +59,20 @@ function Base.show(io::IO, gm::FactorGraph)
5359
println(io, " ", get(gm.variable_names))
5460
end
5561

56-
println(io, "terms: ")
57-
for k in sort(collect(keys(gm.terms)))
62+
println(io, "terms: $(length(gm.terms))")
63+
for k in sort(collect(keys(gm.terms)), by=(x)->(length(x),x))
5864
println(" ", k, " => ", gm.terms[k])
5965
end
6066
end
6167

68+
function jsondata{T <: Real}(gm::FactorGraph{T})
69+
data = []
70+
for k in sort(collect(keys(gm.terms)), by=(x)->(length(x),x))
71+
push!(data, Dict("term" => k, "weight" => gm.terms[k]))
72+
end
73+
return data
74+
end
75+
6276
Base.start(gm::FactorGraph) = start(gm.terms)
6377
Base.next(gm::FactorGraph, state) = next(gm.terms, state)
6478
Base.done(gm::FactorGraph, state) = done(gm.terms, state)
@@ -68,7 +82,6 @@ Base.length(gm::FactorGraph) = length(gm.terms)
6882
Base.getindex(gm::FactorGraph, i) = gm.terms[i]
6983
Base.keys(gm::FactorGraph) = keys(gm.terms)
7084

71-
7285
function diag_keys(gm::FactorGraph)
7386
dkeys = Tuple[]
7487
for i in 1:gm.varible_count
@@ -167,6 +180,48 @@ function Base.convert{T <: Real}(::Type{Dict{Tuple,T}}, m::Array{T,2})
167180
end
168181

169182

183+
function Base.convert(::Type{FactorGraph}, list::Array{Any,1})
184+
info("assuming spin alphabet")
185+
alphabet = :spin
186+
187+
max_variable = 0
188+
max_order = 0
189+
terms = Dict{Tuple,Float64}()
190+
191+
for item in list
192+
term = item["term"]
193+
weight = item["weight"]
194+
terms[tuple(term...)] = weight
195+
196+
@assert minimum(term) > 0
197+
max_order = max(max_order, length(term))
198+
max_variable = max(max_variable, maximum(term))
199+
end
200+
201+
info("dectected $(max_variable) variables with order $(max_order)")
202+
203+
return FactorGraph(max_order, max_variable, alphabet, terms)
204+
end
205+
206+
207+
Base.convert{T <: Real}(::Type{FactorGraph}, dict::Dict{Tuple,T}) = convert(FactorGraph{T}, dict)
208+
function Base.convert{T <: Real}(::Type{FactorGraph{T}}, dict::Dict{Tuple,T})
209+
info("assuming spin alphabet")
210+
alphabet = :spin
211+
212+
max_variable = 0
213+
max_order = 0
214+
for (term,weight) in dict
215+
@assert minimum(term) > 0
216+
max_order = max(max_order, length(term))
217+
max_variable = max(max_variable, maximum(term))
218+
end
219+
220+
info("dectected $(max_variable) variables with order $(max_order)")
221+
222+
return FactorGraph(max_order, max_variable, alphabet, dict)
223+
end
224+
170225

171226
permutations(items, order::Int; asymmetric::Bool = false) = sort(permutations([], items, order, asymmetric))
172227

src/sampling.jl

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ function weigh_proba{T <: Real}(int_representation::Int, adj::Array{T,2}, prior:
3030
end
3131

3232

33-
function sample_generation{T <: Real}(gm::FactorGraph{T}, samples_per_bin::Integer, bins::Int)
33+
# assumes second order
34+
function sample_generation_ising{T <: Real}(gm::FactorGraph{T}, samples_per_bin::Integer, bins::Int)
3435
@assert bins >= 1
3536

3637
spin_number = gm.varible_count
@@ -43,7 +44,37 @@ function sample_generation{T <: Real}(gm::FactorGraph{T}, samples_per_bin::Integ
4344
assignment_tmp = [0 for i in 1:spin_number] # pre allocate assignment memory
4445
weights = [weigh_proba(i, adjacency_matrix, prior_vector, assignment_tmp) for i in (0:config_number-1)]
4546

46-
items = [i for i in 0:(config_number-1)]
47+
raw_sample = StatsBase.sample(items, StatsBase.Weights(weights), samples_per_bin*bins, ordered=false)
48+
raw_sample_bins = reshape(raw_sample, bins, samples_per_bin)
49+
50+
spin_samples = []
51+
for b in 1:bins
52+
raw_binning = countmap(raw_sample_bins[b,:])
53+
spin_sample = [ vcat(raw_binning[i], int_to_spin(i, spin_number)) for i in keys(raw_binning)]
54+
push!(spin_samples, hcat(spin_sample...)')
55+
end
56+
return spin_samples
57+
end
58+
59+
60+
function weigh_proba{T <: Real}(int_representation::Int, gm::FactorGraph{T}, spins::Array{Int,1})
61+
digits!(spins, int_representation, 2)
62+
spins .= bool_to_spin.(spins)
63+
evaluation = sum( weight*prod(spins[i] for i in term) for (term, weight) in gm)
64+
return exp(evaluation)
65+
end
66+
67+
function sample_generation{T <: Real}(gm::FactorGraph{T}, samples_per_bin::Integer, bins::Int)
68+
@assert bins >= 1
69+
#info("use general sample model")
70+
71+
spin_number = gm.varible_count
72+
config_number = 2^spin_number
73+
74+
items = [i for i in 0:(config_number-1)]
75+
assignment_tmp = [0 for i in 1:spin_number] # pre allocate assignment memory
76+
weights = [weigh_proba(i, gm, assignment_tmp) for i in (0:config_number-1)]
77+
4778
raw_sample = StatsBase.sample(items, StatsBase.Weights(weights), samples_per_bin*bins, ordered=false)
4879
raw_sample_bins = reshape(raw_sample, bins, samples_per_bin)
4980

@@ -61,14 +92,15 @@ sample{T <: Real}(gm::FactorGraph{T}, number_sample::Integer, replicates::Intege
6192

6293

6394
function sample{T <: Real}(gm::FactorGraph{T}, number_sample::Integer, replicates::Integer, sampler::Gibbs)
64-
if gm.order != 2
65-
error("sampling is only supported for FactorGraphs of order 2, given order $(gm.order)")
66-
end
6795
if gm.alphabet != :spin
6896
error("sampling is only supported for spin FactorGraphs, given alphabet $(gm.alphabet)")
6997
end
7098

71-
samples = sample_generation(gm, number_sample, replicates)
99+
if gm.order <= 2
100+
samples = sample_generation_ising(gm, number_sample, replicates)
101+
else
102+
samples = sample_generation(gm, number_sample, replicates)
103+
end
72104

73105
return samples
74106
end

test/runtests.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,17 @@ end
2222

2323

2424
@testset "gibbs sampler" begin
25+
for (name, gm) in gms
26+
gm_tmp = deepcopy(gm)
27+
gm_tmp.order = 3
28+
srand(0) # fix random number generator
29+
samples = sample(gm_tmp, gibbs_test_samples)
30+
base_samples = readcsv("data/$(name)_samples.csv")
31+
@test isapprox(samples, base_samples)
32+
end
33+
end
34+
35+
@testset "gibbs sampler, 2nd order" begin
2536
for (name, gm) in gms
2637
srand(0) # fix random number generator
2738
samples = sample(gm, gibbs_test_samples)

0 commit comments

Comments
 (0)