|  | 
|  | 1 | +module DynamicPPLUtilsTests | 
|  | 2 | + | 
|  | 3 | +using Bijectors: Bijectors | 
|  | 4 | +using Distributions | 
|  | 5 | +using DynamicPPL | 
|  | 6 | +using LinearAlgebra: LinearAlgebra | 
|  | 7 | +using Test | 
|  | 8 | + | 
|  | 9 | +isapprox_nested(a::Number, b::Number; kwargs...) = isapprox(a, b; kwargs...) | 
|  | 10 | +isapprox_nested(a::AbstractArray, b::AbstractArray; kwargs...) = isapprox(a, b; kwargs...) | 
|  | 11 | +function isapprox_nested(a::LinearAlgebra.Cholesky, b::LinearAlgebra.Cholesky; kwargs...) | 
|  | 12 | +    return isapprox(a.U, b.U; kwargs...) && isapprox(a.L, b.L; kwargs...) | 
|  | 13 | +end | 
|  | 14 | +function isapprox_nested(a::NamedTuple, b::NamedTuple; kwargs...) | 
|  | 15 | +    keys(a) == keys(b) || return false | 
|  | 16 | +    return all(k -> isapprox_nested(a[k], b[k]; kwargs...), keys(a)) | 
|  | 17 | +end | 
|  | 18 | + | 
| 1 | 19 | @testset "utils.jl" begin | 
| 2 | 20 |     @testset "addlogprob!" begin | 
| 3 | 21 |         @model function testmodel() | 
|  | 
| 31 | 49 |         end | 
| 32 | 50 |     end | 
| 33 | 51 | 
 | 
|  | 52 | +    @testset "transformations" begin | 
|  | 53 | +        function test_transformation( | 
|  | 54 | +            dist::Distribution; test_bijector_type_stability::Bool=true | 
|  | 55 | +        ) | 
|  | 56 | +            unlinked = rand(dist) | 
|  | 57 | +            unlinked_vec = DynamicPPL.tovec(unlinked) | 
|  | 58 | +            @test unlinked_vec isa AbstractVector | 
|  | 59 | + | 
|  | 60 | +            from_vec_trfm = DynamicPPL.from_vec_transform(dist) | 
|  | 61 | +            unlinked_again, logjac = Bijectors.with_logabsdet_jacobian( | 
|  | 62 | +                from_vec_trfm, unlinked_vec | 
|  | 63 | +            ) | 
|  | 64 | +            @test isapprox_nested(unlinked, unlinked_again) | 
|  | 65 | +            @test iszero(logjac) | 
|  | 66 | +            # Type stability | 
|  | 67 | +            @inferred DynamicPPL.from_vec_transform(dist) | 
|  | 68 | +            @inferred Bijectors.with_logabsdet_jacobian(from_vec_trfm, unlinked_vec) | 
|  | 69 | + | 
|  | 70 | +            # Typically the same as `bijector(dist)`, but technically a different | 
|  | 71 | +            # function | 
|  | 72 | +            b = DynamicPPL.link_transform(dist) | 
|  | 73 | +            @test (b(unlinked); true) | 
|  | 74 | +            linked, logjac = Bijectors.with_logabsdet_jacobian(b, unlinked) | 
|  | 75 | +            @test logjac isa Real | 
|  | 76 | + | 
|  | 77 | +            binv = DynamicPPL.invlink_transform(dist) | 
|  | 78 | +            unlinked_again, logjac_inv = Bijectors.with_logabsdet_jacobian(binv, linked) | 
|  | 79 | +            @test isapprox_nested(unlinked, unlinked_again) | 
|  | 80 | +            @test isapprox(logjac, -logjac_inv) | 
|  | 81 | + | 
|  | 82 | +            linked_vec = DynamicPPL.tovec(linked) | 
|  | 83 | +            @test linked_vec isa AbstractVector | 
|  | 84 | +            from_linked_vec_trfm = DynamicPPL.from_linked_vec_transform(dist) | 
|  | 85 | +            unlinked_again_again = from_linked_vec_trfm(linked_vec) | 
|  | 86 | +            @test isapprox_nested(unlinked, unlinked_again_again) | 
|  | 87 | + | 
|  | 88 | +            # Sometimes the bijector itself is not type stable. In this case there is not | 
|  | 89 | +            # much we can do in DynamicPPL except skip these tests (it has to be fixed | 
|  | 90 | +            # upstream in Bijectors.) | 
|  | 91 | +            if test_bijector_type_stability | 
|  | 92 | +                @inferred DynamicPPL.from_linked_vec_transform(dist) | 
|  | 93 | +                @inferred Bijectors.with_logabsdet_jacobian( | 
|  | 94 | +                    from_linked_vec_trfm, linked_vec | 
|  | 95 | +                ) | 
|  | 96 | +            end | 
|  | 97 | + | 
|  | 98 | +            # Create a model and check that we can evaluate it with both unlinked and linked | 
|  | 99 | +            # VarInfo. This relies on the transformations working correctly so is more of an | 
|  | 100 | +            # 'end to end' test | 
|  | 101 | +            @model test() = x ~ dist | 
|  | 102 | +            model = test() | 
|  | 103 | +            vi_unlinked = VarInfo(model) | 
|  | 104 | +            vi_linked = DynamicPPL.link!!(VarInfo(model), model) | 
|  | 105 | +            @test (DynamicPPL.evaluate!!(model, vi_unlinked); true) | 
|  | 106 | +            @test (DynamicPPL.evaluate!!(model, vi_linked); true) | 
|  | 107 | +            model_init = DynamicPPL.setleafcontext(model, DynamicPPL.InitContext()) | 
|  | 108 | +            @test (DynamicPPL.evaluate!!(model_init, vi_unlinked); true) | 
|  | 109 | +            @test (DynamicPPL.evaluate!!(model_init, vi_linked); true) | 
|  | 110 | +        end | 
|  | 111 | + | 
|  | 112 | +        # Unconstrained univariate | 
|  | 113 | +        test_transformation(Normal()) | 
|  | 114 | +        # Constrained univariate | 
|  | 115 | +        test_transformation(LogNormal()) | 
|  | 116 | +        test_transformation(truncated(Normal(); lower=0)) | 
|  | 117 | +        test_transformation(Exponential(1.0)) | 
|  | 118 | +        test_transformation(Uniform(-2, 2)) | 
|  | 119 | +        test_transformation(Beta(2, 2)) | 
|  | 120 | +        test_transformation(InverseGamma(2, 3)) | 
|  | 121 | +        # Discrete univariate | 
|  | 122 | +        test_transformation(Poisson(3)) | 
|  | 123 | +        test_transformation(Binomial(10, 0.5)) | 
|  | 124 | +        # Multivariate | 
|  | 125 | +        test_transformation(MvNormal(zeros(3), LinearAlgebra.I)) | 
|  | 126 | +        test_transformation( | 
|  | 127 | +            product_distribution([Normal(), LogNormal()]); | 
|  | 128 | +            test_bijector_type_stability=false, | 
|  | 129 | +        ) | 
|  | 130 | +        test_transformation(product_distribution([LogNormal(), LogNormal()])) | 
|  | 131 | +        # Matrixvariate | 
|  | 132 | +        test_transformation(LKJ(3, 0.5)) | 
|  | 133 | +        test_transformation(Wishart(7, [1.0 0.0; 0.0 1.0])) | 
|  | 134 | +        # This is a pathological example: the linked representation is a matrix | 
|  | 135 | +        test_transformation(product_distribution(fill(Dirichlet(ones(4)), 2, 3))) | 
|  | 136 | +        # Cholesky | 
|  | 137 | +        test_transformation(LKJCholesky(3, 0.5)) | 
|  | 138 | +        # ProductNamedTupleDistribution | 
|  | 139 | +        d = product_distribution((a=Normal(), b=LogNormal())) | 
|  | 140 | +        test_transformation(d) | 
|  | 141 | +        d_nested = product_distribution((x=LKJCholesky(2, 0.5), y=d)) | 
|  | 142 | +        test_transformation(d_nested) | 
|  | 143 | +    end | 
|  | 144 | + | 
| 34 | 145 |     @testset "getargs_dottilde" begin | 
| 35 | 146 |         # Some things that are not expressions. | 
| 36 |  | -        @test getargs_dottilde(:x) === nothing | 
| 37 |  | -        @test getargs_dottilde(1.0) === nothing | 
| 38 |  | -        @test getargs_dottilde([1.0, 2.0, 4.0]) === nothing | 
|  | 147 | +        @test DynamicPPL.getargs_dottilde(:x) === nothing | 
|  | 148 | +        @test DynamicPPL.getargs_dottilde(1.0) === nothing | 
|  | 149 | +        @test DynamicPPL.getargs_dottilde([1.0, 2.0, 4.0]) === nothing | 
| 39 | 150 | 
 | 
| 40 | 151 |         # Some expressions. | 
| 41 |  | -        @test getargs_dottilde(:(x ~ Normal(μ, σ))) === nothing | 
| 42 |  | -        @test getargs_dottilde(:((.~)(x, Normal(μ, σ)))) == (:x, :(Normal(μ, σ))) | 
| 43 |  | -        @test getargs_dottilde(:((~).(x, Normal(μ, σ)))) == (:x, :(Normal(μ, σ))) | 
| 44 |  | -        @test getargs_dottilde(:(x .~ Normal(μ, σ))) == (:x, :(Normal(μ, σ))) | 
| 45 |  | -        @test getargs_dottilde(:(@. x ~ Normal(μ, σ))) === nothing | 
| 46 |  | -        @test getargs_dottilde(:(@. x ~ Normal(μ, $(Expr(:$, :(sqrt(v))))))) === nothing | 
| 47 |  | -        @test getargs_dottilde(:(@~ Normal.(μ, σ))) === nothing | 
|  | 152 | +        @test DynamicPPL.getargs_dottilde(:(x ~ Normal(μ, σ))) === nothing | 
|  | 153 | +        @test DynamicPPL.getargs_dottilde(:((.~)(x, Normal(μ, σ)))) == (:x, :(Normal(μ, σ))) | 
|  | 154 | +        @test DynamicPPL.getargs_dottilde(:((~).(x, Normal(μ, σ)))) == (:x, :(Normal(μ, σ))) | 
|  | 155 | +        @test DynamicPPL.getargs_dottilde(:(x .~ Normal(μ, σ))) == (:x, :(Normal(μ, σ))) | 
|  | 156 | +        @test DynamicPPL.getargs_dottilde(:(@. x ~ Normal(μ, σ))) === nothing | 
|  | 157 | +        @test DynamicPPL.getargs_dottilde(:(@. x ~ Normal(μ, $(Expr(:$, :(sqrt(v))))))) === | 
|  | 158 | +            nothing | 
|  | 159 | +        @test DynamicPPL.getargs_dottilde(:(@~ Normal.(μ, σ))) === nothing | 
| 48 | 160 |     end | 
| 49 | 161 | 
 | 
| 50 | 162 |     @testset "getargs_tilde" begin | 
| 51 | 163 |         # Some things that are not expressions. | 
| 52 |  | -        @test getargs_tilde(:x) === nothing | 
| 53 |  | -        @test getargs_tilde(1.0) === nothing | 
| 54 |  | -        @test getargs_tilde([1.0, 2.0, 4.0]) === nothing | 
|  | 164 | +        @test DynamicPPL.getargs_tilde(:x) === nothing | 
|  | 165 | +        @test DynamicPPL.getargs_tilde(1.0) === nothing | 
|  | 166 | +        @test DynamicPPL.getargs_tilde([1.0, 2.0, 4.0]) === nothing | 
| 55 | 167 | 
 | 
| 56 | 168 |         # Some expressions. | 
| 57 |  | -        @test getargs_tilde(:(x ~ Normal(μ, σ))) == (:x, :(Normal(μ, σ))) | 
| 58 |  | -        @test getargs_tilde(:((.~)(x, Normal(μ, σ)))) === nothing | 
| 59 |  | -        @test getargs_tilde(:((~).(x, Normal(μ, σ)))) === nothing | 
| 60 |  | -        @test getargs_tilde(:(@. x ~ Normal(μ, σ))) === nothing | 
| 61 |  | -        @test getargs_tilde(:(@. x ~ Normal(μ, $(Expr(:$, :(sqrt(v))))))) === nothing | 
| 62 |  | -        @test getargs_tilde(:(@~ Normal.(μ, σ))) === nothing | 
|  | 169 | +        @test DynamicPPL.getargs_tilde(:(x ~ Normal(μ, σ))) == (:x, :(Normal(μ, σ))) | 
|  | 170 | +        @test DynamicPPL.getargs_tilde(:((.~)(x, Normal(μ, σ)))) === nothing | 
|  | 171 | +        @test DynamicPPL.getargs_tilde(:((~).(x, Normal(μ, σ)))) === nothing | 
|  | 172 | +        @test DynamicPPL.getargs_tilde(:(@. x ~ Normal(μ, σ))) === nothing | 
|  | 173 | +        @test DynamicPPL.getargs_tilde(:(@. x ~ Normal(μ, $(Expr(:$, :(sqrt(v))))))) === | 
|  | 174 | +            nothing | 
|  | 175 | +        @test DynamicPPL.getargs_tilde(:(@~ Normal.(μ, σ))) === nothing | 
| 63 | 176 |     end | 
| 64 | 177 | 
 | 
| 65 | 178 |     @testset "tovec" begin | 
|  | 
| 97 | 210 |         @test DynamicPPL.group_varnames_by_symbol(vns_tuple) == vns_nt | 
| 98 | 211 |     end | 
| 99 | 212 | end | 
|  | 213 | + | 
|  | 214 | +end | 
0 commit comments