|
| 1 | +# data structures graphical models |
| 2 | + |
| 3 | +export FactorGraph |
| 4 | + |
| 5 | +alphabets = [:spin, :boolean, :integer, :integer_pos, :real, :real_pos] |
| 6 | + |
| 7 | +type FactorGraph{T <: Real} |
| 8 | + order::Int |
| 9 | + varible_count::Int |
| 10 | + alphabet::Symbol |
| 11 | + terms::Dict{Tuple,T} # TODO, would be nice to have a stronger tuple type here |
| 12 | + variable_names::Nullable{Vector{String}} |
| 13 | + FactorGraph(a,b,c,d,e) = check_model_data(a,b,c,d,e) ? new(a,b,c,d,e) : error("generic init problem") |
| 14 | +end |
| 15 | +FactorGraph{T <: Real}(order::Int, varible_count::Int, alphabet::Symbol, terms::Dict{Tuple,T}) = FactorGraph{T}(order, varible_count, alphabet, terms, Nullable{Vector{String}}()) |
| 16 | +FactorGraph{T <: Real}(matrix::Array{T,2}) = convert(FactorGraph{T}, matrix) |
| 17 | + |
| 18 | +function check_model_data{T <: Real}(order::Int, varible_count::Int, alphabet::Symbol, terms::Dict{Tuple,T}, variable_names::Nullable{Vector{String}}) |
| 19 | + if !in(alphabet, alphabets) |
| 20 | + error("alphabet $(alphabet) is not supported") |
| 21 | + return false |
| 22 | + end |
| 23 | + if !isnull(variable_names) && length(variable_names) != varible_count |
| 24 | + error("expected $(varible_count) but only given $(length(variable_names))") |
| 25 | + return false |
| 26 | + end |
| 27 | + for (k,v) in terms |
| 28 | + if length(k) != order |
| 29 | + error("a term has $(length(k)) indices but should have $(order) indices") |
| 30 | + return false |
| 31 | + end |
| 32 | + for (i,index) in enumerate(k) |
| 33 | + #println(i," ",index) |
| 34 | + if index < 1 || index > varible_count |
| 35 | + error("a term has an index of $(index) but it should be in the range of 1:$(varible_count)") |
| 36 | + return false |
| 37 | + end |
| 38 | + if i > 1 |
| 39 | + if k[i-1] > index |
| 40 | + error("the term $(k) does not have ascending indices") |
| 41 | + end |
| 42 | + end |
| 43 | + end |
| 44 | + end |
| 45 | + return true |
| 46 | +end |
| 47 | + |
| 48 | +function Base.show(io::IO, gm::FactorGraph) |
| 49 | + println(io, "alphabet: ", gm.alphabet) |
| 50 | + println(io, "vars: ", gm.varible_count) |
| 51 | + if !isnull(gm.variable_names) |
| 52 | + println(io, "variable names: ") |
| 53 | + println(io, " ", get(gm.variable_names)) |
| 54 | + end |
| 55 | + |
| 56 | + println(io, "terms: ") |
| 57 | + for k in sort(collect(keys(gm.terms))) |
| 58 | + println(" ", k, " => ", gm.terms[k]) |
| 59 | + end |
| 60 | +end |
| 61 | + |
| 62 | +Base.start(gm::FactorGraph) = start(gm.terms) |
| 63 | +Base.next(gm::FactorGraph, state) = next(gm.terms, state) |
| 64 | +Base.done(gm::FactorGraph, state) = done(gm.terms, state) |
| 65 | + |
| 66 | +Base.length(gm::FactorGraph) = length(gm.terms) |
| 67 | + |
| 68 | +Base.getindex(gm::FactorGraph, i) = gm.terms[i] |
| 69 | +Base.keys(gm::FactorGraph) = keys(gm.terms) |
| 70 | + |
| 71 | + |
| 72 | +function diag_keys(gm::FactorGraph) |
| 73 | + dkeys = Tuple[] |
| 74 | + for i in 1:gm.varible_count |
| 75 | + key = diag_key(gm, i) |
| 76 | + if key in keys(gm.terms) |
| 77 | + push!(dkeys, key) |
| 78 | + end |
| 79 | + end |
| 80 | + return sort(dkeys) |
| 81 | +end |
| 82 | + |
| 83 | +diag_key(gm::FactorGraph, i::Int) = tuple(fill(i, gm.order)...) |
| 84 | + |
| 85 | +#Base.diag{T <: Real}(gm::FactorGraph{T}) = [ get(gm.terms, diag_key(gm, i), zero(T)) for i in 1:gm.varible_count ] |
| 86 | + |
| 87 | +Base.DataFmt.writecsv{T <: Real}(io, gm::FactorGraph{T}, args...; kwargs...) = writecsv(io, convert(Array{T,2}, gm), args...; kwargs...) |
| 88 | + |
| 89 | +Base.convert{T <: Real}(::Type{FactorGraph}, m::Array{T,2}) = convert(FactorGraph{T}, m) |
| 90 | +function Base.convert{T <: Real}(::Type{FactorGraph{T}}, m::Array{T,2}) |
| 91 | + @assert size(m,1) == size(m,2) #check matrix is square |
| 92 | + |
| 93 | + info("assuming spin alphabet") |
| 94 | + alphabet = :spin |
| 95 | + varible_count = size(m,1) |
| 96 | + |
| 97 | + terms = Dict{Tuple,T}() |
| 98 | + for key in permutations(1:varible_count, 2) |
| 99 | + weight = m[key...] |
| 100 | + if !isapprox(weight, 0.0) |
| 101 | + terms[key] = weight |
| 102 | + end |
| 103 | + |
| 104 | + rev = reverse(key) |
| 105 | + if !isapprox(m[rev...], 0.0) && !isapprox(m[key...], m[rev...]) |
| 106 | + delta = abs(m[key...] - m[rev...]) |
| 107 | + warn("values at $(key) and $(rev) differ by $(delta), only $(key) will be used") |
| 108 | + end |
| 109 | + end |
| 110 | + |
| 111 | + return FactorGraph(2, varible_count, alphabet, terms) |
| 112 | +end |
| 113 | + |
| 114 | +function Base.convert{T <: Real}(::Type{Array{T,2}}, gm::FactorGraph{T}) |
| 115 | + if gm.order != 2 |
| 116 | + error("cannot convert a FactorGraph of order $(gm.order) to a matrix") |
| 117 | + end |
| 118 | + |
| 119 | + matrix = zeros(gm.varible_count, gm.varible_count) |
| 120 | + for (k,v) in gm |
| 121 | + matrix[k...] = v |
| 122 | + r = reverse(k) |
| 123 | + matrix[r...] = v |
| 124 | + end |
| 125 | + |
| 126 | + return matrix |
| 127 | +end |
| 128 | + |
| 129 | + |
| 130 | + |
| 131 | +permutations(items, order::Int; asymmetric::Bool = false) = sort(permutations([], items, order, asymmetric)) |
| 132 | + |
| 133 | +function permutations(partical_perm::Array{Any,1}, items, order::Int, asymmetric::Bool) |
| 134 | + if order == 0 |
| 135 | + return [tuple(partical_perm...)] |
| 136 | + else |
| 137 | + perms = [] |
| 138 | + for item in items |
| 139 | + if !asymmetric && length(partical_perm) > 0 |
| 140 | + if partical_perm[end] < item |
| 141 | + continue |
| 142 | + end |
| 143 | + end |
| 144 | + perm = permutations(vcat([item], partical_perm), items, order-1, asymmetric) |
| 145 | + append!(perms, perm) |
| 146 | + end |
| 147 | + return perms |
| 148 | + end |
| 149 | +end |
| 150 | + |
0 commit comments