Skip to content

Commit 4cbeba8

Browse files
authored
Merge pull request #1097 from johnczito/add_singular_wishart
Add singular branch of the Wishart
2 parents 0c6cd50 + 636c695 commit 4cbeba8

File tree

4 files changed

+197
-42
lines changed

4 files changed

+197
-42
lines changed

src/matrix/wishart.jl

Lines changed: 93 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,70 @@
1-
# Wishart distribution
2-
#
3-
# following the Wikipedia parameterization
4-
#
51
"""
62
Wishart(ν, S)
73
```julia
8-
ν::Real degrees of freedom (greater than p - 1)
4+
ν::Real degrees of freedom (whole number or a real number greater than p - 1)
95
S::AbstractPDMat p x p scale matrix
106
```
117
The [Wishart distribution](http://en.wikipedia.org/wiki/Wishart_distribution)
12-
generalizes the gamma distribution to ``p\\times p`` real, positive definite
13-
matrices ``\\mathbf{H}``. If ``\\mathbf{H}\\sim \\textrm{W}_p(\\nu,\\mathbf{S})``,
14-
then its probability density function is
8+
generalizes the gamma distribution to ``p\\times p`` real, positive semidefinite
9+
matrices ``\\mathbf{H}``.
10+
11+
If ``\\nu>p-1``, then ``\\mathbf{H}\\sim \\textrm{W}_p(\\nu, \\mathbf{S})``
12+
has rank ``p`` and its probability density function is
1513
1614
```math
1715
f(\\mathbf{H};\\nu,\\mathbf{S}) = \\frac{1}{2^{\\nu p/2} \\left|\\mathbf{S}\\right|^{\\nu/2} \\Gamma_p\\left(\\frac {\\nu}{2}\\right ) }{\\left|\\mathbf{H}\\right|}^{(\\nu-p-1)/2} e^{-(1/2)\\operatorname{tr}(\\mathbf{S}^{-1}\\mathbf{H})}.
1816
```
1917
20-
If ``\\nu`` is an integer, then a random matrix ``\\mathbf{H}`` given by
18+
If ``\\nu\\leq p-1``, then ``\\mathbf{H}`` is rank ``\\nu`` and it has
19+
a density with respect to a suitably chosen volume element on the space of
20+
positive semidefinite matrices. See [here](https://doi.org/10.1214/aos/1176325375).
21+
22+
For integer ``\\nu``, a random matrix given by
2123
2224
```math
23-
\\mathbf{H} = \\mathbf{X}\\mathbf{X}^{\\rm{T}}, \\quad\\mathbf{X} \\sim \\textrm{MN}_{p,\\nu}(\\mathbf{0}, \\mathbf{S}, \\mathbf{I}_{\\nu})
25+
\\mathbf{H} = \\mathbf{X}\\mathbf{X}^{\\rm{T}},
26+
\\quad\\mathbf{X} \\sim \\textrm{MN}_{p,\\nu}(\\mathbf{0}, \\mathbf{S}, \\mathbf{I}_{\\nu})
2427
```
2528
26-
has ``\\mathbf{H}\\sim \\textrm{W}_p(\\nu, \\mathbf{S})``. For non-integer degrees of freedom,
27-
Wishart matrices can be generated via the [Bartlett decomposition](https://en.wikipedia.org/wiki/Wishart_distribution#Bartlett_decomposition).
29+
has ``\\mathbf{H}\\sim \\textrm{W}_p(\\nu, \\mathbf{S})``.
30+
For non-integer ``\\nu``, Wishart matrices can be generated via the
31+
[Bartlett decomposition](https://en.wikipedia.org/wiki/Wishart_distribution#Bartlett_decomposition).
2832
"""
29-
struct Wishart{T<:Real, ST<:AbstractPDMat} <: ContinuousMatrixDistribution
30-
df::T # degree of freedom
31-
S::ST # the scale matrix
32-
logc0::T # the logarithm of normalizing constant in pdf
33+
struct Wishart{T<:Real, ST<:AbstractPDMat, R<:Integer} <: ContinuousMatrixDistribution
34+
df::T # degree of freedom
35+
S::ST # the scale matrix
36+
logc0::T # the logarithm of normalizing constant in pdf
37+
rank::R # rank of a sample
38+
singular::Bool # singular of nonsingular wishart?
3339
end
3440

3541
# -----------------------------------------------------------------------------
3642
# Constructors
3743
# -----------------------------------------------------------------------------
3844

39-
function Wishart(df::T, S::AbstractPDMat{T}) where T<:Real
45+
function Wishart(df::T, S::AbstractPDMat{T}, warn::Bool = true) where T<:Real
46+
df > 0 || throw(ArgumentError("df must be positive. got $(df)."))
4047
p = dim(S)
41-
df > p - 1 || throw(ArgumentError("df should be greater than dim - 1."))
42-
logc0 = wishart_logc0(df, S)
48+
rnk = p
49+
singular = df <= p - 1
50+
if singular
51+
isinteger(df) || throw(ArgumentError("singular df must be an integer. got $(df)."))
52+
rnk = convert(Integer, df)
53+
warn && @warn("got df <= dim - 1; returning a singular Wishart")
54+
end
55+
logc0 = wishart_logc0(df, S, rnk)
4356
R = Base.promote_eltype(T, logc0)
4457
prom_S = convert(AbstractArray{T}, S)
45-
Wishart{R, typeof(prom_S)}(R(df), prom_S, R(logc0))
58+
Wishart{R, typeof(prom_S), typeof(rnk)}(R(df), prom_S, R(logc0), rnk, singular)
4659
end
4760

48-
function Wishart(df::Real, S::AbstractPDMat)
61+
function Wishart(df::Real, S::AbstractPDMat, warn::Bool = true)
4962
T = Base.promote_eltype(df, S)
50-
Wishart(T(df), convert(AbstractArray{T}, S))
63+
Wishart(T(df), convert(AbstractArray{T}, S), warn)
5164
end
5265

53-
Wishart(df::Real, S::Matrix) = Wishart(df, PDMat(S))
54-
55-
Wishart(df::Real, S::Cholesky) = Wishart(df, PDMat(S))
66+
Wishart(df::Real, S::Matrix, warn::Bool = true) = Wishart(df, PDMat(S), warn)
67+
Wishart(df::Real, S::Cholesky, warn::Bool = true) = Wishart(df, PDMat(S), warn)
5668

5769
# -----------------------------------------------------------------------------
5870
# REPL display
@@ -66,23 +78,30 @@ show(io::IO, d::Wishart) = show_multline(io, d, [(:df, d.df), (:S, Matrix(d.S))]
6678

6779
function convert(::Type{Wishart{T}}, d::Wishart) where T<:Real
6880
P = convert(AbstractArray{T}, d.S)
69-
Wishart{T, typeof(P)}(T(d.df), P, T(d.logc0))
81+
Wishart{T, typeof(P), typeof(d.rank)}(T(d.df), P, T(d.logc0), d.rank, d.singular)
7082
end
71-
function convert(::Type{Wishart{T}}, df, S::AbstractPDMat, logc0) where T<:Real
83+
function convert(::Type{Wishart{T}}, df, S::AbstractPDMat, logc0, rnk, singular) where T<:Real
7284
P = convert(AbstractArray{T}, S)
73-
Wishart{T, typeof(P)}(T(df), P, T(logc0))
85+
Wishart{T, typeof(P), typeof(rnk)}(T(df), P, T(logc0), rnk, singular)
7486
end
7587

7688
# -----------------------------------------------------------------------------
7789
# Properties
7890
# -----------------------------------------------------------------------------
7991

80-
insupport(::Type{Wishart}, X::Matrix) = isposdef(X)
81-
insupport(d::Wishart, X::Matrix) = size(X) == size(d) && isposdef(X)
92+
insupport(::Type{Wishart}, X::AbstractMatrix) = ispossemdef(X)
93+
function insupport(d::Wishart, X::AbstractMatrix)
94+
size(X) == size(d) || return false
95+
if d.singular
96+
return ispossemdef(X, rank(d))
97+
else
98+
return isposdef(X)
99+
end
100+
end
82101

83102
dim(d::Wishart) = dim(d.S)
84103
size(d::Wishart) = (p = dim(d); (p, p))
85-
rank(d::Wishart) = dim(d)
104+
rank(d::Wishart) = d.rank
86105
params(d::Wishart) = (d.df, d.S)
87106
@inline partype(d::Wishart{T}) where {T<:Real} = T
88107

@@ -95,6 +114,7 @@ function mode(d::Wishart)
95114
end
96115

97116
function meanlogdet(d::Wishart)
117+
d.singular && return -Inf
98118
p = dim(d)
99119
df = d.df
100120
v = logdet(d.S) + p * logtwo
@@ -105,6 +125,7 @@ function meanlogdet(d::Wishart)
105125
end
106126

107127
function entropy(d::Wishart)
128+
d.singular && throw(ArgumentError("entropy not defined for singular Wishart."))
108129
p = dim(d)
109130
df = d.df
110131
-d.logc0 - 0.5 * (df - p - 1) * meanlogdet(d) + 0.5 * df * p
@@ -125,13 +146,44 @@ end
125146
# Evaluation
126147
# -----------------------------------------------------------------------------
127148

128-
function wishart_logc0(df::Real, S::AbstractPDMat)
129-
h_df = df / 2
149+
function wishart_logc0(df::Real, S::AbstractPDMat, rnk::Integer)
130150
p = dim(S)
131-
-h_df * (logdet(S) + p * typeof(df)(logtwo)) - logmvgamma(p, h_df)
151+
if df <= p - 1
152+
return singular_wishart_logc0(p, df, S, rnk)
153+
else
154+
return nonsingular_wishart_logc0(p, df, S)
155+
end
132156
end
133157

134158
function logkernel(d::Wishart, X::AbstractMatrix)
159+
if d.singular
160+
return singular_wishart_logkernel(d, X)
161+
else
162+
return nonsingular_wishart_logkernel(d, X)
163+
end
164+
end
165+
166+
# Singular Wishart pdf: Theorem 6 in Uhlig (1994 AoS)
167+
function singular_wishart_logc0(p::Integer, df::Real, S::AbstractPDMat, rnk::Integer)
168+
h_df = df / 2
169+
-h_df * (logdet(S) + p * typeof(df)(logtwo)) - logmvgamma(rnk, h_df) + (rnk*(rnk - p) / 2)*typeof(df)(logπ)
170+
end
171+
172+
function singular_wishart_logkernel(d::Wishart, X::AbstractMatrix)
173+
df = d.df
174+
p = dim(d)
175+
r = rank(d)
176+
L = eigvals(Hermitian(X), (p - r + 1):p)
177+
0.5 * ((df - (p + 1)) * sum(log.(L)) - tr(d.S \ X))
178+
end
179+
180+
# Nonsingular Wishart pdf
181+
function nonsingular_wishart_logc0(p::Integer, df::Real, S::AbstractPDMat)
182+
h_df = df / 2
183+
-h_df * (logdet(S) + p * typeof(df)(logtwo)) - logmvgamma(p, h_df)
184+
end
185+
186+
function nonsingular_wishart_logkernel(d::Wishart, X::AbstractMatrix)
135187
df = d.df
136188
p = dim(d)
137189
Xcf = cholesky(X)
@@ -143,7 +195,12 @@ end
143195
# -----------------------------------------------------------------------------
144196

145197
function _rand!(rng::AbstractRNG, d::Wishart, A::AbstractMatrix)
146-
_wishart_genA!(rng, dim(d), d.df, A)
198+
if d.singular
199+
A .= zero(eltype(A))
200+
A[:, 1:rank(d)] = randn(rng, dim(d), rank(d))
201+
else
202+
_wishart_genA!(rng, dim(d), d.df, A)
203+
end
147204
unwhiten!(d.S, A)
148205
A .= A * A'
149206
end
@@ -179,7 +236,7 @@ end
179236

180237
function _rand_params(::Type{Wishart}, elty, n::Int, p::Int)
181238
n == p || throw(ArgumentError("dims must be equal for Wishart"))
182-
ν = elty( n + 1 + abs(10randn()) )
239+
ν = elty( n - 1 + abs(10randn()) )
183240
S = (X = 2rand(elty, n, n) .- 1; X * X')
184241
return ν, S
185242
end

src/utils.jl

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,58 @@ function trycholesky(a::Matrix{Float64})
6464
return e
6565
end
6666
end
67+
68+
"""
69+
ispossemdef(A, k) -> Bool
70+
Test whether a matrix is positive semi-definite with specified rank `k` by
71+
checking that `k` of its eigenvalues are positive and the rest are zero.
72+
# Examples
73+
```jldoctest
74+
julia> A = [1 0; 0 0]
75+
2×2 Array{Int64,2}:
76+
1 0
77+
0 0
78+
julia> ispossemdef(A, 1)
79+
true
80+
julia> ispossemdef(A, 2)
81+
false
82+
```
83+
"""
84+
function ispossemdef(X::AbstractMatrix, k::Int;
85+
atol::Real=0.0,
86+
rtol::Real=(minimum(size(X))*eps(real(float(one(eltype(X))))))*iszero(atol))
87+
_check_rank_range(k, minimum(size(X)))
88+
ishermitian(X) || return false
89+
dp, dz, dn = eigsigns(Hermitian(X), atol, rtol)
90+
return dn == 0 && dp == k
91+
end
92+
function ispossemdef(X::AbstractMatrix;
93+
atol::Real=0.0,
94+
rtol::Real=(minimum(size(X))*eps(real(float(one(eltype(X))))))*iszero(atol))
95+
ishermitian(X) || return false
96+
dp, dz, dn = eigsigns(Hermitian(X), atol, rtol)
97+
return dn == 0
98+
end
99+
100+
function _check_rank_range(k::Int, n::Int)
101+
0 <= k <= n || throw(ArgumentError("rank must be between 0 and $(n) (inclusive)"))
102+
nothing
103+
end
104+
105+
# return counts of the number of positive, zero, and negative eigenvalues
106+
function eigsigns(X::AbstractMatrix,
107+
atol::Real=0.0,
108+
rtol::Real=(minimum(size(X))*eps(real(float(one(eltype(X))))))*iszero(atol))
109+
eigs = eigvals(X)
110+
eigsigns(eigs, atol, rtol)
111+
end
112+
function eigsigns(eigs::Vector{<: Real}, atol::Real, rtol::Real)
113+
tol = max(atol, rtol * eigs[end])
114+
eigsigns(eigs, tol)
115+
end
116+
function eigsigns(eigs::Vector{<: Real}, tol::Real)
117+
dp = count(x -> tol < x, eigs) # number of positive eigenvalues
118+
dz = count(x -> -tol < x < tol, eigs) # number of numerically zero eigenvalues
119+
dn = count(x -> x < -tol, eigs) # number of negative eigenvalues
120+
return dp, dz, dn
121+
end

test/matrixvariates.jl

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ import Distributions: _univariate, _multivariate, _rand_params
2828

2929
function test_draw(d::MatrixDistribution, X::AbstractMatrix)
3030
@test size(d) == size(X)
31-
@test size(d) == size(mean(d))
3231
@test size(d, 1) == size(X, 1)
3332
@test size(d, 2) == size(X, 2)
3433
@test length(d) == length(X)
@@ -331,18 +330,32 @@ function test_special(dist::Type{Wishart})
331330
@test pvalue(kstest) >= α
332331
end
333332
@testset "H ~ W(ν, I) ⟹ H[i, i] ~ χ²(ν)" begin
334-
ν = n + 1
335-
ρ = Chisq(ν)
336-
d = Wishart(ν, ScalMat(n, 1))
333+
κ = n + 1
334+
ρ = Chisq(κ)
335+
g = Wishart(κ, ScalMat(n, 1))
337336
mymats = zeros(n, n, M)
338337
for m in 1:M
339-
mymats[:, :, m] = rand(d)
338+
mymats[:, :, m] = rand(g)
340339
end
341340
for i in 1:n
342341
kstest = ExactOneSampleKSTest(mymats[i, i, :], ρ)
343342
@test pvalue(kstest) >= α / n
344343
end
345344
end
345+
@testset "Check Singular Branch" begin
346+
X = H[1]
347+
rank1 = Wishart(n - 2, Σ, false)
348+
rank2 = Wishart(n - 1, Σ, false)
349+
test_draw(rank1)
350+
test_draw(rank2)
351+
test_draws(rank1, rand(rank1, 10^6))
352+
test_draws(rank2, rand(rank2, 10^6))
353+
test_cov(rank1)
354+
test_cov(rank2)
355+
@test Distributions.singular_wishart_logkernel(d, X) Distributions.nonsingular_wishart_logkernel(d, X)
356+
@test Distributions.singular_wishart_logc0(n, ν, d.S, rank(d)) Distributions.nonsingular_wishart_logc0(n, ν, d.S)
357+
@test logpdf(d, X) Distributions.singular_wishart_logkernel(d, X) + Distributions.singular_wishart_logc0(n, ν, d.S, rank(d))
358+
end
346359
nothing
347360
end
348361

test/utils.jl

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using Distributions, PDMats
22
using Test, LinearAlgebra
3-
3+
import Distributions: ispossemdef
44

55
# RealInterval
66
r = RealInterval(1.5, 4.0)
@@ -36,3 +36,33 @@ N = GenericArray([1.0 0.0; 1.0 0.0])
3636

3737
@test Distributions.isApproxSymmmetric(N) == false
3838
@test Distributions.isApproxSymmmetric(M)
39+
40+
41+
n = 10
42+
areal = randn(n,n)/2
43+
aimg = randn(n,n)/2
44+
@testset "For A containing $eltya" for eltya in (Float32, Float64, ComplexF32, ComplexF64, Int)
45+
ainit = eltya == Int ? rand(1:7, n, n) : convert(Matrix{eltya}, eltya <: Complex ? complex.(areal, aimg) : areal)
46+
@testset "Positive semi-definiteness" begin
47+
notsymmetric = ainit
48+
notsquare = [ainit ainit]
49+
@test !ispossemdef(notsymmetric)
50+
@test !ispossemdef(notsquare)
51+
for truerank in 0:n
52+
X = ainit[:, 1:truerank]
53+
A = truerank == 0 ? zeros(eltya, n, n) : X * X'
54+
@test ispossemdef(A)
55+
for testrank in 0:n
56+
if testrank == truerank
57+
@test ispossemdef(A, testrank)
58+
else
59+
@test !ispossemdef(A, testrank)
60+
end
61+
end
62+
@test !ispossemdef(notsymmetric, truerank)
63+
@test !ispossemdef(notsquare, truerank)
64+
@test_throws ArgumentError ispossemdef(A, -1)
65+
@test_throws ArgumentError ispossemdef(A, n + 1)
66+
end
67+
end
68+
end

0 commit comments

Comments
 (0)