Skip to content
Closed
3 changes: 2 additions & 1 deletion src/StructuralEquationModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ include("frontend/fit/summary.jl")
include("frontend/pretty_printing.jl")
# observed
include("observed/abstract.jl")
include("observed/covariance.jl")
include("observed/data.jl")
include("observed/covariance.jl")
include("observed/missing_pattern.jl")
include("observed/missing.jl")
include("observed/EM.jl")
# constructor
Expand Down
5 changes: 0 additions & 5 deletions src/additional_functions/helper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,6 @@ function sparse_outer_mul!(C, A, B::Vector, ind) #computes A*S*B -> C, where ind
end
end

function cov_and_mean(rows; corrected = false)
obs_mean, obs_cov = StatsBase.mean_and_cov(reduce(hcat, rows), 2, corrected = corrected)
return obs_cov, vec(obs_mean)
end

# n²×(n(n+1)/2) matrix to transform a vector of lower
# triangular entries into a vectorized form of a n×n symmetric matrix,
# opposite of elimination_matrix()
Expand Down
71 changes: 15 additions & 56 deletions src/frontend/fit/fitmeasures/minus2ll.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,74 +31,33 @@ minus2ll(minimum::Number, obs, imp::Union{RAM, RAMSymbolic}, loss_ml::SemWLS) =
# compute likelihood for missing data - H0 -------------------------------------------------
# -2ll = (∑ log(2π)*(nᵢ + mᵢ)) + F*n
function minus2ll(minimum::Number, observed, imp::Union{RAM, RAMSymbolic}, loss_ml::SemFIML)
F = minimum
F *= nsamples(observed)
F += sum(log(2π) * observed.pattern_nsamples .* observed.pattern_nobs_vars)
F = minimum * nsamples(observed)
F += log(2π) * sum(pat -> nsamples(pat) * nmeasured_vars(pat), observed.patterns)
return F
end

# compute likelihood for missing data - H1 -------------------------------------------------
# -2ll = ∑ log(2π)*(nᵢ + mᵢ) + ln(Σᵢ) + (mᵢ - μᵢ)ᵀ Σᵢ⁻¹ (mᵢ - μᵢ)) + tr(SᵢΣᵢ)
function minus2ll(observed::SemObservedMissing)
if observed.em_model.fitted
minus2ll(
observed.em_model.μ,
observed.em_model.Σ,
nsamples(observed),
pattern_rows(observed),
observed.patterns,
observed.obs_mean,
observed.obs_cov,
observed.pattern_nsamples,
observed.pattern_nobs_vars,
)
else
em_mvn(observed)
minus2ll(
observed.em_model.μ,
observed.em_model.Σ,
nsamples(observed),
pattern_rows(observed),
observed.patterns,
observed.obs_mean,
observed.obs_cov,
observed.pattern_nsamples,
observed.pattern_nobs_vars,
)
end
end

function minus2ll(
μ,
Σ,
N,
rows,
patterns,
obs_mean,
obs_cov,
pattern_nsamples,
pattern_nobs_vars,
)
F = 0.0
# fit EM-based mean and cov if not yet fitted
# FIXME EM could be very computationally expensive
observed.em_model.fitted || em_mvn(observed)

for i in 1:length(rows)
nᵢ = pattern_nsamples[i]
# missing pattern
pattern = patterns[i]
# observed data
Sᵢ = obs_cov[i]
Σ = observed.em_model.Σ
μ = observed.em_model.μ

F = sum(observed.patterns) do pat
# implied covariance/mean
Σᵢ = Σ[pattern, pattern]
ld = logdet(Σᵢ)
Σᵢ⁻¹ = inv(cholesky(Σᵢ))
meandiffᵢ = obs_mean[i] - μ[pattern]
Σᵢ = Σ[pat.measured_mask, pat.measured_mask]
Σᵢ_chol = cholesky!(Σᵢ)
ld = logdet(Σᵢ_chol)
Σᵢ⁻¹ = LinearAlgebra.inv!(Σᵢ_chol)
meandiffᵢ = pat.measured_mean - μ[pat.measured_mask]

F += F_one_pattern(meandiffᵢ, Σᵢ⁻¹, Sᵢ, ld, nᵢ)
F_one_pattern(meandiffᵢ, Σᵢ⁻¹, pat.measured_cov, ld, nsamples(pat))
end

F += sum(log(2π) * pattern_nsamples .* pattern_nobs_vars)
#F *= N
F += log(2π) * sum(pat -> nsamples(pat) * nmeasured_vars(pat), observed.patterns)

return F
end
Expand Down
18 changes: 12 additions & 6 deletions src/frontend/specification/Sem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
############################################################################################

function Sem(;
specification = RAMMatrices,
observed::O = SemObservedData,
imply::I = RAM,
loss::L = SemML,
Expand All @@ -12,7 +13,7 @@ function Sem(;

set_field_type_kwargs!(kwdict, observed, imply, loss, O, I)

observed, imply, loss = get_fields!(kwdict, observed, imply, loss)
observed, imply, loss = get_fields!(kwdict, specification, observed, imply, loss)

sem = Sem(observed, imply, loss)

Expand Down Expand Up @@ -59,6 +60,7 @@ Returns the loss part of a model.
loss(model::AbstractSemSingle) = model.loss

function SemFiniteDiff(;
specification = RAMMatrices,
observed::O = SemObservedData,
imply::I = RAM,
loss::L = SemML,
Expand All @@ -68,7 +70,7 @@ function SemFiniteDiff(;

set_field_type_kwargs!(kwdict, observed, imply, loss, O, I)

observed, imply, loss = get_fields!(kwdict, observed, imply, loss)
observed, imply, loss = get_fields!(kwdict, specification, observed, imply, loss)

sem = SemFiniteDiff(observed, imply, loss)

Expand Down Expand Up @@ -96,23 +98,27 @@ function set_field_type_kwargs!(kwargs, observed, imply, loss, O, I)
end

# construct Sem fields
function get_fields!(kwargs, observed, imply, loss)
function get_fields!(kwargs, specification, observed, imply, loss)
if !isa(specification, SemSpecification)
specification = specification(; kwargs...)
end

# observed
if !isa(observed, SemObserved)
observed = observed(; kwargs...)
observed = observed(; specification, kwargs...)
end
kwargs[:observed] = observed

# imply
if !isa(imply, SemImply)
imply = imply(; kwargs...)
imply = imply(; specification, kwargs...)
end

kwargs[:imply] = imply
kwargs[:nparams] = nparams(imply)

# loss
loss = get_SemLoss(loss; kwargs...)
loss = get_SemLoss(loss; specification, kwargs...)
kwargs[:loss] = loss

return observed, imply, loss
Expand Down
76 changes: 36 additions & 40 deletions src/loss/ML/FIML.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,23 +47,27 @@ end
### Constructors
############################################################################################

function SemFIML(; observed, specification, kwargs...)
inverses = broadcast(x -> zeros(x, x), pattern_nobs_vars(observed))
function SemFIML(; observed::SemObservedMissing, specification, kwargs...)
inverses =
[zeros(nmeasured_vars(pat), nmeasured_vars(pat)) for pat in observed.patterns]
choleskys = Array{Cholesky{Float64, Array{Float64, 2}}, 1}(undef, length(inverses))

n_patterns = size(pattern_rows(observed), 1)
n_patterns = length(observed.patterns)
logdets = zeros(n_patterns)

imp_mean = zeros.(pattern_nobs_vars(observed))
meandiff = zeros.(pattern_nobs_vars(observed))
imp_mean = [zeros(nmeasured_vars(pat)) for pat in observed.patterns]
meandiff = [zeros(nmeasured_vars(pat)) for pat in observed.patterns]

nobs_vars = nobserved_vars(observed)
imp_inv = zeros(nobs_vars, nobs_vars)
mult = similar.(inverses)

∇ind = vec(CartesianIndices(Array{Float64}(undef, nobs_vars, nobs_vars)))
∇ind =
[findall(x -> !(x[1] ∈ ind || x[2] ∈ ind), ∇ind) for ind in patterns_not(observed)]
# generate linear indicies of co-observed variable pairs for each pattern
Σ_linind = LinearIndices((nobs_vars, nobs_vars))
∇ind = map(observed.patterns) do pat
pat_vars = findall(pat.measured_mask)
vec(Σ_linind[pat_vars, pat_vars])
end

return SemFIML(
ExactHessian(),
Expand Down Expand Up @@ -104,10 +108,10 @@ function evaluate!(
prepare_SemFIML!(semfiml, model)

scale = inv(nsamples(observed(model)))
obs_rows = pattern_rows(observed(model))
isnothing(objective) || (objective = scale * F_FIML(obs_rows, semfiml, model, params))
isnothing(objective) ||
(objective = scale * F_FIML(observed(model), semfiml, model, params))
isnothing(gradient) ||
(∇F_FIML!(gradient, obs_rows, semfiml, model); gradient .*= scale)
(∇F_FIML!(gradient, observed(model), semfiml, model); gradient .*= scale)

return objective
end
Expand All @@ -131,16 +135,16 @@ function F_one_pattern(meandiff, inverse, obs_cov, logdet, N)
return F * N
end

function ∇F_one_pattern(μ_diff, Σ⁻¹, S, pattern, ∇ind, N, Jμ, JΣ, model)
function ∇F_one_pattern(μ_diff, Σ⁻¹, S, obs_mask, ∇ind, N, Jμ, JΣ, model)
diff⨉inv = μ_diff' * Σ⁻¹

if N > one(N)
JΣ[∇ind] .+= N * vec(Σ⁻¹ * (I - S * Σ⁻¹ - μ_diff * diff⨉inv))
@. Jμ[pattern] += (N * 2 * diff⨉inv)'
@. Jμ[obs_mask] += (N * 2 * diff⨉inv)'

else
JΣ[∇ind] .+= vec(Σ⁻¹ * (I - μ_diff * diff⨉inv))
@. Jμ[pattern] += (2 * diff⨉inv)'
@. Jμ[obs_mask] += (2 * diff⨉inv)'
end
end

Expand All @@ -163,32 +167,32 @@ function ∇F_fiml_outer!(G, JΣ, Jμ, imply, model, semfiml)
mul!(G, ∇μ', Jμ, -1, 1)
end

function F_FIML(rows, semfiml, model, params)
function F_FIML(observed::SemObservedMissing, semfiml, model, params)
F = zero(eltype(params))
for i in 1:size(rows, 1)
for (i, pat) in enumerate(observed.patterns)
F += F_one_pattern(
semfiml.meandiff[i],
semfiml.inverses[i],
obs_cov(observed(model))[i],
pat.measured_cov,
semfiml.logdets[i],
pattern_nsamples(observed(model))[i],
nsamples(pat),
)
end
return F
end

function ∇F_FIML!(G, rows, semfiml, model)
function ∇F_FIML!(G, observed::SemObservedMissing, semfiml, model)
Jμ = zeros(nobserved_vars(model))
JΣ = zeros(nobserved_vars(model)^2)

for i in 1:size(rows, 1)
for (i, pat) in enumerate(observed.patterns)
∇F_one_pattern(
semfiml.meandiff[i],
semfiml.inverses[i],
obs_cov(observed(model))[i],
patterns(observed(model))[i],
pat.measured_cov,
pat.measured_mask,
semfiml.∇ind[i],
pattern_nsamples(observed(model))[i],
nsamples(pat),
Jμ,
JΣ,
model,
Expand All @@ -202,29 +206,21 @@ function prepare_SemFIML!(semfiml, model)
batch_cholesky!(semfiml, model)
#batch_sym_inv_update!(semfiml, model)
batch_inv!(semfiml, model)
for i in 1:size(pattern_nsamples(observed(model)), 1)
semfiml.meandiff[i] .= obs_mean(observed(model))[i] - semfiml.imp_mean[i]
for (i, pat) in enumerate(observed(model).patterns)
semfiml.meandiff[i] .= pat.measured_mean .- semfiml.imp_mean[i]
end
end

function copy_per_pattern!(inverses, source_inverses, means, source_means, patterns)
@views for i in 1:size(patterns, 1)
inverses[i] .= source_inverses[patterns[i], patterns[i]]
end

@views for i in 1:size(patterns, 1)
means[i] .= source_means[patterns[i]]
function copy_per_pattern!(fiml::SemFIML, model::AbstractSem)
Σ = imply(model).Σ
μ = imply(model).μ
data = observed(model)
@inbounds @views for (i, pat) in enumerate(data.patterns)
fiml.inverses[i] .= Σ[pat.measured_mask, pat.measured_mask]
fiml.imp_mean[i] .= μ[pat.measured_mask]
end
end

copy_per_pattern!(semfiml, model::M where {M <: AbstractSem}) = copy_per_pattern!(
semfiml.inverses,
imply(model).Σ,
semfiml.imp_mean,
imply(model).μ,
patterns(observed(model)),
)

function batch_cholesky!(semfiml, model)
for i in 1:size(semfiml.inverses, 1)
semfiml.choleskys[i] = cholesky!(Symmetric(semfiml.inverses[i]))
Expand Down
34 changes: 20 additions & 14 deletions src/observed/EM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ function em_mvn(
𝔼xxᵀ_pre = zeros(nvars, nvars)

### precompute for full cases
if length(observed.patterns[1]) == nvars
for row in pattern_rows(observed)[1]
row = observed.data_rowwise[row]
fullpat = observed.patterns[1]
if nmissed_vars(fullpat) == 0
for row in eachrow(fullpat.data)
𝔼x_pre += row
𝔼xxᵀ_pre += row * row'
end
Expand Down Expand Up @@ -97,21 +97,27 @@ function em_mvn_Estep!(𝔼x, 𝔼xxᵀ, em_model, observed, 𝔼x_pre, 𝔼xx
Σ = em_model.Σ

# Compute the expected sufficient statistics
for i in 2:length(observed.pattern_nsamples)
for pat in observed.patterns
(nmissed_vars(pat) == 0) && continue # skip full cases

# observed and unobserved vars
u = observed.patterns_not[i]
o = observed.patterns[i]
u = pat.miss_mask
o = pat.measured_mask

# precompute for pattern
V = Σ[u, u] - Σ[u, o] * (Σ[o, o] \ Σ[o, u])
Σoo = Σ[o, o]
Σuo = Σ[u, o]
μu = μ[u]
μo = μ[o]

V = Σ[u, u] - Σuo * (Σoo \ Σ[o, u])

# loop trough data
for row in pattern_rows(observed)[i]
m = μ[u] + Σ[u, o] * (Σ[o, o] \ (observed.data_rowwise[row] - μ[o]))
for rowdata in eachrow(pat.data)
m = μu + Σuo * (Σoo \ (rowdata - μo))

𝔼xᵢ[u] = m
𝔼xᵢ[o] = observed.data_rowwise[row]
𝔼xᵢ[o] = rowdata
𝔼xxᵀᵢ[u, u] = 𝔼xᵢ[u] * 𝔼xᵢ[u]' + V
𝔼xxᵀᵢ[o, o] = 𝔼xᵢ[o] * 𝔼xᵢ[o]'
𝔼xxᵀᵢ[o, u] = 𝔼xᵢ[o] * 𝔼xᵢ[u]'
Expand Down Expand Up @@ -153,10 +159,10 @@ end

# use μ and Σ of full cases
function start_em_observed(observed::SemObservedMissing; kwargs...)
if (length(observed.patterns[1]) == nobserved_vars(observed)) &
(observed.pattern_nsamples[1] > 1)
μ = copy(observed.obs_mean[1])
Σ = copy(Symmetric(observed.obs_cov[1]))
fullpat = observed.patterns[1]
if (nmissed_vars(fullpat) == 0) && (nobserved_vars(fullpat) > 1)
μ = copy(fullpat.measured_mean)
Σ = copy(Symmetric(fullpat.measured_cov))
if !isposdef(Σ)
Σ = Matrix(Diagonal(Σ))
end
Expand Down
Loading
Loading