Skip to content

Commit

Permalink
Simplify Zygote tests and use CR
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion committed Oct 24, 2021
1 parent 751754b commit 12b9d34
Showing 1 changed file with 88 additions and 61 deletions.
149 changes: 88 additions & 61 deletions test/ad/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,20 @@ if GROUP == "All" || GROUP == "ForwardDiff"
end
if GROUP == "All" || GROUP == "Zygote"
@eval using Zygote

# Workaround for nested `nothing`
Zygote.z2d(::NTuple{<:Any,Nothing}, ::Tuple) = NoTangent()
function Zygote.z2d(t::NamedTuple, primal::T) where T
fnames = fieldnames(T)
complete_t = map(n -> get(t, n, nothing), fnames)
primals = map(n -> getfield(primal, n), fnames)
tp = map(Zygote.z2d, complete_t, primals)
return if tp isa NTuple{<:Any,NoTangent}
NoTangent()
else
canonicalize(Tangent{T, typeof(tp)}(tp))
end
end
end
if GROUP == "All" || GROUP == "ReverseDiff"
@eval using ReverseDiff
Expand Down Expand Up @@ -225,62 +239,93 @@ function test_ad(dist::DistSpec; kwargs...)
g = dist.xtrans
broken = dist.broken

# Create functions with all possible arguments
f_loglik_allargs = let f=f, g=g
function (x, θ...)
dist = f...)
xtilde = g === nothing ? x : g(x)
return loglikelihood(dist, xtilde)
end
# combine all arguments
# point `x` is not differentiable if the distribution is discrete
args = if Distributions.value_support(typeof(dist)) === Continuous
(x, θ...)
else
θ
end
f_logpdf_allargs = let f=f, g=g
function (x, θ...)
dist = f...)
xtilde = g === nothing ? x : g(x)
if dist isa UnivariateDistribution && xtilde isa AbstractArray
return sum(logpdf.(dist, xtilde))
else
return sum(logpdf(dist, xtilde))

# Create functions with all arguments
if Distributions.value_support(typeof(dist)) === Continuous
f_loglik_allargs = let f=f, g=g
function (x, θ...)
dist = f...)
xtilde = g === nothing ? x : g(x)
return loglikelihood(dist, xtilde)
end
end
end

# For all combinations of distribution parameters `θ`
for inds in powerset(2:(length(θ) + 1))
# Test only distribution parameters
if !isempty(inds)
xtest = mapreduce(vcat, inds) do i
vectorize(θ[i - 1])
f_logpdf_allargs = let f=f, g=g
function (x, θ...)
dist = f...)
xtilde = g === nothing ? x : g(x)
if dist isa UnivariateDistribution && xtilde isa AbstractArray
return sum(logpdf.(dist, xtilde))
else
return sum(logpdf(dist, xtilde))
end
end
f_loglik_test = let xorig=x, θorig=θ, inds=inds
x -> f_loglik_allargs(unpack(x, inds, xorig, θorig...)...)
end
else
gx = g === nothing ? x : g(x)
f_loglik_allargs = let f=f, gx=gx
function...)
dist = f...)
return loglikelihood(dist, gx)
end
f_logpdf_test = let xorig=x, θorig=θ, inds=inds
x -> f_logpdf_allargs(unpack(x, inds, xorig, θorig...)...)
end
f_logpdf_allargs = let f=f, gx=gx
function...)
dist = f...)
return if dist isa UnivariateDistribution && gx isa AbstractArray
sum(logpdf.(dist, gx))
else
sum(logpdf(dist, gx))
end
end
end
end

@test f_loglik_test(xtest) f_logpdf_test(xtest)

test_ad(f_loglik_test, xtest, broken; kwargs...)
test_ad(f_logpdf_test, xtest, broken; kwargs...)
# short cut: since Zygote does not use special number types with
# different dispatches etc., it is suffiient to just test derivatives of
# all differentiable arguments at once
if GROUP === "All" || GROUP === "Zygote"
@test f_loglik_allargs(args...) f_logpdf_allargs(args...)

# Zygote has type inference problems so we don't check it
try
for f in (f_loglik_allargs, f_logpdf_allargs)
test_rrule(
Zygote.ZygoteRuleConfig(), f NoTangent(), args...;
rrule_f=rrule_via_ad, check_inferred=false, kwargs...
)
end
catch
:Zygote in test_broken || rethrow()
end
end

# early exit
GROUP !== "Zygote" || return

# Test derivative with respect to location `x` as well
# if the distribution is continuous
if Distributions.value_support(typeof(dist)) === Continuous
xtest = isempty(inds) ? vectorize(x) : vcat(vectorize(x), xtest)
push!(inds, 1)
f_loglik_test = let xorig=x, θorig=θ, inds=inds
x -> f_loglik_allargs(unpack(x, inds, xorig, θorig...)...)
# For all combinations of arguments
for inds in powerset(1:length(args))
if !isempty(inds)
argstest = mapreduce(vcat, inds) do i
vectorize(args[i])
end
f_loglik_test = let args=args, inds=inds
x -> f_loglik_allargs(unpack(x, inds, args...)...)
end
f_logpdf_test = let xorig=x, θorig=θ, inds=inds
x -> f_logpdf_allargs(unpack(x, inds, xorig, θorig...)...)
f_logpdf_test = let args=args, inds=inds
x -> f_logpdf_allargs(unpack(x, inds, args...)...)
end

@test f_loglik_test(xtest) f_logpdf_test(xtest)
@test f_loglik_test(argstest) f_logpdf_test(argstest)

test_ad(f_loglik_test, xtest, broken; kwargs...)
test_ad(f_logpdf_test, xtest, broken; kwargs...)
test_ad(f_loglik_test, argstest, broken; kwargs...)
test_ad(f_logpdf_test, argstest, broken; kwargs...)
end
end
end
Expand All @@ -304,18 +349,6 @@ function test_ad(f, x, broken = (); rtol = 1e-6, atol = 1e-6)
end
end

if GROUP == "All" || GROUP == "Zygote"
if :Zygote in broken
@test_broken zygote_isapprox(
Zygote.gradient(f, x)[1], finitediff; rtol=rtol, atol=atol,
)
else
@test zygote_isapprox(
Zygote.gradient(f, x)[1], finitediff; rtol=rtol, atol=atol,
)
end
end

if GROUP == "All" || GROUP == "ReverseDiff"
if :ReverseDiff in broken
@test_broken ReverseDiff.gradient(f, x) finitediff rtol=rtol atol=atol
Expand All @@ -326,9 +359,3 @@ function test_ad(f, x, broken = (); rtol = 1e-6, atol = 1e-6)

return
end

# Handle Zygote's `nothing`
zygote_isapprox(x, expected; kwargs...) = isapprox(x, expected; kwargs...)
function zygote_isapprox(::Nothing, expected; kwargs...)
return isapprox(zero(expected), expected; kwargs...)
end

0 comments on commit 12b9d34

Please sign in to comment.