|
46 | 46 |
|
47 | 47 | # BatchScalMvNormal
|
48 | 48 | m = SplitLayer(zlength, [xlength,1])
|
49 |
| - d = MvNormal(zeros(Float32,xlength), 1f0) |
50 | 49 | p = ConditionalMvNormal(m) |> gpu
|
51 | 50 |
|
52 | 51 | res = condition(p, rand(zlength,batchsize)|>gpu)
|
|
70 | 69 |
|
71 | 70 | # Unit variance
|
72 | 71 | m = Dense(zlength,xlength)
|
73 |
| - d = MvNormal(zeros(Float32,xlength), 1f0) |
74 | 72 | p = ConditionalMvNormal(m) |> gpu
|
75 | 73 |
|
76 | 74 | res = condition(p, rand(zlength,batchsize)|>gpu)
|
|
90 | 88 |
|
91 | 89 | f() = sum(rand(p,z))
|
92 | 90 | @test_nowarn Flux.gradient(f, ps)
|
| 91 | + |
| 92 | + |
| 93 | + # Fixed scalar variance |
| 94 | + m = Dense(zlength,xlength) |
| 95 | + σ(x::AbstractVector) = 2 |
| 96 | + σ(x::AbstractMatrix) = ones(Float32,size(x,2)) .* 2 |
| 97 | + p = ConditionalMvNormal(SplitLayer(m,σ)) |> gpu |
| 98 | + |
| 99 | + res = condition(p, rand(zlength,batchsize)|>gpu) |
| 100 | + μ = mean(res) |
| 101 | + σ2 = var(res) |
| 102 | + @test res isa ConditionalDists.BatchScalMvNormal |
| 103 | + @test size(μ) == (xlength,batchsize) |
| 104 | + @test size(σ2) == (xlength,batchsize) |
| 105 | + |
| 106 | + x = rand(Float32, xlength, batchsize) |> gpu |
| 107 | + z = rand(Float32, zlength, batchsize) |> gpu |
| 108 | + loss() = sum(logpdf(p,x,z)) |
| 109 | + ps = Flux.params(p) |
| 110 | + @test length(ps) == 2 |
| 111 | + @test loss() isa Float32 |
| 112 | + @test_nowarn gs = Flux.gradient(loss, ps) |
| 113 | + |
| 114 | + f() = sum(rand(p,z)) |
| 115 | + @test_nowarn Flux.gradient(f, ps) |
| 116 | + |
93 | 117 | end
|
0 commit comments