Skip to content

Commit

Permalink
Merge pull request #1 from MartinKocour/master
Browse files Browse the repository at this point in the history
Fix issues with Forward-Backward
  • Loading branch information
iondel authored Jul 27, 2020
2 parents 023c751 + bf7f874 commit 5f4a5b2
Show file tree
Hide file tree
Showing 7 changed files with 1,911 additions and 859 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,6 @@ docs/site/
# committed for packages, but should be committed for applications that require a static
# environment.
Manifest.toml

# Files generated by JupyterNotebook
.ipynb_checkpoints
5 changes: 5 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,9 @@ authors = ["Lucas Ondel <iondel@fit.vutbr.cz>"]
version = "0.1.0"

[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
SpeechFeatures = "6f3487c4-5ca2-4050-bfeb-2cf56df92307"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
2,478 changes: 1,673 additions & 805 deletions examples/demo.ipynb

Large diffs are not rendered by default.

24 changes: 23 additions & 1 deletion src/HiddenMarkovModel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,8 @@ function Base.show(io, ::MIME"image/svg+xml", g::AbstractGraph)

for state in states(g)
shape = isemitting(state) ? "circle" : "point"
write(dotfile, "$(id(state)) [ shape=\"$(shape)\", label=\"$(name(state))\" ];\n")
label = "$(id(state)):$(name(state))"
write(dotfile, "$(id(state)) [ shape=\"$(shape)\", label=\"$(label)\" ];\n")
end
for arc in arcs(g)
src, dest, weight = id(arc[1]), id(arc[2]), round(arc[3], digits = 3)
Expand Down Expand Up @@ -260,10 +261,31 @@ export nopruning

include("graph.jl")

#######################################################################
# Pretty display the sparse matrix (i.e. from αβrecursion).

import Printf:@sprintf
function Base.show(io::IO, ::MIME"text/plain", a::Array{Dict{State, T},1}) where T <: AbstractFloat
for n in 1:length(a)
write(io, "[n = $n] \t")
max = foldl(((sa,wa), (s,w)) -> wa < w ? (s,w) : (sa,wa), a[n]; init=first(a[n]))
write(io, first(max) |> name)
for (s, w) in sort(a[n]; by=x->name(x))
write(io, "\t$(id(s)):$(name(s)) = $(@sprintf("%.3f", w)) ")
end
write(io, "\n")
end
end

#######################################################################
# Major algorithms

include("algorithms.jl")


#######################################################################
# Other

include("../src/misc.jl")

end
243 changes: 191 additions & 52 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ end

function (pruning::ThresholdPruning)(candidates::Dict{State, T}) where T <: AbstractFloat
maxval = maximum(p -> p.second, candidates)
return filter(p -> maxval - p.second pruning.Δ, candidates)
filter!(p -> maxval - p.second pruning.Δ, candidates)
end


Expand All @@ -43,30 +43,22 @@ Forward step of the Baum-Welch algorithm in the log-domain.
"""
function αrecursion(g::AbstractGraph, llh::Matrix{T};
pruning::Union{Real, NoPruning} = nopruning) where T <: AbstractFloat
pruning = pruning nopruning ? ThresholdPruning(pruning) : pruning
α = Matrix{T}(undef, size(llh))
fill!(α, T(-Inf))
pruning! = pruning nopruning ? ThresholdPruning(pruning) : pruning

activestates = Dict{State, T}(initstate(g) => T(0.0))
newstates = Dict{State, T}()
α = Vector{Dict{State, T}}()

for n in 1:size(llh, 2)
for state_weightpath in activestates
state, weightpath = state_weightpath
for nstate_linkweight in emittingstates(forward, state)
nstate, linkweight = nstate_linkweight
push!(α, Dict{State,T}())
for (state, weightpath) in activestates
for (nstate, linkweight) in emittingstates(forward, state)
nweightpath = weightpath + linkweight
newstates[nstate] = llh[nstate.pdfindex, n] + logaddexp(get(newstates, nstate, T(-Inf)), nweightpath)
α[n][nstate] = llh[pdfindex(nstate), n] + logaddexp(get(α[n], nstate, T(-Inf)), nweightpath)
end
end

for nstate_nweightpath in newstates
nstate, nweightpath = nstate_nweightpath
α[nstate.pdfindex, n] = logaddexp(α[nstate.pdfindex, n], nweightpath)
end

empty!(activestates)
merge!(activestates, pruning(newstates))
empty!(newstates)
merge!(activestates, pruning!(α[n]))
end
α
end
Expand All @@ -78,33 +70,25 @@ Backward step of the Baum-Welch algorithm in the log domain.
"""
function βrecursion(g::AbstractGraph, llh::Matrix{T};
pruning::Union{Real, NoPruning} = nopruning) where T <: AbstractFloat
pruning = pruning nopruning ? ThresholdPruning(pruning) : pruning
β = Matrix{eltype(llh)}(undef, size(llh))
fill!(β, T(-Inf))

activestates = Dict{State, T}(finalstate(g) => T(0.0))
newstates = Dict{State, T}()

for n in size(llh, 2):-1:1
for state_weightpath in activestates
state, weightpath = state_weightpath
emitting = isemitting(state)
prev_llh = emitting ? llh[state.pdfindex, n+1] : T(0.0)
for nstate_linkweight in emittingstates(backward, state)
nstate, linkweight = nstate_linkweight
pruning! = pruning nopruning ? ThresholdPruning(pruning) : pruning

activestates = Dict{State, T}()
β = Vector{Dict{State, T}}()
push!(β, Dict(s => T(0.0) for (s, w) in emittingstates(backward, finalstate(g))))

for n in size(llh, 2)-1:-1:1
# Update the active tokens
empty!(activestates)
merge!(activestates, pruning!(β[1]))

pushfirst!(β, Dict{State,T}())
for (state, weightpath) in activestates
prev_llh = llh[pdfindex(state), n+1]
for (nstate, linkweight) in emittingstates(backward, state)
nweightpath = weightpath + linkweight + prev_llh
newstates[nstate] = logaddexp(get(newstates, nstate, T(-Inf)), nweightpath)
β[1][nstate] = logaddexp(get(β[1], nstate, T(-Inf)), nweightpath)
end
end

for nstate_nweightpath in newstates
nstate, nweightpath = nstate_nweightpath
β[nstate.pdfindex, n] = logaddexp(β[nstate.pdfindex, n], nweightpath)
end

empty!(activestates)
merge!(activestates, pruning(newstates))
empty!(newstates)
end
β
end
Expand All @@ -118,37 +102,57 @@ function αβrecursion(g::AbstractGraph, llh::Matrix{T};
pruning::Union{Real, NoPruning} = nopruning) where T <: AbstractFloat
α = αrecursion(g, llh, pruning = pruning)
β = βrecursion(g, llh, pruning = pruning)
α + β .- logsumexp+ β, dims = 1)

γ = Vector{Dict{State,T}}()

for n in 1:size(llh, 2)
push!(γ, Dict{State, T}())
for s in union(keys(α[n]), keys(β[n]))
a = get(α[n], s, T(-Inf))
b = get(β[n], s, T(-Inf))
γ[n][s] = a + b
end
filter!(p -> isfinite(p.second), γ[n])
sum = logsumexp(values(γ[n]))

for s in keys(γ[n])
γ[n][s] -= sum
end
end

# Total Log Likelihood
fs = foldl((acc, (s, w)) -> push!(acc, s), emittingstates(backward, finalstate(g)); init=[])
ttl = filter(s -> s[1] in fs, α[end]) |> values |> sum

γ, ttl
end

# function total_llh()

#######################################################################
# Viterbi algorithm (find the best path)

export viterbi

function maxβrecursion(g::AbstractGraph, llh::Matrix{T}, α::Matrix{T}) where T <: AbstractFloat
function maxβrecursion(g::AbstractGraph, llh::Matrix{T}, α::Vector{Dict{State,T}}) where T <: AbstractFloat
bestseq = Vector{State}()
activestates = Dict{State, T}(finalstate(g) => T(0.0))
newstates = Dict{State, T}()

for n in size(llh, 2):-1:1
for state_weightpath in activestates
state, weightpath = state_weightpath
for (state, weightpath) in activestates
emitting = isemitting(state)
prev_llh = emitting ? llh[state.pdfindex, n+1] : T(0.0)
for nstate_linkweight in emittingstates(backward, state)
nstate, linkweight = nstate_linkweight
for (nstate, linkweight) in emittingstates(backward, state)
nweightpath = weightpath + linkweight + prev_llh
newstates[nstate] = logaddexp(get(newstates, nstate, T(-Inf)), nweightpath)
end
end


hypscores = Vector{T}(undef, length(newstates))
hypstates = Vector{State}(undef, length(newstates))
for (i, nstate_nweightpath) in enumerate(newstates)
nstate, nweightpath = nstate_nweightpath
hypscores[i] = α[nstate.pdfindex, n] + nweightpath
for (i, (nstate, nweightpath)) in enumerate(newstates)
hypscores[i] = get(α[n], nstate, T(-Inf)) + nweightpath
hypstates[i] = nstate
end
println(hypstates)
Expand Down Expand Up @@ -229,7 +233,7 @@ end
export weightnormalize

"""
normalize(graph)
weightnormalize(graph)
Update the weights of the graph such that the exponentiation of the
weight of all the outoing arc from a state sum up to one.
Expand Down Expand Up @@ -287,3 +291,138 @@ function addselfloop(graph::Graph; loopprob = 0.5)
g
end

#######################################################################
# Union two graphs into one

import Base: union

"""
union(g1::AbstractGraph, g2::AbstractGraph)
"""
function Base.union(g1::AbstractGraph, g2::AbstractGraph)
g = Graph()
statecount = 0
old2new = Dict{AbstractState, AbstractState}(
initstate(g1) => initstate(g),
finalstate(g1) => finalstate(g),
)
for (i, state) in enumerate(states(g1))
if id(state) finalstateid && id(state) initstateid
statecount += 1
old2new[state] = addstate!(g, State(statecount, pdfindex(state), name(state)))
end
end

for state in states(g1)
src = old2new[state]
for link in children(state)
link!(src, old2new[link.dest], link.weight)
end
end

old2new = Dict{AbstractState, AbstractState}(
initstate(g2) => initstate(g),
finalstate(g2) => finalstate(g),
)
for (i, state) in enumerate(states(g2))
if id(state) finalstateid && id(state) initstateid
statecount += 1
old2new[state] = addstate!(g, State(statecount, pdfindex(state), name(state)))
end
end

for state in states(g2)
src = old2new[state]
for link in children(state)
link!(src, old2new[link.dest], link.weight)
end
end

g |> determinize |> weightnormalize
end

#######################################################################
# FSM minimization

export minimize

function leftminimize!(g::Graph, state::AbstractState)
leaves = Dict()
for link in children(state)
leaf, weight = get(leaves, pdfindex(link.dest), ([], -Inf))
push!(leaf, link.dest)
leaves[pdfindex(link.dest)] = (leaf, logaddexp(weight, link.weight))
end

empty!(state.outgoing)
for (nextstates, weight) in values(leaves)
mergedstate = nextstates[1]
filter!(l -> l.dest state, mergedstate.incoming)

# Now we removed all the extra states of the graph.
links = vcat([next.outgoing for next in nextstates[2:end]]...)
for link in links
link!(mergedstate, link.dest, link.weight)
end

for old in nextstates[2:end]
for link in children(old)
filter!(l -> l.dest old, link.dest.incoming)
end
delete!(g.states, id(old))
end

# Reconnect the previous state with the merged state
link!(state, mergedstate, weight)

# Minimize the subgraph.
leftminimize!(g, mergedstate)
end
g
end

function rightminimize!(g::Graph, state::AbstractState)
leaves = Dict()
for link in parents(state)
leaf, weight = get(leaves, pdfindex(link.dest), ([], -Inf))
push!(leaf, link.dest)
leaves[pdfindex(link.dest)] = (leaf, logaddexp(weight, link.weight))
end

empty!(state.incoming)
for (nextstates, weight) in values(leaves)
mergedstate = nextstates[1]
filter!(l -> l.dest state, mergedstate.outgoing)

# Now we removed all the extra states of the graph.
links = vcat([next.incoming for next in nextstates[2:end]]...)
for link in links
#link!(mergedstate, link.dest, link.weight)
link!(link.dest, mergedstate, link.weight)
end

for old in nextstates[2:end]
for link in parents(old)
filter!(l -> l.dest old, link.dest.outgoing)
end
delete!(g.states, id(old))
end

# Reconnect the previous state with the merged state
link!(mergedstate, state, weight)

# Minimize the subgraph.
rightminimize!(g, mergedstate)
end
g
end

"""
minimize(g::Graph)
"""
minimize(g::Graph) = begin
newg = deepcopy(g)
newg = leftminimize!(newg, initstate(newg))
rightminimize!(newg, finalstate(newg))
end

1 change: 0 additions & 1 deletion src/graph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,4 +138,3 @@ function LinearGraph(sequence::AbstractArray{String},
link!(prevstate, finalstate(g), 0.)
g
end

16 changes: 16 additions & 0 deletions src/misc.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import StatsFuns: logsumexp

function logsumexp(X::AbstractArray{T}; dims=:) where {T<:Real}
u = reduce(max, X, dims=dims, init=oftype(log(zero(T)), -Inf))
u isa AbstractArray || isfinite(u) || return float(u)
let u=u
if u isa AbstractArray
v = u .+ log.(sum(exp.(X .- u); dims=dims))
i = .! isfinite.(v)
v[i] .= u[i]
v
else
u + log(sum(x -> exp(x-u), X))
end
end
end

0 comments on commit 5f4a5b2

Please sign in to comment.