From b59aa7b92a47bda972a6f3f7d6290440ba7f8ada Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 13 Jul 2020 09:42:01 -0700 Subject: [PATCH] Check that if only one adjoint exists, it is not thunked (#53) * Test no thunking when unnecessary * Remove test that now fails by design * Add note that InplaceableThunk is allowed * Increment version number --- Project.toml | 2 +- src/testers.jl | 10 +++++++++- test/testers.jl | 5 +---- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index d55681a9..5d6d839c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesTestUtils" uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a" -version = "0.4.3" +version = "0.5.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/testers.jl b/src/testers.jl index f67c941d..a75e4632 100644 --- a/src/testers.jl +++ b/src/testers.jl @@ -192,8 +192,9 @@ function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm x̄s_ad = ∂s[2:end] @test ∂self === NO_FIELDS # No internal fields + x̄s_is_dne = x̄s .== nothing # Correctness testing via finite differencing. - x̄s_fd = _make_fdm_call(fdm, (xs...) -> f(xs...; fkwargs...), ȳ, xs, x̄s .== nothing) + x̄s_fd = _make_fdm_call(fdm, (xs...) -> f(xs...; fkwargs...), ȳ, xs, x̄s_is_dne) for (x̄_ad, x̄_fd) in zip(x̄s_ad, x̄s_fd) if x̄_fd === nothing # The way we've structured the above, this tests the propagator is returning a DoesNotExist @@ -202,4 +203,11 @@ function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm @test isapprox(x̄_ad, x̄_fd; rtol=rtol, atol=atol, kwargs...) end end + + if count(!, x̄s_is_dne) == 1 + # for functions with pullbacks that only produce a single non-DNE adjoint, that + # single adjoint should not be `Thunk`ed. InplaceableThunk is fine. + i = findfirst(!, x̄s_is_dne) + @test !(isa(x̄s_ad[i], Thunk)) + end end diff --git a/test/testers.jl b/test/testers.jl index 87645a31..72375ff2 100644 --- a/test/testers.jl +++ b/test/testers.jl @@ -40,10 +40,7 @@ primalapprox(x) = x # define rrule using ChainRulesCore's v0.9.0 convention, conjugating the derivative # in the rrule function ChainRulesCore.rrule(::typeof(sinconj), x) - # usually we would not thunk for a single output, because it will of course be - # used, but we do here to ensure that test_scalar works even if a scalar rrule - # thunks - sinconj_pullback(ΔΩ) = (NO_FIELDS, @thunk(conj(cos(x)) * ΔΩ)) + sinconj_pullback(ΔΩ) = (NO_FIELDS, conj(cos(x)) * ΔΩ) return sin(x), sinconj_pullback end