diff --git a/.gitignore b/.gitignore index 29126e4..2975a3f 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,6 @@ docs/site/ # committed for packages, but should be committed for applications that require a static # environment. Manifest.toml + +# Files generated by JupyterNotebook +.ipynb_checkpoints diff --git a/Project.toml b/Project.toml index 6104760..cb493a7 100644 --- a/Project.toml +++ b/Project.toml @@ -4,4 +4,9 @@ authors = ["Lucas Ondel "] version = "0.1.0" [deps] +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" +Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" +SpeechFeatures = "6f3487c4-5ca2-4050-bfeb-2cf56df92307" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" diff --git a/examples/demo.ipynb b/examples/demo.ipynb index bc332ee..4514cd2 100644 --- a/examples/demo.ipynb +++ b/examples/demo.ipynb @@ -2,14 +2,30 @@ "cells": [ { "cell_type": "code", - "execution_count": 290, + "execution_count": 1, "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[32m\u001b[1mActivating\u001b[22m\u001b[39m environment at `~/Skola/VTI/HiddenMarkovModel/Project.toml`\n" + ] + }, { "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m\u001b[1m Activating\u001b[22m\u001b[39m environment at `~/GithubRepositories/HiddenMarkovModel/Project.toml`\n" + "┌ Info: Precompiling HiddenMarkovModel [7212e43c-49b6-4278-a2e6-573e0ac5bb2d]\n", + "└ @ Base loading.jl:1273\n", + "┌ Warning: Replacing docs for `HiddenMarkovModel.pdfindex :: Union{}` in module `HiddenMarkovModel`\n", + "└ @ Base.Docs docs/Docs.jl:223\n", + "WARNING: Method definition logsumexp(AbstractArray{T<:Real, N} where N) where {T<:Real} in module StatsFuns at /Users/praca/.julia/packages/StatsFuns/CXyCV/src/basicfuns.jl:237 overwritten in module HiddenMarkovModel at /Users/praca/Skola/VTI/HiddenMarkovModel/src/misc.jl:4.\n", + " ** incremental compilation may be fatally broken for this module **\n", + "\n", + "WARNING: Method definition #logsumexp(Any, typeof(StatsFuns.logsumexp), AbstractArray{T<:Real, N} where N) where {T<:Real} in module StatsFuns overwritten in module HiddenMarkovModel.\n", + " ** incremental compilation may be fatally broken for this module **\n", + "\n" ] } ], @@ -17,14 +33,13 @@ "using Pkg\n", "Pkg.activate(\"../\")\n", "\n", - "using StatsFuns\n", "using Revise\n", "using HiddenMarkovModel" ] }, { "cell_type": "code", - "execution_count": 291, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -33,91 +48,90 @@ "\n", "\n", - "\n", - "\n", - "\n", - "\n", - "%3\n", - "\n", - "\n", - "\n", - "initstateid\n", - "\n", - "\n", - "\n", - "\n", - "1\n", - "\n", - "a\n", - "\n", - "\n", - "\n", - "initstateid->1\n", - "\n", - "\n", - "0.0\n", - "\n", + "\n", + "\n", + "\n", + "\n", "\n", - "\n", + "\n", "4\n", - "\n", - "a\n", + "\n", + "4:a\n", "\n", "\n", "\n", "finalstateid\n", - "\n", + "\n", "\n", "\n", - "\n", + "\n", "4->finalstateid\n", - "\n", - "\n", - "0.0\n", + "\n", + "\n", + "0.0\n", "\n", "\n", - "\n", + "\n", "2\n", - "\n", - "c\n", + "\n", + "2:c\n", "\n", "\n", - "\n", + "\n", "3\n", - "\n", - "b\n", + "\n", + "3:b\n", "\n", "\n", - "\n", + "\n", "2->3\n", - "\n", - "\n", - "0.0\n", + "\n", + "\n", + "0.0\n", "\n", "\n", - "\n", + "\n", "3->4\n", - "\n", - "\n", - "0.0\n", + "\n", + "\n", + "0.0\n", + "\n", + "\n", + "\n", + "initstateid\n", + "\n", + "\n", + "\n", + "\n", + "1\n", + "\n", + "1:a\n", + "\n", + "\n", + "\n", + "initstateid->1\n", + "\n", + "\n", + "0.0\n", "\n", "\n", "\n", "1->2\n", - "\n", - "\n", - "0.0\n", + "\n", + "\n", + "0.0\n", "\n", "\n", "\n" ], "text/plain": [ - "Graph(Dict{Union{Missing, HiddenMarkovModel.FinalStateID, HiddenMarkovModel.InitStateID, Int64},State}(initstateid => State(id = initstateid, pdfindex = nothing, initstateid),4 => State(id = 4, pdfindex = 1, a),2 => State(id = 2, pdfindex = 3, c),3 => State(id = 3, pdfindex = 2, b),finalstateid => State(id = finalstateid, pdfindex = nothing, finalstateid),1 => State(id = 1, pdfindex = 1, a)))" + "Graph(Dict{Union{Missing, HiddenMarkovModel.FinalStateID, HiddenMarkovModel.InitStateID, Int64},State}(4 => State(id = 4, pdfindex = 1, a),2 => State(id = 2, pdfindex = 3, c),3 => State(id = 3, pdfindex = 2, b),initstateid => State(id = initstateid, pdfindex = nothing, initstateid),finalstateid => State(id = finalstateid, pdfindex = nothing, finalstateid),1 => State(id = 1, pdfindex = 1, a)))" ] }, - "execution_count": 291, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } @@ -135,7 +149,7 @@ }, { "cell_type": "code", - "execution_count": 292, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -144,91 +158,90 @@ "\n", "\n", - "\n", - "\n", - "\n", - "\n", - "%3\n", - "\n", - "\n", - "\n", - "initstateid\n", - "\n", - "\n", - "\n", - "\n", - "1\n", - "\n", - "a\n", - "\n", - "\n", - "\n", - "initstateid->1\n", - "\n", - "\n", - "0.0\n", - "\n", + "\n", + "\n", + "\n", + "\n", "\n", - "\n", + "\n", "4\n", - "\n", - "a\n", + "\n", + "4:a\n", "\n", "\n", "\n", "finalstateid\n", - "\n", + "\n", "\n", "\n", - "\n", + "\n", "4->finalstateid\n", - "\n", - "\n", - "0.0\n", + "\n", + "\n", + "0.0\n", "\n", "\n", - "\n", + "\n", "2\n", - "\n", - "d\n", + "\n", + "2:d\n", "\n", "\n", - "\n", + "\n", "3\n", - "\n", - "b\n", + "\n", + "3:b\n", "\n", "\n", - "\n", + "\n", "2->3\n", - "\n", - "\n", - "0.0\n", + "\n", + "\n", + "0.0\n", "\n", "\n", - "\n", + "\n", "3->4\n", - "\n", - "\n", - "0.0\n", + "\n", + "\n", + "0.0\n", + "\n", + "\n", + "\n", + "initstateid\n", + "\n", + "\n", + "\n", + "\n", + "1\n", + "\n", + "1:a\n", + "\n", + "\n", + "\n", + "initstateid->1\n", + "\n", + "\n", + "0.0\n", "\n", "\n", "\n", "1->2\n", - "\n", - "\n", - "0.0\n", + "\n", + "\n", + "0.0\n", "\n", "\n", "\n" ], "text/plain": [ - "Graph(Dict{Union{Missing, HiddenMarkovModel.FinalStateID, HiddenMarkovModel.InitStateID, Int64},State}(initstateid => State(id = initstateid, pdfindex = nothing, initstateid),4 => State(id = 4, pdfindex = 1, a),2 => State(id = 2, pdfindex = 4, d),3 => State(id = 3, pdfindex = 2, b),finalstateid => State(id = finalstateid, pdfindex = nothing, finalstateid),1 => State(id = 1, pdfindex = 1, a)))" + "Graph(Dict{Union{Missing, HiddenMarkovModel.FinalStateID, HiddenMarkovModel.InitStateID, Int64},State}(4 => State(id = 4, pdfindex = 1, a),2 => State(id = 2, pdfindex = 4, d),3 => State(id = 3, pdfindex = 2, b),initstateid => State(id = initstateid, pdfindex = nothing, initstateid),finalstateid => State(id = finalstateid, pdfindex = nothing, finalstateid),1 => State(id = 1, pdfindex = 1, a)))" ] }, - "execution_count": 292, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -242,7 +255,7 @@ }, { "cell_type": "code", - "execution_count": 352, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -251,297 +264,66 @@ "\n", "\n", - "\n", - "\n", - "\n", - "\n", - "%3\n", - "\n", + "\n", + "\n", + "\n", + "\n", "\n", "\n", "initstateid\n", - "\n", - "\n", - "\n", - "\n", - "4\n", - "\n", - "a\n", - "\n", - "\n", - "\n", - "initstateid->4\n", - "\n", - "\n", - "-0.693\n", - "\n", - "\n", - "\n", - "8\n", - "\n", - "a\n", - "\n", - "\n", - "\n", - "initstateid->8\n", - "\n", - "\n", - "-0.693\n", - "\n", - "\n", - "\n", - "2\n", - "\n", - "c\n", - "\n", - "\n", - "\n", - "4->2\n", - "\n", - "\n", - "0.0\n", - "\n", - "\n", - "\n", - "7\n", - "\n", - "b\n", - "\n", - "\n", - "\n", - "5\n", - "\n", - "a\n", - "\n", - "\n", - "\n", - "7->5\n", - "\n", - "\n", - "0.0\n", - "\n", - "\n", - "\n", - "3\n", - "\n", - "b\n", - "\n", - "\n", - "\n", - "2->3\n", - "\n", - "\n", - "0.0\n", + "\n", "\n", "\n", - "\n", + "\n", "1\n", - "\n", - "a\n", + "\n", + "1:hello\n", "\n", - "\n", - "\n", - "3->1\n", - "\n", - "\n", - "0.0\n", + "\n", + "\n", + "initstateid->1\n", + "\n", + "\n", + "0.0\n", "\n", "\n", - "\n", + "\n", "finalstateid\n", - "\n", - "\n", - "\n", - "\n", - "5->finalstateid\n", - "\n", - "\n", - "0.0\n", - "\n", - "\n", - "\n", - "6\n", - "\n", - "d\n", - "\n", - "\n", - "\n", - "8->6\n", - "\n", - "\n", - "0.0\n", - "\n", - "\n", - "\n", - "6->7\n", - "\n", - "\n", - "0.0\n", + "\n", "\n", "\n", - "\n", + "\n", "1->finalstateid\n", - "\n", - "\n", - "0.0\n", + "\n", + "\n", + "0.0\n", "\n", "\n", "\n" ], "text/plain": [ - "Graph(Dict{Union{Missing, HiddenMarkovModel.FinalStateID, HiddenMarkovModel.InitStateID, Int64},State}(initstateid => State(id = initstateid, pdfindex = nothing, initstateid),4 => State(id = 4, pdfindex = 1, a),7 => State(id = 7, pdfindex = 2, b),2 => State(id = 2, pdfindex = 3, c),3 => State(id = 3, pdfindex = 2, b),finalstateid => State(id = finalstateid, pdfindex = nothing, finalstateid),5 => State(id = 5, pdfindex = 1, a),8 => State(id = 8, pdfindex = 1, a),6 => State(id = 6, pdfindex = 4, d),1 => State(id = 1, pdfindex = 1, a)…))" - ] - }, - "execution_count": 352, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "function union(g1::AbstractGraph, g2::AbstractGraph)\n", - " g = Graph()\n", - " statecount = 0\n", - " old2new = Dict{AbstractState, AbstractState}(\n", - " initstate(g1) => initstate(g),\n", - " finalstate(g1) => finalstate(g),\n", - " )\n", - " for (i, state) in enumerate(states(g1))\n", - " if id(state) ≠ finalstateid && id(state) ≠ initstateid\n", - " statecount += 1\n", - " old2new[state] = addstate!(g, State(statecount, pdfindex(state), name(state)))\n", - " end\n", - " end\n", - " \n", - " for state in states(g1)\n", - " src = old2new[state]\n", - " for link in children(state)\n", - " link!(src, old2new[link.dest], link.weight)\n", - " end\n", - " end\n", - " \n", - " old2new = Dict{AbstractState, AbstractState}(\n", - " initstate(g2) => initstate(g),\n", - " finalstate(g2) => finalstate(g),\n", - " )\n", - " for (i, state) in enumerate(states(g2))\n", - " if id(state) ≠ finalstateid && id(state) ≠ initstateid\n", - " statecount += 1\n", - " old2new[state] = addstate!(g, State(statecount, pdfindex(state), name(state)))\n", - " end\n", - " end\n", - " \n", - " for state in states(g2)\n", - " src = old2new[state]\n", - " for link in children(state)\n", - " link!(src, old2new[link.dest], link.weight)\n", - " end\n", - " end\n", - " \n", - " g |> determinize |> weightnormalize\n", - "end\n", - "g = union(g1, g2)" - ] - }, - { - "cell_type": "code", - "execution_count": 353, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "minimize (generic function with 1 method)" + "Graph(Dict{Union{Missing, HiddenMarkovModel.FinalStateID, HiddenMarkovModel.InitStateID, Int64},State}(initstateid => State(id = initstateid, pdfindex = nothing, initstateid),finalstateid => State(id = finalstateid, pdfindex = nothing, finalstateid),1 => State(id = 1, pdfindex = 1, hello)))" ] }, - "execution_count": 353, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "function leftminimize!(g::Graph, state::AbstractState)\n", - " leaves = Dict() \n", - " for link in children(state)\n", - " leaf, weight = get(leaves, pdfindex(link.dest), ([], -Inf))\n", - " push!(leaf, link.dest)\n", - " leaves[pdfindex(link.dest)] = (leaf, logaddexp(weight, link.weight))\n", - " end\n", - " \n", - " empty!(state.outgoing)\n", - " for (nextstates, weight) in values(leaves)\n", - " mergedstate = nextstates[1]\n", - " filter!(l -> l.dest ≠ state, mergedstate.incoming)\n", - "\n", - " # Now we removed all the extra states of the graph.\n", - " links = vcat([next.outgoing for next in nextstates[2:end]]...)\n", - " for link in links\n", - " link!(mergedstate, link.dest, link.weight)\n", - " end\n", - " \n", - " for old in nextstates[2:end]\n", - " for link in children(old)\n", - " filter!(l -> l.dest ≠ old, link.dest.incoming)\n", - " end\n", - " delete!(g.states, id(old))\n", - " end\n", - " \n", - " # Reconnect the previous state with the merged state\n", - " link!(state, mergedstate, weight)\n", - " \n", - " # Minimize the subgraph.\n", - " leftminimize!(g, mergedstate)\n", - " end\n", - " g \n", - "end\n", - "\n", - "\n", - "function rightminimize!(g::Graph, state::AbstractState)\n", - " leaves = Dict() \n", - " for link in parents(state)\n", - " leaf, weight = get(leaves, pdfindex(link.dest), ([], -Inf))\n", - " push!(leaf, link.dest)\n", - " leaves[pdfindex(link.dest)] = (leaf, logaddexp(weight, link.weight))\n", - " end\n", - " \n", - " empty!(state.incoming)\n", - " for (nextstates, weight) in values(leaves)\n", - " mergedstate = nextstates[1]\n", - " filter!(l -> l.dest ≠ state, mergedstate.outgoing)\n", - "\n", - " # Now we removed all the extra states of the graph.\n", - " links = vcat([next.incoming for next in nextstates[2:end]]...)\n", - " for link in links\n", - " #link!(mergedstate, link.dest, link.weight)\n", - " link!(link.dest, mergedstate, link.weight)\n", - " end\n", - " \n", - " for old in nextstates[2:end]\n", - " for link in parents(old)\n", - " filter!(l -> l.dest ≠ old, link.dest.outgoing)\n", - " end\n", - " delete!(g.states, id(old))\n", - " end\n", - " \n", - " # Reconnect the previous state with the merged state\n", - " link!(mergedstate, state, weight)\n", - " \n", - " # Minimize the subgraph.\n", - " rightminimize!(g, mergedstate)\n", - " end\n", - " g \n", - "end\n", - "minimize(g::Graph) = begin\n", - " newg = deepcopy(g)\n", - " newg = leftminimize!(newg, initstate(newg)) \n", - " rightminimize!(newg, finalstate(newg)) \n", - "end" + "g3 = Graph()\n", + "s1 = addstate!(g3, State(1, 1, \"hello\"))\n", + "link!(initstate(g3), s1)\n", + "link!(s1, finalstate(g3))\n", + "g3" ] }, { "cell_type": "code", - "execution_count": 354, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -550,161 +332,160 @@ "\n", "\n", - "\n", - "\n", - "\n", - "\n", - "%3\n", - "\n", - "\n", + "\n", + "\n", + "\n", + "\n", + "\n", "\n", - "initstateid\n", - "\n", + "7\n", + "\n", + "7:b\n", + "\n", + "\n", + "\n", + "5\n", + "\n", + "5:a\n", + "\n", + "\n", + "\n", + "7->5\n", + "\n", + "\n", + "0.0\n", "\n", "\n", "\n", "4\n", - "\n", - "a\n", - "\n", - "\n", - "\n", - "initstateid->4\n", - "\n", - "\n", - "-0.693\n", - "\n", - "\n", - "\n", - "8\n", - "\n", - "a\n", - "\n", - "\n", - "\n", - "initstateid->8\n", - "\n", - "\n", - "-0.693\n", + "\n", + "4:a\n", "\n", "\n", "\n", "2\n", - "\n", - "c\n", + "\n", + "2:c\n", "\n", "\n", - "\n", + "\n", "4->2\n", - "\n", - "\n", - "0.0\n", - "\n", - "\n", - "\n", - "7\n", - "\n", - "b\n", + "\n", + "\n", + "0.0\n", "\n", - "\n", + "\n", "\n", - "5\n", - "\n", - "a\n", + "finalstateid\n", + "\n", "\n", - "\n", - "\n", - "7->5\n", - "\n", - "\n", - "0.0\n", + "\n", + "\n", + "5->finalstateid\n", + "\n", + "\n", + "0.0\n", "\n", "\n", "\n", "3\n", - "\n", - "b\n", + "\n", + "3:b\n", "\n", "\n", - "\n", + "\n", "2->3\n", - "\n", - "\n", - "0.0\n", + "\n", + "\n", + "0.0\n", "\n", "\n", "\n", "1\n", - "\n", - "a\n", + "\n", + "1:a\n", "\n", "\n", - "\n", + "\n", "3->1\n", - "\n", - "\n", - "0.0\n", + "\n", + "\n", + "0.0\n", "\n", - "\n", + "\n", "\n", - "finalstateid\n", - "\n", + "initstateid\n", + "\n", "\n", - "\n", + "\n", + "\n", + "initstateid->4\n", + "\n", + "\n", + "-0.693\n", + "\n", + "\n", + "\n", + "8\n", + "\n", + "8:a\n", + "\n", + "\n", "\n", - "5->finalstateid\n", - "\n", - "\n", - "0.0\n", + "initstateid->8\n", + "\n", + "\n", + "-0.693\n", "\n", "\n", "\n", "6\n", - "\n", - "d\n", + "\n", + "6:d\n", "\n", "\n", "\n", "8->6\n", - "\n", - "\n", - "0.0\n", + "\n", + "\n", + "0.0\n", "\n", "\n", "\n", "6->7\n", - "\n", - "\n", - "0.0\n", + "\n", + "\n", + "0.0\n", "\n", "\n", "\n", "1->finalstateid\n", - "\n", - "\n", - "0.0\n", + "\n", + "\n", + "0.0\n", "\n", "\n", "\n" ], "text/plain": [ - "Graph(Dict{Union{Missing, HiddenMarkovModel.FinalStateID, HiddenMarkovModel.InitStateID, Int64},State}(initstateid => State(id = initstateid, pdfindex = nothing, initstateid),4 => State(id = 4, pdfindex = 1, a),7 => State(id = 7, pdfindex = 2, b),2 => State(id = 2, pdfindex = 3, c),3 => State(id = 3, pdfindex = 2, b),finalstateid => State(id = finalstateid, pdfindex = nothing, finalstateid),5 => State(id = 5, pdfindex = 1, a),8 => State(id = 8, pdfindex = 1, a),6 => State(id = 6, pdfindex = 4, d),1 => State(id = 1, pdfindex = 1, a)…))" + "Graph(Dict{Union{Missing, HiddenMarkovModel.FinalStateID, HiddenMarkovModel.InitStateID, Int64},State}(7 => State(id = 7, pdfindex = 2, b),4 => State(id = 4, pdfindex = 1, a),5 => State(id = 5, pdfindex = 1, a),2 => State(id = 2, pdfindex = 3, c),3 => State(id = 3, pdfindex = 2, b),initstateid => State(id = initstateid, pdfindex = nothing, initstateid),finalstateid => State(id = finalstateid, pdfindex = nothing, finalstateid),8 => State(id = 8, pdfindex = 1, a),6 => State(id = 6, pdfindex = 4, d),1 => State(id = 1, pdfindex = 1, a)…))" ] }, - "execution_count": 354, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "g" + "g = union(g1, g2)" ] }, { "cell_type": "code", - "execution_count": 355, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -713,122 +494,1196 @@ "\n", "\n", - "\n", - "\n", - "\n", - "\n", - "%3\n", - "\n", - "\n", - "\n", - "initstateid\n", - "\n", - "\n", - "\n", - "\n", - "8\n", - "\n", - "a\n", - "\n", - "\n", - "\n", - "initstateid->8\n", - "\n", - "\n", - "0.0\n", - "\n", + "\n", + "\n", + "\n", + "\n", "\n", - "\n", + "\n", "7\n", - "\n", - "b\n", + "\n", + "7:b\n", "\n", "\n", - "\n", + "\n", "5\n", - "\n", - "a\n", + "\n", + "5:b\n", "\n", "\n", - "\n", + "\n", "7->5\n", - "\n", - "\n", - "0.0\n", + "\n", + "\n", + "0.0\n", + "\n", + "\n", + "\n", + "4\n", + "\n", + "4:a\n", + "\n", + "\n", + "\n", + "2\n", + "\n", + "2:c\n", + "\n", + "\n", + "\n", + "4->2\n", + "\n", + "\n", + "0.0\n", + "\n", + "\n", + "\n", + "finalstateid\n", + "\n", + "\n", + "\n", + "\n", + "5->finalstateid\n", + "\n", + "\n", + "0.0\n", + "\n", + "\n", + "\n", + "3\n", + "\n", + "3:b\n", + "\n", + "\n", + "\n", + "2->3\n", + "\n", + "\n", + "0.0\n", + "\n", + "\n", + "\n", + "3->5\n", + "\n", + "\n", + "0.0\n", + "\n", + "\n", + "\n", + "1\n", + "\n", + "1:a\n", + "\n", + "\n", + "\n", + "3->1\n", + "\n", + "\n", + "0.0\n", + "\n", + "\n", + "\n", + "initstateid\n", + "\n", + "\n", + "\n", + "\n", + "initstateid->4\n", + "\n", + "\n", + "-0.693\n", + "\n", + "\n", + "\n", + "8\n", + "\n", + "8:a\n", + "\n", + "\n", + "\n", + "initstateid->8\n", + "\n", + "\n", + "-0.693\n", + "\n", + "\n", + "\n", + "6\n", + "\n", + "6:d\n", + "\n", + "\n", + "\n", + "8->6\n", + "\n", + "\n", + "0.0\n", + "\n", + "\n", + "\n", + "6->7\n", + "\n", + "\n", + "0.0\n", + "\n", + "\n", + "\n", + "1->finalstateid\n", + "\n", + "\n", + "0.0\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "Graph(Dict{Union{Missing, HiddenMarkovModel.FinalStateID, HiddenMarkovModel.InitStateID, Int64},State}(7 => State(id = 7, pdfindex = 2, b),4 => State(id = 4, pdfindex = 1, a),5 => State(id = 5, pdfindex = 2, b),2 => State(id = 2, pdfindex = 3, c),3 => State(id = 3, pdfindex = 2, b),initstateid => State(id = initstateid, pdfindex = nothing, initstateid),finalstateid => State(id = finalstateid, pdfindex = nothing, finalstateid),8 => State(id = 8, pdfindex = 1, a),6 => State(id = 6, pdfindex = 4, d),1 => State(id = 1, pdfindex = 1, a)…))" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "s = addstate!(g, State(5, emissionsmap[\"b\"], \"b\"))\n", + "fs = finalstate(g)\n", + "sb = g.states[3]\n", + "g.states\n", + "link!(s, fs)\n", + "link!(sb, s)\n", + "g" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "7\n", + "\n", + "7:b\n", + "\n", + "\n", + "\n", + "5\n", + "\n", + "5:b\n", + "\n", + "\n", + "\n", + "7->5\n", + "\n", + "\n", + "0.0\n", + "\n", + "\n", + "\n", + "4\n", + "\n", + "4:a\n", + "\n", + "\n", + "\n", + "2\n", + "\n", + "2:c\n", + "\n", + "\n", + "\n", + "4->2\n", + "\n", + "\n", + "0.0\n", + "\n", + "\n", + "\n", + "finalstateid\n", + "\n", + "\n", + "\n", + "\n", + "5->finalstateid\n", + "\n", + "\n", + "0.0\n", + "\n", + "\n", + "\n", + "3\n", + "\n", + "3:b\n", + "\n", + "\n", + "\n", + "2->3\n", + "\n", + "\n", + "0.0\n", + "\n", + "\n", + "\n", + "3->5\n", + "\n", + "\n", + "0.0\n", + "\n", + "\n", + "\n", + "1\n", + "\n", + "1:a\n", + "\n", + "\n", + "\n", + "3->1\n", + "\n", + "\n", + "0.0\n", + "\n", + "\n", + "\n", + "initstateid\n", + "\n", + "\n", + "\n", + "\n", + "initstateid->4\n", + "\n", + "\n", + "-0.693\n", + "\n", + "\n", + "\n", + "8\n", + "\n", + "8:a\n", + "\n", + "\n", + "\n", + "initstateid->8\n", + "\n", + "\n", + "-0.693\n", + "\n", + "\n", + "\n", + "6\n", + "\n", + "6:d\n", + "\n", + "\n", + "\n", + "8->6\n", + "\n", + "\n", + "0.0\n", + "\n", + "\n", + "\n", + "6->7\n", + "\n", + "\n", + "0.0\n", + "\n", + "\n", + "\n", + "1->finalstateid\n", + "\n", + "\n", + "0.0\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "Graph(Dict{Union{Missing, HiddenMarkovModel.FinalStateID, HiddenMarkovModel.InitStateID, Int64},State}(7 => State(id = 7, pdfindex = 2, b),4 => State(id = 4, pdfindex = 1, a),5 => State(id = 5, pdfindex = 2, b),2 => State(id = 2, pdfindex = 3, c),3 => State(id = 3, pdfindex = 2, b),initstateid => State(id = initstateid, pdfindex = nothing, initstateid),finalstateid => State(id = finalstateid, pdfindex = nothing, finalstateid),8 => State(id = 8, pdfindex = 1, a),6 => State(id = 6, pdfindex = 4, d),1 => State(id = 1, pdfindex = 1, a)…))" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "g" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "7\n", + "\n", + "7:b\n", + "\n", + "\n", + "\n", + "5\n", + "\n", + "5:b\n", + "\n", + "\n", + "\n", + "7->5\n", + "\n", + "\n", + "0.0\n", + "\n", + "\n", + "\n", + "4\n", + "\n", + "4:a\n", + "\n", + "\n", + "\n", + "2\n", + "\n", + "2:c\n", + "\n", + "\n", + "\n", + "4->2\n", + "\n", + "\n", + "0.0\n", + "\n", + "\n", + "\n", + "finalstateid\n", + "\n", + "\n", + "\n", + "\n", + "5->finalstateid\n", + "\n", + "\n", + "0.0\n", + "\n", + "\n", + "\n", + "3\n", + "\n", + "3:b\n", + "\n", + "\n", + "\n", + "2->3\n", + "\n", + "\n", + "0.0\n", + "\n", + "\n", + "\n", + "3->5\n", + "\n", + "\n", + "0.0\n", + "\n", + "\n", + "\n", + "1\n", + "\n", + "1:a\n", + "\n", + "\n", + "\n", + "3->1\n", + "\n", + "\n", + "0.0\n", + "\n", + "\n", + "\n", + "initstateid\n", + "\n", + "\n", + "\n", + "\n", + "initstateid->4\n", + "\n", + "\n", + "-0.693\n", + "\n", + "\n", + "\n", + "8\n", + "\n", + "8:a\n", + "\n", + "\n", + "\n", + "initstateid->8\n", + "\n", + "\n", + "-0.693\n", + "\n", + "\n", + "\n", + "6\n", + "\n", + "6:d\n", + "\n", + "\n", + "\n", + "8->6\n", + "\n", + "\n", + "0.0\n", + "\n", + "\n", + "\n", + "6->7\n", + "\n", + "\n", + "0.0\n", + "\n", + "\n", + "\n", + "1->finalstateid\n", + "\n", + "\n", + "0.0\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "Graph(Dict{Union{Missing, HiddenMarkovModel.FinalStateID, HiddenMarkovModel.InitStateID, Int64},State}(7 => State(id = 7, pdfindex = 2, b),4 => State(id = 4, pdfindex = 1, a),5 => State(id = 5, pdfindex = 2, b),2 => State(id = 2, pdfindex = 3, c),3 => State(id = 3, pdfindex = 2, b),initstateid => State(id = initstateid, pdfindex = nothing, initstateid),finalstateid => State(id = finalstateid, pdfindex = nothing, finalstateid),8 => State(id = 8, pdfindex = 1, a),6 => State(id = 6, pdfindex = 4, d),1 => State(id = 1, pdfindex = 1, a)…))" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "g" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "7\n", + "\n", + "7:b\n", + "\n", + "\n", + "\n", + "5\n", + "\n", + "5:b\n", + "\n", + "\n", + "\n", + "7->5\n", + "\n", + "\n", + "0.0\n", + "\n", + "\n", + "\n", + "4\n", + "\n", + "4:a\n", + "\n", + "\n", + "\n", + "2\n", + "\n", + "2:c\n", + "\n", + "\n", + "\n", + "4->2\n", + "\n", + "\n", + "0.0\n", + "\n", + "\n", + "\n", + "6\n", + "\n", + "6:d\n", + "\n", + "\n", + "\n", + "4->6\n", + "\n", + "\n", + "0.0\n", + "\n", + "\n", + "\n", + "1\n", + "\n", + "1:a\n", + "\n", + "\n", + "\n", + "finalstateid\n", + "\n", + "\n", + "\n", + "\n", + "1->finalstateid\n", + "\n", + "\n", + "0.0\n", + "\n", + "\n", + "\n", + "3\n", + "\n", + "3:b\n", + "\n", + "\n", + "\n", + "2->3\n", + "\n", + "\n", + "0.0\n", + "\n", + "\n", + "\n", + "3->1\n", + "\n", + "\n", + "0.0\n", + "\n", + "\n", + "\n", + "3->5\n", + "\n", + "\n", + "0.0\n", + "\n", + "\n", + "\n", + "initstateid\n", + "\n", + "\n", + "\n", + "\n", + "initstateid->4\n", + "\n", + "\n", + "0.0\n", + "\n", + "\n", + "\n", + "5->finalstateid\n", + "\n", + "\n", + "0.0\n", + "\n", + "\n", + "\n", + "6->7\n", + "\n", + "\n", + "0.0\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "Graph(Dict{Union{Missing, HiddenMarkovModel.FinalStateID, HiddenMarkovModel.InitStateID, Int64},State}(7 => State(id = 7, pdfindex = 2, b),4 => State(id = 4, pdfindex = 1, a),1 => State(id = 1, pdfindex = 1, a),2 => State(id = 2, pdfindex = 3, c),3 => State(id = 3, pdfindex = 2, b),initstateid => State(id = initstateid, pdfindex = nothing, initstateid),5 => State(id = 5, pdfindex = 2, b),finalstateid => State(id = finalstateid, pdfindex = nothing, finalstateid),6 => State(id = 6, pdfindex = 4, d)))" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "g |> minimize" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "7\n", + "\n", + "7:b\n", + "\n", + "\n", + "\n", + "5\n", + "\n", + "5:b\n", + "\n", + "\n", + "\n", + "7->5\n", + "\n", + "\n", + "0.0\n", + "\n", + "\n", + "\n", + "4\n", + "\n", + "4:a\n", "\n", "\n", "\n", "2\n", - "\n", - "c\n", + "\n", + "2:c\n", "\n", - "\n", + "\n", "\n", - "2->7\n", - "\n", - "\n", - "0.0\n", + "4->2\n", + "\n", + "\n", + "-0.693\n", "\n", - "\n", + "\n", + "\n", + "6\n", + "\n", + "6:d\n", + "\n", + "\n", + "\n", + "4->6\n", + "\n", + "\n", + "-0.693\n", + "\n", + "\n", "\n", + "3\n", + "\n", + "3:b\n", + "\n", + "\n", + "\n", + "2->3\n", + "\n", + "\n", + "0.0\n", + "\n", + "\n", + "\n", + "3->5\n", + "\n", + "\n", + "-0.693\n", + "\n", + "\n", + "\n", + "1\n", + "\n", + "1:a\n", + "\n", + "\n", + "\n", + "3->1\n", + "\n", + "\n", + "-0.693\n", + "\n", + "\n", + "\n", + "initstateid\n", + "\n", + "\n", + "\n", + "\n", + "initstateid->4\n", + "\n", + "\n", + "0.0\n", + "\n", + "\n", + "\n", "finalstateid\n", - "\n", + "\n", "\n", "\n", - "\n", + "\n", "5->finalstateid\n", - "\n", - "\n", - "0.0\n", + "\n", + "\n", + "0.0\n", + "\n", + "\n", + "\n", + "6->7\n", + "\n", + "\n", + "0.0\n", + "\n", + "\n", + "\n", + "1->finalstateid\n", + "\n", + "\n", + "0.0\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "Graph(Dict{Union{Missing, HiddenMarkovModel.FinalStateID, HiddenMarkovModel.InitStateID, Int64},State}(7 => State(id = 7, pdfindex = 2, b),4 => State(id = 4, pdfindex = 1, a),2 => State(id = 2, pdfindex = 3, c),3 => State(id = 3, pdfindex = 2, b),initstateid => State(id = initstateid, pdfindex = nothing, initstateid),finalstateid => State(id = finalstateid, pdfindex = nothing, finalstateid),5 => State(id = 5, pdfindex = 2, b),6 => State(id = 6, pdfindex = 4, d),1 => State(id = 1, pdfindex = 1, a)))" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "g = g |> minimize |> determinize |> weightnormalize" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "7\n", + "\n", + "7:b\n", + "\n", + "\n", + "\n", + "7->7\n", + "\n", + "\n", + "-0.693\n", + "\n", + "\n", + "\n", + "5\n", + "\n", + "5:b\n", + "\n", + "\n", + "\n", + "7->5\n", + "\n", + "\n", + "-0.693\n", "\n", - "\n", + "\n", + "\n", + "4\n", + "\n", + "4:a\n", + "\n", + "\n", + "\n", + "4->4\n", + "\n", + "\n", + "-0.693\n", + "\n", + "\n", + "\n", + "2\n", + "\n", + "2:c\n", + "\n", + "\n", "\n", - "8->2\n", - "\n", - "\n", - "-0.693\n", + "4->2\n", + "\n", + "\n", + "-1.386\n", "\n", "\n", - "\n", + "\n", "6\n", - "\n", - "d\n", + "\n", + "6:d\n", "\n", - "\n", + "\n", + "\n", + "4->6\n", + "\n", + "\n", + "-1.386\n", + "\n", + "\n", "\n", - "8->6\n", - "\n", - "\n", - "-0.693\n", + "2->2\n", + "\n", + "\n", + "-0.693\n", "\n", - "\n", + "\n", + "\n", + "3\n", + "\n", + "3:b\n", + "\n", + "\n", "\n", + "2->3\n", + "\n", + "\n", + "-0.693\n", + "\n", + "\n", + "\n", + "3->3\n", + "\n", + "\n", + "-0.693\n", + "\n", + "\n", + "\n", + "3->5\n", + "\n", + "\n", + "-1.386\n", + "\n", + "\n", + "\n", + "1\n", + "\n", + "1:a\n", + "\n", + "\n", + "\n", + "3->1\n", + "\n", + "\n", + "-1.386\n", + "\n", + "\n", + "\n", + "initstateid\n", + "\n", + "\n", + "\n", + "\n", + "initstateid->4\n", + "\n", + "\n", + "0.0\n", + "\n", + "\n", + "\n", + "finalstateid\n", + "\n", + "\n", + "\n", + "\n", + "5->finalstateid\n", + "\n", + "\n", + "-0.693\n", + "\n", + "\n", + "\n", + "5->5\n", + "\n", + "\n", + "-0.693\n", + "\n", + "\n", + "\n", "6->7\n", - "\n", - "\n", - "0.0\n", + "\n", + "\n", + "-0.693\n", + "\n", + "\n", + "\n", + "6->6\n", + "\n", + "\n", + "-0.693\n", + "\n", + "\n", + "\n", + "1->finalstateid\n", + "\n", + "\n", + "-0.693\n", + "\n", + "\n", + "\n", + "1->1\n", + "\n", + "\n", + "-0.693\n", "\n", "\n", "\n" ], "text/plain": [ - "Graph(Dict{Union{Missing, HiddenMarkovModel.FinalStateID, HiddenMarkovModel.InitStateID, Int64},State}(initstateid => State(id = initstateid, pdfindex = nothing, initstateid),7 => State(id = 7, pdfindex = 2, b),2 => State(id = 2, pdfindex = 3, c),finalstateid => State(id = finalstateid, pdfindex = nothing, finalstateid),5 => State(id = 5, pdfindex = 1, a),8 => State(id = 8, pdfindex = 1, a),6 => State(id = 6, pdfindex = 4, d)))" + "Graph(Dict{Union{Missing, HiddenMarkovModel.FinalStateID, HiddenMarkovModel.InitStateID, Int64},State}(7 => State(id = 7, pdfindex = 2, b),4 => State(id = 4, pdfindex = 1, a),2 => State(id = 2, pdfindex = 3, c),3 => State(id = 3, pdfindex = 2, b),initstateid => State(id = initstateid, pdfindex = nothing, initstateid),finalstateid => State(id = finalstateid, pdfindex = nothing, finalstateid),5 => State(id = 5, pdfindex = 2, b),6 => State(id = 6, pdfindex = 4, d),1 => State(id = 1, pdfindex = 1, a)))" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "g = addselfloop(g) |> weightnormalize" + ] + }, + { + "cell_type": "code", + "execution_count": 217, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "state State(id = initstateid, pdfindex = nothing, initstateid) : Link[]\n", + "---\n", + "state State(id = 4, pdfindex = 1, a) : Link[Link{Float64}(State(id = initstateid, pdfindex = nothing, initstateid), 0.0), Link{Float64}(State(id = 4, pdfindex = 1, a), -0.6931471805599454)]\n", + "---\n", + "state State(id = 7, pdfindex = 2, b) : Link[Link{Float64}(State(id = 7, pdfindex = 2, b), -0.6931471805599453), Link{Float64}(State(id = 2, pdfindex = 3, c), -0.6931471805599453), Link{Float64}(State(id = 6, pdfindex = 4, d), -0.6931471805599453)]\n", + "---\n", + "state State(id = 2, pdfindex = 3, c) : Link[Link{Float64}(State(id = 4, pdfindex = 1, a), -1.3862943611198906), Link{Float64}(State(id = 2, pdfindex = 3, c), -0.6931471805599453)]\n", + "---\n", + "state State(id = finalstateid, pdfindex = nothing, finalstateid) : Link[Link{Float64}(State(id = 5, pdfindex = 1, a), -0.6931471805599453)]\n", + "---\n", + "state State(id = 5, pdfindex = 1, a) : Link[Link{Float64}(State(id = 7, pdfindex = 2, b), -0.6931471805599453), Link{Float64}(State(id = 5, pdfindex = 1, a), -0.6931471805599453)]\n", + "---\n", + "state State(id = 6, pdfindex = 4, d) : Link[Link{Float64}(State(id = 4, pdfindex = 1, a), -1.3862943611198906), Link{Float64}(State(id = 6, pdfindex = 4, d), -0.6931471805599453)]\n", + "---\n" + ] + } + ], + "source": [ + "for (sid, state) in g.states\n", + " println(\"state $(state) : $(state.incoming)\")\n", + " println(\"---\")\n", + "end" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "4×400 Array{Float64,2}:\n", + " -7.23752 -5.44423 -5.76237 -4.26329 … -6.37517 -6.36586 -4.52336\n", + " -7.23752 -5.44423 -5.76237 -4.26329 -6.37517 -6.36586 -4.52336\n", + " -7.23752 -5.44423 -5.76237 -4.26329 -6.37517 -6.36586 -4.52336\n", + " -7.23752 -5.44423 -5.76237 -4.26329 -6.37517 -6.36586 -4.52336" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "D,N = 4,400 # number of distributions times number of frames\n", + "v = randn(N, 1) .- 6 \n", + "llh = repeat(v', D)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "4×5 Array{Float64,2}:\n", + " 1.0 1.0 1.0 1.0 1.0\n", + " 1.0 1.0 1.0 1.0 1.0\n", + " 1.0 1.0 1.0 1.0 1.0\n", + " 1.0 1.0 1.0 1.0 1.0" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "llh = ones(4,5)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "4×5 Array{Float64,2}:\n", + " 1.0 1.0 0.5 0.5 0.0\n", + " 0.5 0.5 1.0 1.0 0.5\n", + " 0.0 0.0 0.5 0.5 1.0\n", + " 0.0 0.0 0.0 0.0 0.5" ] }, - "execution_count": 355, + "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "g = g |> minimize |> determinize |> weightnormalize" + "llh[:,1:2] .= [1, 0.5, 0, 0]\n", + "llh[:,3:4] .= [0.5, 1, 0.5, 0]\n", + "llh[:,5:5] .= [0, 0.5, 1, 0.5]\n", + "llh" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Dict{String,Int64} with 4 entries:\n", + " \"c\" => 3\n", + " \"b\" => 2\n", + " \"a\" => 1\n", + " \"d\" => 4" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "emissionsmap" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "2-element Array{Link,1}:\n", + " Link{Float64}(State(id = 2, pdfindex = 3, c), -0.6931471805599453)\n", + " Link{Float64}(State(id = 1, pdfindex = 1, a), -0.6931471805599453)" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "2-element Array{Link,1}:\n", + " Link{Float64}(State(id = 2, pdfindex = 3, c), -0.6931471805599453)\n", + " Link{Float64}(State(id = 3, pdfindex = 2, b), -0.6931471805599453)" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "s = g.states[2]\n", + "display(s.incoming)\n", + "display(s.outgoing)" ] }, { "cell_type": "code", - "execution_count": 356, + "execution_count": 26, "metadata": {}, "outputs": [ { @@ -837,263 +1692,295 @@ "\n", "\n", - "\n", - "\n", - "\n", - "\n", - "%3\n", - "\n", - "\n", - "\n", - "initstateid\n", - "\n", - "\n", - "\n", - "\n", - "8\n", - "\n", - "a\n", - "\n", - "\n", - "\n", - "initstateid->8\n", - "\n", - "\n", - "0.0\n", - "\n", + "\n", + "\n", + "\n", + "\n", "\n", - "\n", + "\n", "7\n", - "\n", - "b\n", + "\n", + "7:b\n", "\n", "\n", - "\n", + "\n", "7->7\n", - "\n", - "\n", - "-0.693\n", + "\n", + "\n", + "-0.693\n", "\n", "\n", - "\n", + "\n", "5\n", - "\n", - "a\n", + "\n", + "5:b\n", "\n", "\n", - "\n", + "\n", "7->5\n", - "\n", - "\n", - "-0.693\n", + "\n", + "\n", + "-0.693\n", + "\n", + "\n", + "\n", + "4\n", + "\n", + "4:a\n", + "\n", + "\n", + "\n", + "4->4\n", + "\n", + "\n", + "-0.693\n", "\n", "\n", "\n", "2\n", - "\n", - "c\n", + "\n", + "2:c\n", "\n", - "\n", + "\n", "\n", - "2->7\n", - "\n", - "\n", - "-0.693\n", + "4->2\n", + "\n", + "\n", + "-1.386\n", "\n", - "\n", + "\n", + "\n", + "6\n", + "\n", + "6:d\n", + "\n", + "\n", "\n", + "4->6\n", + "\n", + "\n", + "-1.386\n", + "\n", + "\n", + "\n", "2->2\n", - "\n", - "\n", - "-0.693\n", + "\n", + "\n", + "-0.693\n", "\n", - "\n", + "\n", "\n", - "finalstateid\n", - "\n", + "3\n", + "\n", + "3:b\n", "\n", - "\n", + "\n", "\n", - "5->finalstateid\n", - "\n", - "\n", - "-0.693\n", + "2->3\n", + "\n", + "\n", + "-0.693\n", "\n", - "\n", - "\n", - "5->5\n", - "\n", - "\n", - "-0.693\n", + "\n", + "\n", + "3->3\n", + "\n", + "\n", + "-0.693\n", "\n", - "\n", + "\n", + "\n", + "3->5\n", + "\n", + "\n", + "-1.386\n", + "\n", + "\n", + "\n", + "1\n", + "\n", + "1:a\n", + "\n", + "\n", "\n", - "8->2\n", - "\n", - "\n", - "-1.386\n", + "3->1\n", + "\n", + "\n", + "-1.386\n", "\n", - "\n", - "\n", - "8->8\n", - "\n", - "\n", - "-0.693\n", + "\n", + "\n", + "initstateid\n", + "\n", "\n", - "\n", - "\n", - "6\n", - "\n", - "d\n", + "\n", + "\n", + "initstateid->4\n", + "\n", + "\n", + "0.0\n", "\n", - "\n", - "\n", - "8->6\n", - "\n", - "\n", - "-1.386\n", + "\n", + "\n", + "finalstateid\n", + "\n", "\n", - "\n", + "\n", + "\n", + "5->finalstateid\n", + "\n", + "\n", + "-0.693\n", + "\n", + "\n", "\n", + "5->5\n", + "\n", + "\n", + "-0.693\n", + "\n", + "\n", + "\n", "6->7\n", - "\n", - "\n", - "-0.693\n", + "\n", + "\n", + "-0.693\n", "\n", "\n", - "\n", + "\n", "6->6\n", - "\n", - "\n", - "-0.693\n", + "\n", + "\n", + "-0.693\n", + "\n", + "\n", + "\n", + "1->finalstateid\n", + "\n", + "\n", + "-0.693\n", + "\n", + "\n", + "\n", + "1->1\n", + "\n", + "\n", + "-0.693\n", "\n", "\n", "\n" ], "text/plain": [ - "Graph(Dict{Union{Missing, HiddenMarkovModel.FinalStateID, HiddenMarkovModel.InitStateID, Int64},State}(initstateid => State(id = initstateid, pdfindex = nothing, initstateid),7 => State(id = 7, pdfindex = 2, b),2 => State(id = 2, pdfindex = 3, c),finalstateid => State(id = finalstateid, pdfindex = nothing, finalstateid),5 => State(id = 5, pdfindex = 1, a),8 => State(id = 8, pdfindex = 1, a),6 => State(id = 6, pdfindex = 4, d)))" + "Graph(Dict{Union{Missing, HiddenMarkovModel.FinalStateID, HiddenMarkovModel.InitStateID, Int64},State}(7 => State(id = 7, pdfindex = 2, b),4 => State(id = 4, pdfindex = 1, a),2 => State(id = 2, pdfindex = 3, c),3 => State(id = 3, pdfindex = 2, b),initstateid => State(id = initstateid, pdfindex = nothing, initstateid),finalstateid => State(id = finalstateid, pdfindex = nothing, finalstateid),5 => State(id = 5, pdfindex = 2, b),6 => State(id = 6, pdfindex = 4, d),1 => State(id = 1, pdfindex = 1, a)))" ] }, - "execution_count": 356, + "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "g = addselfloop(g) |> weightnormalize" - ] - }, - { - "cell_type": "code", - "execution_count": 403, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "state State(id = initstateid, pdfindex = nothing, initstateid) : Link[]\n", - "---\n", - "state State(id = 7, pdfindex = 2, b) : Link[Link{Float64}(State(id = 7, pdfindex = 2, b), -0.6931471805599453), Link{Float64}(State(id = 2, pdfindex = 3, c), -0.6931471805599453), Link{Float64}(State(id = 6, pdfindex = 4, d), -0.6931471805599453)]\n", - "---\n", - "state State(id = 2, pdfindex = 3, c) : Link[Link{Float64}(State(id = 2, pdfindex = 3, c), -0.6931471805599453), Link{Float64}(State(id = 8, pdfindex = 1, a), -1.3862943611198906)]\n", - "---\n", - "state State(id = finalstateid, pdfindex = nothing, finalstateid) : Link[Link{Float64}(State(id = 5, pdfindex = 1, a), -0.6931471805599453)]\n", - "---\n", - "state State(id = 5, pdfindex = 1, a) : Link[Link{Float64}(State(id = 7, pdfindex = 2, b), -0.6931471805599453), Link{Float64}(State(id = 5, pdfindex = 1, a), -0.6931471805599453)]\n", - "---\n", - "state State(id = 8, pdfindex = 1, a) : Link[Link{Float64}(State(id = initstateid, pdfindex = nothing, initstateid), 0.0), Link{Float64}(State(id = 8, pdfindex = 1, a), -0.6931471805599454)]\n", - "---\n", - "state State(id = 6, pdfindex = 4, d) : Link[Link{Float64}(State(id = 8, pdfindex = 1, a), -1.3862943611198906), Link{Float64}(State(id = 6, pdfindex = 4, d), -0.6931471805599453)]\n", - "---\n" - ] - } - ], - "source": [ - "for (sid, state) in g.states\n", - " println(\"state $(state) : $(state.incoming)\")\n", - " println(\"---\")\n", - "end" + "g" ] }, { "cell_type": "code", - "execution_count": 391, + "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "4×4 Array{Float64,2}:\n", - " 0.393626 -0.831897 0.130839 -1.60193\n", - " 1.01238 -0.16196 -0.0855048 1.03511\n", - " -0.580188 -0.621978 0.0923155 1.18269\n", - " 0.397424 -0.966615 0.464552 0.375744" + "[n = 1] \ta\t4:a = 1.000 \n", + "[n = 2] \ta\t4:a = 1.307 \t2:c = 0.614 \t6:d = 0.614 \n", + "[n = 3] \tc\t4:a = 1.614 \t3:b = 0.921 \t7:b = 0.921 \t2:c = 2.234 \t6:d = 2.234 \n", + "[n = 4] \tc\t1:a = 0.534 \t4:a = 1.921 \t3:b = 3.089 \t7:b = 3.089 \t5:b = 2.086 \t2:c = 3.635 \t6:d = 3.635 \n", + "[n = 5] \tb\t1:a = 3.759 \t4:a = 2.227 \t3:b = 4.888 \t7:b = 4.888 \t5:b = 5.310 \t2:c = 4.974 \t6:d = 4.974 \n" ] }, - "execution_count": 391, + "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "llh = randn(4, 4)" + "α = αrecursion(g, llh; pruning=nopruning) # the lower the more pruning \n", + "α" ] }, { "cell_type": "code", - "execution_count": 395, + "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "4×4 Array{Float64,2}:\n", - " 0.393626 -1.13142 -1.69373 -3.39\n", - " -Inf -Inf -1.89227 1.1712\n", - " -Inf -1.61465 -1.5794 -0.191604\n", - " -Inf -1.95928 -1.15086 -1.11546" + "[n = 1] \tc\t1:a = 0.228 \t4:a = -1.466 \t3:b = 0.558 \t5:b = -0.466 \t2:c = 0.727 \n", + "[n = 2] \tb\t1:a = -0.992 \t4:a = -1.773 \t3:b = 0.469 \t5:b = -0.273 \t2:c = 0.408 \n", + "[n = 3] \tb\t4:a = -1.579 \t3:b = -0.183 \t5:b = -0.579 \t2:c = -0.799 \n", + "[n = 4] \tb\t4:a = -1.386 \t3:b = -1.105 \t5:b = -0.886 \n", + "[n = 5] \tb\t4:a = -0.693 \t5:b = -0.693 \n" ] }, - "execution_count": 395, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "αrecursion(g, llh) # the lower the more pruning " + "β = βrecursion(g, llh; pruning=nopruning) # the lower the more pruning \n", + "β" ] }, { "cell_type": "code", - "execution_count": 396, + "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "4×4 Array{Float64,2}:\n", - " 1.0 0.540901 0.549474 1.0\n", - " 0.0 0.0 0.450526 0.0\n", - " 0.0 0.268718 0.0 0.0\n", - " 0.0 0.190381 0.0 0.0" + "4.798578778287657" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Same" + ] + }, + { + "data": { + "text/plain": [ + "[n = 1] \ta\t1:a = 1.228 \n", + "[n = 2] \tc\t1:a = 0.315 \t2:c = 0.715 \n", + "[n = 3] \tc\t3:b = 0.431 \t2:c = 1.017 \n", + "[n = 4] \tb\t4:a = -1.659 \t3:b = 2.122 \t5:b = -0.659 \n", + "[n = 5] \tb\t4:a = 1.206 \t5:b = 2.206 \n" ] }, - "execution_count": 396, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "lnαβ = αβrecursion(g, llh) # the lower the more pruning \n", - "exp.(lnαβ)" + "lnαβ, tll = αβrecursion(g, llh) # the lower the more pruning \n", + "display(tll)\n", + "lnαβ" ] }, { "cell_type": "code", - "execution_count": 404, + "execution_count": 18, "metadata": {}, "outputs": [ { @@ -1101,16 +1988,25 @@ "output_type": "stream", "text": [ "State[State(id = 5, pdfindex = 1, a)]\n", - "[-4.083151668811988]\n", + "[2.4496921953675406]\n", "-----\n", "State[State(id = 5, pdfindex = 1, a), State(id = 7, pdfindex = 2, b)]\n", - "[-4.681946269295651, -4.880490209299517]\n", + "[0.6804271107052051, 2.2143455711413766]\n", "-----\n", - "State[State(id = 5, pdfindex = 1, a), State(id = 7, pdfindex = 2, b)]\n", - "[-4.681946269295651, -Inf]\n", + "State[State(id = 6, pdfindex = 4, d), State(id = 2, pdfindex = 3, c), State(id = 7, pdfindex = 2, b)]\n", + "[1.065389135159319, -2.4302709607893522, -1.709925938893071]\n", "-----\n", - "State[State(id = 5, pdfindex = 1, a), State(id = 7, pdfindex = 2, b)]\n", - "[-4.68194626929565, -Inf]\n", + "State[State(id = 6, pdfindex = 4, d), State(id = 8, pdfindex = 1, a)]\n", + "[1.1020871927699054, -4.924617799721492]\n", + "-----\n", + "State[State(id = 6, pdfindex = 4, d), State(id = 8, pdfindex = 1, a)]\n", + "[0.08948705438494042, -4.1541803992854796]\n", + "-----\n", + "State[State(id = 6, pdfindex = 4, d), State(id = 8, pdfindex = 1, a)]\n", + "[-2.103903668577119, -0.5717629519612493]\n", + "-----\n", + "State[State(id = 8, pdfindex = 1, a)]\n", + "[-0.5717629519612493]\n", "-----\n" ] }, @@ -1120,91 +2016,129 @@ "\n", "\n", - "\n", - "\n", - "\n", + "\n", + "\n", "\n", - "%3\n", - "\n", + "\n", "\n", "\n", "initstateid\n", "\n", "\n", "\n", - "\n", + "\n", "1\n", - "\n", - "a\n", + "\n", + "a\n", "\n", "\n", "\n", "initstateid->1\n", - "\n", - "\n", - "0.0\n", + "\n", + "\n", + "0.0\n", "\n", "\n", "\n", "4\n", - "\n", - "a\n", + "\n", + "d\n", + "\n", + "\n", + "\n", + "5\n", + "\n", + "d\n", + "\n", + "\n", + "\n", + "4->5\n", + "\n", + "\n", + "0.0\n", + "\n", + "\n", + "\n", + "7\n", + "\n", + "a\n", "\n", "\n", - "\n", + "\n", "finalstateid\n", - "\n", + "\n", "\n", - "\n", - "\n", - "4->finalstateid\n", - "\n", - "\n", - "0.0\n", + "\n", + "\n", + "7->finalstateid\n", + "\n", + "\n", + "0.0\n", "\n", "\n", - "\n", + "\n", "2\n", - "\n", - "a\n", + "\n", + "a\n", "\n", "\n", - "\n", + "\n", "3\n", - "\n", - "a\n", + "\n", + "d\n", "\n", "\n", - "\n", + "\n", "2->3\n", - "\n", - "\n", - "0.0\n", + "\n", + "\n", + "0.0\n", "\n", "\n", - "\n", + "\n", "3->4\n", - "\n", - "\n", - "0.0\n", + "\n", + "\n", + "0.0\n", + "\n", + "\n", + "\n", + "6\n", + "\n", + "b\n", + "\n", + "\n", + "\n", + "5->6\n", + "\n", + "\n", + "0.0\n", + "\n", + "\n", + "\n", + "6->7\n", + "\n", + "\n", + "0.0\n", "\n", "\n", - "\n", + "\n", "1->2\n", - "\n", - "\n", - "0.0\n", + "\n", + "\n", + "0.0\n", "\n", "\n", "\n" ], "text/plain": [ - "Graph(Dict{Union{Missing, HiddenMarkovModel.FinalStateID, HiddenMarkovModel.InitStateID, Int64},State}(initstateid => State(id = initstateid, pdfindex = nothing, initstateid),4 => State(id = 4, pdfindex = 1, a),2 => State(id = 2, pdfindex = 1, a),3 => State(id = 3, pdfindex = 1, a),finalstateid => State(id = finalstateid, pdfindex = nothing, finalstateid),1 => State(id = 1, pdfindex = 1, a)))" + "Graph(Dict{Union{Missing, HiddenMarkovModel.FinalStateID, HiddenMarkovModel.InitStateID, Int64},State}(initstateid => State(id = initstateid, pdfindex = nothing, initstateid),4 => State(id = 4, pdfindex = 4, d),7 => State(id = 7, pdfindex = 1, a),2 => State(id = 2, pdfindex = 1, a),3 => State(id = 3, pdfindex = 4, d),finalstateid => State(id = finalstateid, pdfindex = nothing, finalstateid),5 => State(id = 5, pdfindex = 4, d),6 => State(id = 6, pdfindex = 2, b),1 => State(id = 1, pdfindex = 1, a)))" ] }, - "execution_count": 404, + "execution_count": 18, "metadata": {}, "output_type": "execute_result" } @@ -1215,89 +2149,23 @@ }, { "cell_type": "code", - "execution_count": 375, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Dict{Union{Missing, HiddenMarkovModel.FinalStateID, HiddenMarkovModel.InitStateID, Int64},State} with 7 entries:\n", - " initstateid => State(id = initstateid, pdfindex = nothing, initstateid)\n", - " 7 => State(id = 7, pdfindex = 2, b)\n", - " 2 => State(id = 2, pdfindex = 3, c)\n", - " finalstateid => State(id = finalstateid, pdfindex = nothing, finalstateid)\n", - " 5 => State(id = 5, pdfindex = 1, a)\n", - " 8 => State(id = 8, pdfindex = 1, a)\n", - " 6 => State(id = 6, pdfindex = 4, d)" - ] - }, - "execution_count": 375, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "g.states" - ] - }, - { - "cell_type": "code", - "execution_count": 108, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "4×3 Array{Float64,2}:\n", - " 1.0 0.0 0.0\n", - " 0.0 1.0 0.0\n", - " 0.0 0.0 0.466906\n", - " 0.0 0.0 0.533094" - ] - }, - "execution_count": 108, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "lnαβ = αβrecursion(g, llh, pruning = 2.1) # the lower the more pruning \n", - "exp.(lnαβ)" - ] - }, - { - "cell_type": "code", - "execution_count": 38, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "((3, 1), 3)" - ] - }, - "execution_count": 38, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "v = [(1, 3), (2, 2), (3, 1)]\n", - "findmax(v)" - ] + "outputs": [], + "source": [] } ], "metadata": { "kernelspec": { - "display_name": "Julia 1.4.1", + "display_name": "Julia 1.3.1", "language": "julia", - "name": "julia-1.4" + "name": "julia-1.3" }, "language_info": { "file_extension": ".jl", "mimetype": "application/julia", "name": "julia", - "version": "1.4.1" + "version": "1.3.1" } }, "nbformat": 4, diff --git a/src/HiddenMarkovModel.jl b/src/HiddenMarkovModel.jl index c43c635..e6276d7 100644 --- a/src/HiddenMarkovModel.jl +++ b/src/HiddenMarkovModel.jl @@ -232,7 +232,8 @@ function Base.show(io, ::MIME"image/svg+xml", g::AbstractGraph) for state in states(g) shape = isemitting(state) ? "circle" : "point" - write(dotfile, "$(id(state)) [ shape=\"$(shape)\", label=\"$(name(state))\" ];\n") + label = "$(id(state)):$(name(state))" + write(dotfile, "$(id(state)) [ shape=\"$(shape)\", label=\"$(label)\" ];\n") end for arc in arcs(g) src, dest, weight = id(arc[1]), id(arc[2]), round(arc[3], digits = 3) @@ -260,10 +261,31 @@ export nopruning include("graph.jl") +####################################################################### +# Pretty display the sparse matrix (i.e. from αβrecursion). + +import Printf:@sprintf +function Base.show(io::IO, ::MIME"text/plain", a::Array{Dict{State, T},1}) where T <: AbstractFloat + for n in 1:length(a) + write(io, "[n = $n] \t") + max = foldl(((sa,wa), (s,w)) -> wa < w ? (s,w) : (sa,wa), a[n]; init=first(a[n])) + write(io, first(max) |> name) + for (s, w) in sort(a[n]; by=x->name(x)) + write(io, "\t$(id(s)):$(name(s)) = $(@sprintf("%.3f", w)) ") + end + write(io, "\n") + end +end + ####################################################################### # Major algorithms include("algorithms.jl") +####################################################################### +# Other + +include("../src/misc.jl") + end diff --git a/src/algorithms.jl b/src/algorithms.jl index 128b25f..a08fdee 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -24,7 +24,7 @@ end function (pruning::ThresholdPruning)(candidates::Dict{State, T}) where T <: AbstractFloat maxval = maximum(p -> p.second, candidates) - return filter(p -> maxval - p.second ≤ pruning.Δ, candidates) + filter!(p -> maxval - p.second ≤ pruning.Δ, candidates) end @@ -43,30 +43,22 @@ Forward step of the Baum-Welch algorithm in the log-domain. """ function αrecursion(g::AbstractGraph, llh::Matrix{T}; pruning::Union{Real, NoPruning} = nopruning) where T <: AbstractFloat - pruning = pruning ≠ nopruning ? ThresholdPruning(pruning) : pruning - α = Matrix{T}(undef, size(llh)) - fill!(α, T(-Inf)) + pruning! = pruning ≠ nopruning ? ThresholdPruning(pruning) : pruning activestates = Dict{State, T}(initstate(g) => T(0.0)) - newstates = Dict{State, T}() + α = Vector{Dict{State, T}}() + for n in 1:size(llh, 2) - for state_weightpath in activestates - state, weightpath = state_weightpath - for nstate_linkweight in emittingstates(forward, state) - nstate, linkweight = nstate_linkweight + push!(α, Dict{State,T}()) + for (state, weightpath) in activestates + for (nstate, linkweight) in emittingstates(forward, state) nweightpath = weightpath + linkweight - newstates[nstate] = llh[nstate.pdfindex, n] + logaddexp(get(newstates, nstate, T(-Inf)), nweightpath) + α[n][nstate] = llh[pdfindex(nstate), n] + logaddexp(get(α[n], nstate, T(-Inf)), nweightpath) end end - for nstate_nweightpath in newstates - nstate, nweightpath = nstate_nweightpath - α[nstate.pdfindex, n] = logaddexp(α[nstate.pdfindex, n], nweightpath) - end - empty!(activestates) - merge!(activestates, pruning(newstates)) - empty!(newstates) + merge!(activestates, pruning!(α[n])) end α end @@ -78,33 +70,25 @@ Backward step of the Baum-Welch algorithm in the log domain. """ function βrecursion(g::AbstractGraph, llh::Matrix{T}; pruning::Union{Real, NoPruning} = nopruning) where T <: AbstractFloat - pruning = pruning ≠ nopruning ? ThresholdPruning(pruning) : pruning - β = Matrix{eltype(llh)}(undef, size(llh)) - fill!(β, T(-Inf)) - - activestates = Dict{State, T}(finalstate(g) => T(0.0)) - newstates = Dict{State, T}() - - for n in size(llh, 2):-1:1 - for state_weightpath in activestates - state, weightpath = state_weightpath - emitting = isemitting(state) - prev_llh = emitting ? llh[state.pdfindex, n+1] : T(0.0) - for nstate_linkweight in emittingstates(backward, state) - nstate, linkweight = nstate_linkweight + pruning! = pruning ≠ nopruning ? ThresholdPruning(pruning) : pruning + + activestates = Dict{State, T}() + β = Vector{Dict{State, T}}() + push!(β, Dict(s => T(0.0) for (s, w) in emittingstates(backward, finalstate(g)))) + + for n in size(llh, 2)-1:-1:1 + # Update the active tokens + empty!(activestates) + merge!(activestates, pruning!(β[1])) + + pushfirst!(β, Dict{State,T}()) + for (state, weightpath) in activestates + prev_llh = llh[pdfindex(state), n+1] + for (nstate, linkweight) in emittingstates(backward, state) nweightpath = weightpath + linkweight + prev_llh - newstates[nstate] = logaddexp(get(newstates, nstate, T(-Inf)), nweightpath) + β[1][nstate] = logaddexp(get(β[1], nstate, T(-Inf)), nweightpath) end end - - for nstate_nweightpath in newstates - nstate, nweightpath = nstate_nweightpath - β[nstate.pdfindex, n] = logaddexp(β[nstate.pdfindex, n], nweightpath) - end - - empty!(activestates) - merge!(activestates, pruning(newstates)) - empty!(newstates) end β end @@ -118,37 +102,57 @@ function αβrecursion(g::AbstractGraph, llh::Matrix{T}; pruning::Union{Real, NoPruning} = nopruning) where T <: AbstractFloat α = αrecursion(g, llh, pruning = pruning) β = βrecursion(g, llh, pruning = pruning) - α + β .- logsumexp(α + β, dims = 1) + + γ = Vector{Dict{State,T}}() + + for n in 1:size(llh, 2) + push!(γ, Dict{State, T}()) + for s in union(keys(α[n]), keys(β[n])) + a = get(α[n], s, T(-Inf)) + b = get(β[n], s, T(-Inf)) + γ[n][s] = a + b + end + filter!(p -> isfinite(p.second), γ[n]) + sum = logsumexp(values(γ[n])) + + for s in keys(γ[n]) + γ[n][s] -= sum + end + end + + # Total Log Likelihood + fs = foldl((acc, (s, w)) -> push!(acc, s), emittingstates(backward, finalstate(g)); init=[]) + ttl = filter(s -> s[1] in fs, α[end]) |> values |> sum + + γ, ttl end +# function total_llh() + ####################################################################### # Viterbi algorithm (find the best path) export viterbi -function maxβrecursion(g::AbstractGraph, llh::Matrix{T}, α::Matrix{T}) where T <: AbstractFloat +function maxβrecursion(g::AbstractGraph, llh::Matrix{T}, α::Vector{Dict{State,T}}) where T <: AbstractFloat bestseq = Vector{State}() activestates = Dict{State, T}(finalstate(g) => T(0.0)) newstates = Dict{State, T}() for n in size(llh, 2):-1:1 - for state_weightpath in activestates - state, weightpath = state_weightpath + for (state, weightpath) in activestates emitting = isemitting(state) prev_llh = emitting ? llh[state.pdfindex, n+1] : T(0.0) - for nstate_linkweight in emittingstates(backward, state) - nstate, linkweight = nstate_linkweight + for (nstate, linkweight) in emittingstates(backward, state) nweightpath = weightpath + linkweight + prev_llh newstates[nstate] = logaddexp(get(newstates, nstate, T(-Inf)), nweightpath) end end - hypscores = Vector{T}(undef, length(newstates)) hypstates = Vector{State}(undef, length(newstates)) - for (i, nstate_nweightpath) in enumerate(newstates) - nstate, nweightpath = nstate_nweightpath - hypscores[i] = α[nstate.pdfindex, n] + nweightpath + for (i, (nstate, nweightpath)) in enumerate(newstates) + hypscores[i] = get(α[n], nstate, T(-Inf)) + nweightpath hypstates[i] = nstate end println(hypstates) @@ -229,7 +233,7 @@ end export weightnormalize """ - normalize(graph) + weightnormalize(graph) Update the weights of the graph such that the exponentiation of the weight of all the outoing arc from a state sum up to one. @@ -287,3 +291,138 @@ function addselfloop(graph::Graph; loopprob = 0.5) g end +####################################################################### +# Union two graphs into one + +import Base: union + +""" + union(g1::AbstractGraph, g2::AbstractGraph) +""" +function Base.union(g1::AbstractGraph, g2::AbstractGraph) + g = Graph() + statecount = 0 + old2new = Dict{AbstractState, AbstractState}( + initstate(g1) => initstate(g), + finalstate(g1) => finalstate(g), + ) + for (i, state) in enumerate(states(g1)) + if id(state) ≠ finalstateid && id(state) ≠ initstateid + statecount += 1 + old2new[state] = addstate!(g, State(statecount, pdfindex(state), name(state))) + end + end + + for state in states(g1) + src = old2new[state] + for link in children(state) + link!(src, old2new[link.dest], link.weight) + end + end + + old2new = Dict{AbstractState, AbstractState}( + initstate(g2) => initstate(g), + finalstate(g2) => finalstate(g), + ) + for (i, state) in enumerate(states(g2)) + if id(state) ≠ finalstateid && id(state) ≠ initstateid + statecount += 1 + old2new[state] = addstate!(g, State(statecount, pdfindex(state), name(state))) + end + end + + for state in states(g2) + src = old2new[state] + for link in children(state) + link!(src, old2new[link.dest], link.weight) + end + end + + g |> determinize |> weightnormalize +end + +####################################################################### +# FSM minimization + +export minimize + +function leftminimize!(g::Graph, state::AbstractState) + leaves = Dict() + for link in children(state) + leaf, weight = get(leaves, pdfindex(link.dest), ([], -Inf)) + push!(leaf, link.dest) + leaves[pdfindex(link.dest)] = (leaf, logaddexp(weight, link.weight)) + end + + empty!(state.outgoing) + for (nextstates, weight) in values(leaves) + mergedstate = nextstates[1] + filter!(l -> l.dest ≠ state, mergedstate.incoming) + + # Now we removed all the extra states of the graph. + links = vcat([next.outgoing for next in nextstates[2:end]]...) + for link in links + link!(mergedstate, link.dest, link.weight) + end + + for old in nextstates[2:end] + for link in children(old) + filter!(l -> l.dest ≠ old, link.dest.incoming) + end + delete!(g.states, id(old)) + end + + # Reconnect the previous state with the merged state + link!(state, mergedstate, weight) + + # Minimize the subgraph. + leftminimize!(g, mergedstate) + end + g +end + +function rightminimize!(g::Graph, state::AbstractState) + leaves = Dict() + for link in parents(state) + leaf, weight = get(leaves, pdfindex(link.dest), ([], -Inf)) + push!(leaf, link.dest) + leaves[pdfindex(link.dest)] = (leaf, logaddexp(weight, link.weight)) + end + + empty!(state.incoming) + for (nextstates, weight) in values(leaves) + mergedstate = nextstates[1] + filter!(l -> l.dest ≠ state, mergedstate.outgoing) + + # Now we removed all the extra states of the graph. + links = vcat([next.incoming for next in nextstates[2:end]]...) + for link in links + #link!(mergedstate, link.dest, link.weight) + link!(link.dest, mergedstate, link.weight) + end + + for old in nextstates[2:end] + for link in parents(old) + filter!(l -> l.dest ≠ old, link.dest.outgoing) + end + delete!(g.states, id(old)) + end + + # Reconnect the previous state with the merged state + link!(mergedstate, state, weight) + + # Minimize the subgraph. + rightminimize!(g, mergedstate) + end + g +end + +""" + minimize(g::Graph) +""" +minimize(g::Graph) = begin + newg = deepcopy(g) + newg = leftminimize!(newg, initstate(newg)) + rightminimize!(newg, finalstate(newg)) +end + diff --git a/src/graph.jl b/src/graph.jl index e0d674e..d195b90 100644 --- a/src/graph.jl +++ b/src/graph.jl @@ -138,4 +138,3 @@ function LinearGraph(sequence::AbstractArray{String}, link!(prevstate, finalstate(g), 0.) g end - diff --git a/src/misc.jl b/src/misc.jl new file mode 100644 index 0000000..1368596 --- /dev/null +++ b/src/misc.jl @@ -0,0 +1,16 @@ +import StatsFuns: logsumexp + +function logsumexp(X::AbstractArray{T}; dims=:) where {T<:Real} + u = reduce(max, X, dims=dims, init=oftype(log(zero(T)), -Inf)) + u isa AbstractArray || isfinite(u) || return float(u) + let u=u + if u isa AbstractArray + v = u .+ log.(sum(exp.(X .- u); dims=dims)) + i = .! isfinite.(v) + v[i] .= u[i] + v + else + u + log(sum(x -> exp(x-u), X)) + end + end +end \ No newline at end of file