-
Notifications
You must be signed in to change notification settings - Fork 1
/
ExpectationMaximization.jl
59 lines (47 loc) · 1.68 KB
/
ExpectationMaximization.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
module ExpectationMaximization
using ArgCheck
using Distributions
using Distributions: ArrayOfUnivariateDistribution, VectorOfUnivariateDistribution # for product distributions
using LogExpFunctions: logsumexp!, logsumexp
using StatsBase: weights
using Random # to add @kwdef
# Extended functions
import Distributions: fit_mle, params
export fit_mle, fit_mle!
abstract type AbstractEM end
# Utilities
size_sample(y::AbstractMatrix) = size(y, 2)
size_sample(y::AbstractVector) = length(y)
argmaxrow(M) = [argmax(r) for r in eachrow(M)]
"""
predict(mix::MixtureModel, y::AbstractVector; robust=false)
Evaluate the most likely category for each observations given a `MixtureModel`.
- `robust = true` will prevent the (log)likelihood to overflow to `-∞` or `∞`.
"""
function predict(mix::MixtureModel, y::AbstractVecOrMat; robust=false)
return argmaxrow(predict_proba(mix, y; robust=robust))
end
"""
predict_proba(mix::MixtureModel, y::AbstractVecOrMat; robust=false)
Evaluate the probability for each observations to belong to a category given a `MixtureModel`..
- `robust = true` will prevent the (log)likelihood to under(overflow)flow to `-∞` (or `∞`).
"""
function predict_proba(mix::MixtureModel, y::AbstractVecOrMat; robust=false)
# evaluate likelihood for each components k
dists = mix.components
α = probs(mix)
K = length(dists)
N = size_sample(y)
LL = zeros(N, K)
γ = similar(LL)
c = zeros(N)
E_step!(LL, c, γ, dists, α, y; robust=robust)
return γ
end
include("that_should_be_in_Distributions.jl")
include("fit_em.jl")
include("classic_em.jl")
include("stochastic_em.jl")
export ClassicEM, StochasticEM
export predict_proba, predict
end