diff --git a/Project.toml b/Project.toml index 5d59f942..d2bb053d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "OnlineStats" uuid = "a15396b6-48d5-5d58-9928-6d29437db91e" -version = "1.6.0" +version = "1.6.1" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" diff --git a/src/stats/distributions.jl b/src/stats/distributions.jl index de847fb5..11ecd927 100644 --- a/src/stats/distributions.jl +++ b/src/stats/distributions.jl @@ -198,10 +198,14 @@ nobs(o::FitMvNormal) = nobs(o.cov) _fit!(o::FitMvNormal, y) = _fit!(o.cov, y) function value(o::FitMvNormal) c = cov(o.cov) - if isposdef(c) + if isposdef(c) || (iszero(c) && nobs(o) > 1) return mean(o.cov), c else return zeros(nvars(o)), Matrix(1.0I, nvars(o), nvars(o)) end end _merge!(o::FitMvNormal, o2::FitMvNormal) = _merge!(o.cov, o2.cov) + +Statistics.mean(o::FitMvNormal) = mean(o.cov) +Statistics.var(o::FitMvNormal) = var(o.cov) +Statistics.cov(o::FitMvNormal) = cov(o.cov) \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index fe67cbdc..cc2dfb6e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -127,9 +127,14 @@ end end @testset "FitMvNormal" begin @test value(FitMvNormal(2)) == (zeros(2), Matrix(I, 2, 2)) - a, b = mergevals(FitMvNormal(2), OnlineStatsBase.eachrow([y y2]), OnlineStatsBase.eachrow([y2 y])) - @test a[1] ≈ b[1] - @test a[2] ≈ b[2] + a, b = mergestats(FitMvNormal(2), OnlineStatsBase.eachrow([y y2]), OnlineStatsBase.eachrow([y2 y])) + @test value(a)[1] ≈ value(b)[1] + @test value(a)[2] ≈ value(b)[2] + + @test all(mean(a) .≈ mean(ys)) + @test all(var(a) .≈ var(ys)) + @test cov(a)[1] ≈ cov(a)[4] ≈ var(ys) + @test cov(a)[2] ≈ cov(a)[3] end end #-----------------------------------------------------------------------# FastNode