Skip to content

Commit d27b840

Browse files
authored
BatchMvNormal variance shape (#32)
* let BatchMvNormal always return variance of shape (xlength,batchsize)
1 parent ace6c63 commit d27b840

File tree

3 files changed

+6
-15
lines changed

3 files changed

+6
-15
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ConditionalDists"
22
uuid = "c648c4dd-c1e0-49a6-84b9-144ae7fd2468"
33
authors = ["Niklas Heim <niklas.heim@aic.fel.cvut.cz>"]
4-
version = "0.4.1"
4+
version = "0.4.2"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/batch_mvnormal.jl

+3-12
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,8 @@ BatchMvNormal(μ::AbstractMatrix{T}, σ::AbstractMatrix{T}) where T<:Real = Batc
1717
Base.eltype(d::BMN) = eltype(d.μ)
1818
Distributions.params(d::BMN) = (d.μ, d.σ)
1919
Distributions.mean(d::BMN) = d.μ
20-
Distributions.var(d::BMN) = d.σ .^2
21-
22-
#Distributions.var(d::BatchScalMvNormal) = fill(similar(d.σ,size(d.μ,1)),1) .* reshape(d.σ .^2,1,:)
23-
#Distributions.var(d::BatchScalMvNormal) = ones(eltype(d), size(d.μ,1), 1) .* reshape(d.σ .^2, 1, :)
20+
Distributions.var(d::BatchDiagMvNormal) = d.σ .^2
21+
Distributions.var(d::BatchScalMvNormal) = fillsimilar(d.σ,size(d.μ,1),1) .* reshape(d.σ .^2,1,:)
2422

2523
function Distributions.rand(d::BatchDiagMvNormal)
2624
μ, σ = d.μ, d.σ
@@ -34,16 +32,9 @@ function Distributions.rand(d::BatchScalMvNormal)
3432
μ .+ σ .* r
3533
end
3634

37-
function Distributions.logpdf(d::BatchDiagMvNormal, x::AbstractMatrix{T}) where T<:Real
35+
function Distributions.logpdf(d::BMN, x::AbstractMatrix{T}) where T<:Real
3836
n = size(d.μ,1)
3937
μ = mean(d)
4038
σ2 = var(d)
4139
-(vec(sum(((x - μ).^2) ./ σ2 .+ log.(σ2), dims=1)) .+ n*log(T(2π))) / 2
4240
end
43-
44-
function Distributions.logpdf(d::BatchScalMvNormal, x::AbstractMatrix{T}) where T<:Real
45-
n = size(d.μ,1)
46-
μ = mean(d)
47-
σ2 = reshape(var(d), 1, :)
48-
-(vec(sum(((x - μ).^2) ./ σ2 .+ log.(σ2), dims=1)) .+ n*log(T(2π))) / 2
49-
end

test/cond_mvnormal.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
σ2 = var(res)
5555
@test res isa ConditionalDists.BatchScalMvNormal
5656
@test size(μ) == (xlength,batchsize)
57-
@test size(σ2) == (batchsize,)
57+
@test size(σ2) == (xlength,batchsize)
5858

5959
x = rand(Float32, xlength, batchsize) |> gpu
6060
z = rand(Float32, zlength, batchsize) |> gpu
@@ -78,7 +78,7 @@
7878
σ2 = var(res)
7979
@test res isa ConditionalDists.BatchScalMvNormal
8080
@test size(μ) == (xlength,batchsize)
81-
@test size(σ2) == (batchsize,)
81+
@test size(σ2) == (xlength,batchsize)
8282

8383
x = rand(Float32, xlength, batchsize) |> gpu
8484
z = rand(Float32, zlength, batchsize) |> gpu

0 commit comments

Comments
 (0)