From 12b9d34972108c35dffdec23f6c2c0240ea5b7c8 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sun, 24 Oct 2021 02:42:04 +0200 Subject: [PATCH] Simplify Zygote tests and use CR --- test/ad/utils.jl | 149 ++++++++++++++++++++++++++++------------------- 1 file changed, 88 insertions(+), 61 deletions(-) diff --git a/test/ad/utils.jl b/test/ad/utils.jl index 31d294a..a1e62c9 100644 --- a/test/ad/utils.jl +++ b/test/ad/utils.jl @@ -10,6 +10,20 @@ if GROUP == "All" || GROUP == "ForwardDiff" end if GROUP == "All" || GROUP == "Zygote" @eval using Zygote + + # Workaround for nested `nothing` + Zygote.z2d(::NTuple{<:Any,Nothing}, ::Tuple) = NoTangent() + function Zygote.z2d(t::NamedTuple, primal::T) where T + fnames = fieldnames(T) + complete_t = map(n -> get(t, n, nothing), fnames) + primals = map(n -> getfield(primal, n), fnames) + tp = map(Zygote.z2d, complete_t, primals) + return if tp isa NTuple{<:Any,NoTangent} + NoTangent() + else + canonicalize(Tangent{T, typeof(tp)}(tp)) + end + end end if GROUP == "All" || GROUP == "ReverseDiff" @eval using ReverseDiff @@ -225,62 +239,93 @@ function test_ad(dist::DistSpec; kwargs...) g = dist.xtrans broken = dist.broken - # Create functions with all possible arguments - f_loglik_allargs = let f=f, g=g - function (x, θ...) - dist = f(θ...) - xtilde = g === nothing ? x : g(x) - return loglikelihood(dist, xtilde) - end + # combine all arguments + # point `x` is not differentiable if the distribution is discrete + args = if Distributions.value_support(typeof(dist)) === Continuous + (x, θ...) + else + θ end - f_logpdf_allargs = let f=f, g=g - function (x, θ...) - dist = f(θ...) - xtilde = g === nothing ? x : g(x) - if dist isa UnivariateDistribution && xtilde isa AbstractArray - return sum(logpdf.(dist, xtilde)) - else - return sum(logpdf(dist, xtilde)) + + # Create functions with all arguments + if Distributions.value_support(typeof(dist)) === Continuous + f_loglik_allargs = let f=f, g=g + function (x, θ...) + dist = f(θ...) + xtilde = g === nothing ? x : g(x) + return loglikelihood(dist, xtilde) end end - end - - # For all combinations of distribution parameters `θ` - for inds in powerset(2:(length(θ) + 1)) - # Test only distribution parameters - if !isempty(inds) - xtest = mapreduce(vcat, inds) do i - vectorize(θ[i - 1]) + f_logpdf_allargs = let f=f, g=g + function (x, θ...) + dist = f(θ...) + xtilde = g === nothing ? x : g(x) + if dist isa UnivariateDistribution && xtilde isa AbstractArray + return sum(logpdf.(dist, xtilde)) + else + return sum(logpdf(dist, xtilde)) + end end - f_loglik_test = let xorig=x, θorig=θ, inds=inds - x -> f_loglik_allargs(unpack(x, inds, xorig, θorig...)...) + end + else + gx = g === nothing ? x : g(x) + f_loglik_allargs = let f=f, gx=gx + function (θ...) + dist = f(θ...) + return loglikelihood(dist, gx) end - f_logpdf_test = let xorig=x, θorig=θ, inds=inds - x -> f_logpdf_allargs(unpack(x, inds, xorig, θorig...)...) + end + f_logpdf_allargs = let f=f, gx=gx + function (θ...) + dist = f(θ...) + return if dist isa UnivariateDistribution && gx isa AbstractArray + sum(logpdf.(dist, gx)) + else + sum(logpdf(dist, gx)) + end end + end + end - @test f_loglik_test(xtest) ≈ f_logpdf_test(xtest) - - test_ad(f_loglik_test, xtest, broken; kwargs...) - test_ad(f_logpdf_test, xtest, broken; kwargs...) + # short cut: since Zygote does not use special number types with + # different dispatches etc., it is suffiient to just test derivatives of + # all differentiable arguments at once + if GROUP === "All" || GROUP === "Zygote" + @test f_loglik_allargs(args...) ≈ f_logpdf_allargs(args...) + + # Zygote has type inference problems so we don't check it + try + for f in (f_loglik_allargs, f_logpdf_allargs) + test_rrule( + Zygote.ZygoteRuleConfig(), f ⊢ NoTangent(), args...; + rrule_f=rrule_via_ad, check_inferred=false, kwargs... + ) + end + catch + :Zygote in test_broken || rethrow() end + end + + # early exit + GROUP !== "Zygote" || return - # Test derivative with respect to location `x` as well - # if the distribution is continuous - if Distributions.value_support(typeof(dist)) === Continuous - xtest = isempty(inds) ? vectorize(x) : vcat(vectorize(x), xtest) - push!(inds, 1) - f_loglik_test = let xorig=x, θorig=θ, inds=inds - x -> f_loglik_allargs(unpack(x, inds, xorig, θorig...)...) + # For all combinations of arguments + for inds in powerset(1:length(args)) + if !isempty(inds) + argstest = mapreduce(vcat, inds) do i + vectorize(args[i]) + end + f_loglik_test = let args=args, inds=inds + x -> f_loglik_allargs(unpack(x, inds, args...)...) end - f_logpdf_test = let xorig=x, θorig=θ, inds=inds - x -> f_logpdf_allargs(unpack(x, inds, xorig, θorig...)...) + f_logpdf_test = let args=args, inds=inds + x -> f_logpdf_allargs(unpack(x, inds, args...)...) end - @test f_loglik_test(xtest) ≈ f_logpdf_test(xtest) + @test f_loglik_test(argstest) ≈ f_logpdf_test(argstest) - test_ad(f_loglik_test, xtest, broken; kwargs...) - test_ad(f_logpdf_test, xtest, broken; kwargs...) + test_ad(f_loglik_test, argstest, broken; kwargs...) + test_ad(f_logpdf_test, argstest, broken; kwargs...) end end end @@ -304,18 +349,6 @@ function test_ad(f, x, broken = (); rtol = 1e-6, atol = 1e-6) end end - if GROUP == "All" || GROUP == "Zygote" - if :Zygote in broken - @test_broken zygote_isapprox( - Zygote.gradient(f, x)[1], finitediff; rtol=rtol, atol=atol, - ) - else - @test zygote_isapprox( - Zygote.gradient(f, x)[1], finitediff; rtol=rtol, atol=atol, - ) - end - end - if GROUP == "All" || GROUP == "ReverseDiff" if :ReverseDiff in broken @test_broken ReverseDiff.gradient(f, x) ≈ finitediff rtol=rtol atol=atol @@ -326,9 +359,3 @@ function test_ad(f, x, broken = (); rtol = 1e-6, atol = 1e-6) return end - -# Handle Zygote's `nothing` -zygote_isapprox(x, expected; kwargs...) = isapprox(x, expected; kwargs...) -function zygote_isapprox(::Nothing, expected; kwargs...) - return isapprox(zero(expected), expected; kwargs...) -end