Skip to content

Commit e0cd6bc

Browse files
authored
Constant (scalar) variance in ConditionalMvNormal mapping (#42)
1 parent 0e64e9c commit e0cd6bc

File tree

2 files changed

+35
-2
lines changed

2 files changed

+35
-2
lines changed

src/cond_mvnormal.jl

+9
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,18 @@ function condition(p::ConditionalMvNormal, z::AbstractMatrix)
4747
BatchMvNormal(μ,σ)
4848
end
4949

50+
# dispatches for different outputs from mappings
51+
# general case
5052
mean_var(x::Tuple) = x
53+
# single output assumes σ=1
5154
mean_var(x::AbstractVector) = (x, 1)
5255
mean_var(x::AbstractMatrix) = (x, fillsimilar(x,size(x,2),1))
56+
# fixed scalar variance
57+
# mean_var(x::Tuple{<:AbstractVector,<:Real}) = x; is already coverged
58+
function mean_var(x::Tuple{<:AbstractMatrix,<:Real})
59+
(μ,σ) = x
60+
(μ,fillsimilar(μ,size(μ,2),σ))
61+
end
5362

5463
# TODO: this should be moved to DistributionsAD
5564
Distributions.mean(p::TuringDiagMvNormal) = p.m

test/cond_mvnormal.jl

+26-2
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646

4747
# BatchScalMvNormal
4848
m = SplitLayer(zlength, [xlength,1])
49-
d = MvNormal(zeros(Float32,xlength), 1f0)
5049
p = ConditionalMvNormal(m) |> gpu
5150

5251
res = condition(p, rand(zlength,batchsize)|>gpu)
@@ -70,7 +69,6 @@
7069

7170
# Unit variance
7271
m = Dense(zlength,xlength)
73-
d = MvNormal(zeros(Float32,xlength), 1f0)
7472
p = ConditionalMvNormal(m) |> gpu
7573

7674
res = condition(p, rand(zlength,batchsize)|>gpu)
@@ -90,4 +88,30 @@
9088

9189
f() = sum(rand(p,z))
9290
@test_nowarn Flux.gradient(f, ps)
91+
92+
93+
# Fixed scalar variance
94+
m = Dense(zlength,xlength)
95+
σ(x::AbstractVector) = 2
96+
σ(x::AbstractMatrix) = ones(Float32,size(x,2)) .* 2
97+
p = ConditionalMvNormal(SplitLayer(m,σ)) |> gpu
98+
99+
res = condition(p, rand(zlength,batchsize)|>gpu)
100+
μ = mean(res)
101+
σ2 = var(res)
102+
@test res isa ConditionalDists.BatchScalMvNormal
103+
@test size(μ) == (xlength,batchsize)
104+
@test size(σ2) == (xlength,batchsize)
105+
106+
x = rand(Float32, xlength, batchsize) |> gpu
107+
z = rand(Float32, zlength, batchsize) |> gpu
108+
loss() = sum(logpdf(p,x,z))
109+
ps = Flux.params(p)
110+
@test length(ps) == 2
111+
@test loss() isa Float32
112+
@test_nowarn gs = Flux.gradient(loss, ps)
113+
114+
f() = sum(rand(p,z))
115+
@test_nowarn Flux.gradient(f, ps)
116+
93117
end

0 commit comments

Comments
 (0)