Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: rrule dot test #208

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 49 additions & 22 deletions src/testers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ function test_rrule(
config::RuleConfig,
f,
args...;
output_tangent=Auto(),
output_cotangent=Auto(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I don't forget if someone does come back to this.
This can't be done.
It is a breaking change.

In general we have decide to use tangent as the word to refer to both tangents and cotangents

check_thunked_output_tangent=true,
fdm=_fdm,
rrule_f=ChainRulesCore.rrule,
Expand All @@ -188,6 +188,8 @@ function test_rrule(
# and define helper closure over fkwargs
call(f, xs...) = f(xs...; fkwargs...)

call_on_copy(f, xs...) = deepcopy(f)(deepcopy(xs)...; deepcopy(fkwargs)...)

@testset "test_rrule: $f on $(_string_typeof(args))" begin

# Check correctness of evaluation.
Expand Down Expand Up @@ -219,30 +221,55 @@ function test_rrule(

# Correctness testing via finite differencing.
is_ignored = isa.(accum_cotangents, NoTangent)
fd_cotangents = _make_j′vp_call(fdm, call, ȳ, primals, is_ignored)
fd_output_tangent = _make_jvp_call(
fdm, call_on_copy, y, primals, tangents, is_ignored,
)

for (accum_cotangent, ad_cotangent, fd_cotangent) in zip(
accum_cotangents, ad_cotangents, fd_cotangents
# Current implementation assumes that is_ignored is always false. Easy fix though.
# More consistent names for variables in this context.
inputs = primals
inputs_tangents = accum_cotangents
inputs_cotangents = ad_cotangents
output = y
output_tangent = fd_output_tangent
output_cotangent = ȳ
@test isapprox(
dot(output_cotangent, output_tangent),
dot(inputs_cotangents, inputs_tangents),
)
if accum_cotangent isa NoTangent # then we marked this argument as not differentiable
@assert fd_cotangent === NoTangent()
ad_cotangent isa ZeroTangent && error(
"The pullback in the rrule should use NoTangent()" *
" rather than ZeroTangent() for non-perturbable arguments.",
)
@test ad_cotangent isa NoTangent # we said it wasn't differentiable.
else
ad_cotangent isa AbstractThunk && check_inferred && _test_inferred(unthunk, ad_cotangent)

# The main test of the actual derivative being correct:
test_approx(ad_cotangent, fd_cotangent; isapprox_kwargs...)
_test_add!!_behaviour(accum_cotangent, ad_cotangent; isapprox_kwargs...)
end
end

if check_thunked_output_tangent
test_approx(ad_cotangents, pullback(@thunk(ȳ)), "pulling back a thunk:")
end
# Alternatively:
# x = primals
# ẋ = accum_cotangents
# x̄ = ad_cotangents
# y = y
# ẏ = fd_output_tangent
# ȳ = ȳ
# @test dot(ȳ, ẏ) ≈ dot(x̄, ẋ)


# for (accum_cotangent, ad_cotangent, fd_cotangent) in zip(
# accum_cotangents, ad_cotangents, fd_cotangents
# )
# if accum_cotangent isa NoTangent # then we marked this argument as not differentiable
# @assert fd_cotangent === NoTangent()
# ad_cotangent isa ZeroTangent && error(
# "The pullback in the rrule should use NoTangent()" *
# " rather than ZeroTangent() for non-perturbable arguments.",
# )
# @test ad_cotangent isa NoTangent # we said it wasn't differentiable.
# else
# ad_cotangent isa AbstractThunk && check_inferred && _test_inferred(unthunk, ad_cotangent)

# # The main test of the actual derivative being correct:
# test_approx(ad_cotangent, fd_cotangent; isapprox_kwargs...)
# _test_add!!_behaviour(accum_cotangent, ad_cotangent; isapprox_kwargs...)
# end
# end

# if check_thunked_output_tangent
# test_approx(ad_cotangents, pullback(@thunk(ȳ)), "pulling back a thunk:")
# end
end # top-level testset
end

Expand Down