@@ -17,10 +17,8 @@ BatchMvNormal(μ::AbstractMatrix{T}, σ::AbstractMatrix{T}) where T<:Real = Batc
17
17
Base. eltype (d:: BMN ) = eltype (d. μ)
18
18
Distributions. params (d:: BMN ) = (d. μ, d. σ)
19
19
Distributions. mean (d:: BMN ) = d. μ
20
- Distributions. var (d:: BMN ) = d. σ .^ 2
21
-
22
- # Distributions.var(d::BatchScalMvNormal) = fill(similar(d.σ,size(d.μ,1)),1) .* reshape(d.σ .^2,1,:)
23
- # Distributions.var(d::BatchScalMvNormal) = ones(eltype(d), size(d.μ,1), 1) .* reshape(d.σ .^2, 1, :)
20
+ Distributions. var (d:: BatchDiagMvNormal ) = d. σ .^ 2
21
+ Distributions. var (d:: BatchScalMvNormal ) = fillsimilar (d. σ,size (d. μ,1 ),1 ) .* reshape (d. σ .^ 2 ,1 ,:)
24
22
25
23
function Distributions. rand (d:: BatchDiagMvNormal )
26
24
μ, σ = d. μ, d. σ
@@ -34,16 +32,9 @@ function Distributions.rand(d::BatchScalMvNormal)
34
32
μ .+ σ .* r
35
33
end
36
34
37
- function Distributions. logpdf (d:: BatchDiagMvNormal , x:: AbstractMatrix{T} ) where T<: Real
35
+ function Distributions. logpdf (d:: BMN , x:: AbstractMatrix{T} ) where T<: Real
38
36
n = size (d. μ,1 )
39
37
μ = mean (d)
40
38
σ2 = var (d)
41
39
- (vec (sum (((x - μ). ^ 2 ) ./ σ2 .+ log .(σ2), dims= 1 )) .+ n* log (T (2 π))) / 2
42
40
end
43
-
44
- function Distributions. logpdf (d:: BatchScalMvNormal , x:: AbstractMatrix{T} ) where T<: Real
45
- n = size (d. μ,1 )
46
- μ = mean (d)
47
- σ2 = reshape (var (d), 1 , :)
48
- - (vec (sum (((x - μ). ^ 2 ) ./ σ2 .+ log .(σ2), dims= 1 )) .+ n* log (T (2 π))) / 2
49
- end
0 commit comments