Skip to content

Commit

Permalink
Merge pull request #288 from nomadbl/add_massages
Browse files Browse the repository at this point in the history
added helpful failure massages for tests
  • Loading branch information
oxinabox authored Dec 15, 2023
2 parents 49a0324 + f7388ec commit 64971e7
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 22 deletions.
2 changes: 1 addition & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ makedocs(;
"ChainRulesTestUtils" => "index.md",
"API" => "api.md",
],
strict=true,
checkdocs=:exports,
# doctest=:fix
)

const repo = "github.com/JuliaDiff/ChainRulesTestUtils.jl.git"
Expand Down
32 changes: 20 additions & 12 deletions docs/src/index.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
```@meta
DocTestFilters = [r"[0-9\.]+s",r"isapprox\(.*\)"]
```
# ChainRulesTestUtils

[![CI](https://github.com/JuliaDiff/ChainRulesTestUtils.jl/workflows/CI/badge.svg?branch=main)](https://github.com/JuliaDiff/ChainRulesTestUtils.jl/actions?query=workflow%3ACI)
Expand Down Expand Up @@ -38,12 +41,12 @@ end
# output
```
and `rrule`
and `rrule` which contains a mistake in the first cotangent
```jldoctest ex
function ChainRulesCore.rrule(::typeof(two2three), x1, x2)
y = two2three(x1, x2)
function two2three_pullback(Ȳ)
return (NoTangent(), 2.0*Ȳ[2], 3.0*Ȳ[3])
return (NoTangent(), 2.1*Ȳ[2], 3.0*Ȳ[3])
end
return y, two2three_pullback
end
Expand All @@ -65,23 +68,28 @@ Keep this in mind when testing discontinuous rules for functions like [ReLU](htt
julia> using ChainRulesTestUtils;
julia> test_frule(two2three, 3.33, -7.77);
Test Summary: | Pass Total
test_frule: two2three on Float64,Float64 | 6 6
Test Summary: | Pass Total Time
test_frule: two2three on Float64,Float64 | 6 6 2.7s
```

### Testing the `rrule`

[`test_rrule`](@ref) takes in the function `f`, and primal inputsr `x`.
[`test_rrule`](@ref) takes in the function `f`, and primal inputs `x`.
The call will test the `rrule` for function `f` at the point `x`, and similarly to `frule` some rules should be tested at multiple points in the domain.

```jldoctest ex
julia> test_rrule(two2three, 3.33, -7.77);
Test Summary: | Pass Total
test_rrule: two2three on Float64,Float64 | 9 9
test_rrule: two2three on Float64,Float64: Test Failed at /home/lior/.julia/dev/ChainRulesTestUtils/src/check_result.jl:24
Expression: isapprox(actual, expected; kwargs...)
Problem: cotangent for input 2, Float64
Evaluated: isapprox(-4.032, -3.840000000001641; rtol = 1.0e-9, atol = 1.0e-9)
[...]
```

The output of the test indicates to us the cause of the failure under "Problem:" with the expected (`rrule` derived) and actual finite difference results.
The Problem lies with the cotangent corresponding to input 2 of `rrule`, which is the first cotangent as expected.

## Scalar example

For functions with a single argument and a single output, such as e.g. ReLU,
Expand All @@ -105,13 +113,13 @@ with the `frule` and `rrule` defined with the help of `@scalar_rule` macro
call.
```jldoctest ex
julia> test_scalar(relu, 0.5);
Test Summary: | Pass Total
test_scalar: relu at 0.5 | 11 11
Test Summary: | Pass Total Time
test_scalar: relu at 0.5 | 12 12 1.2s
julia> test_scalar(relu, -0.5);
Test Summary: | Pass Total
test_scalar: relu at -0.5 | 11 11
Test Summary: | Pass Total Time
test_scalar: relu at -0.5 | 12 12 0.0s
```

Expand Down
23 changes: 14 additions & 9 deletions src/testers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ function test_rrule(
res === nothing && throw(MethodError(rrule_f, Tuple{Core.Typeof.(primals)...}))
y_ad, pullback = res
y = call(primals...)
test_approx(y_ad, y; isapprox_kwargs...) # make sure primal is correct
test_approx(y_ad, y, "Failed primal value check"; isapprox_kwargs...) # make sure primal is correct

ȳ = output_tangent isa Auto ? rand_tangent(y) : output_tangent

Expand All @@ -231,7 +231,8 @@ function test_rrule(
# Correctness testing via finite differencing.
is_ignored = isa.(accum_cotangents, NoTangent)
fd_cotangents = _make_j′vp_call(fdm, call, ȳ, primals, is_ignored)
foreach(accum_cotangents, ad_cotangents, fd_cotangents) do args...
msgs = ntuple(i->"cotangent for input $i, $(summary(fd_cotangents[i]))", length(fd_cotangents))
foreach(accum_cotangents, ad_cotangents, fd_cotangents, msgs) do args...
_test_cotangent(args...; check_inferred=check_inferred, isapprox_kwargs...)
end

Expand Down Expand Up @@ -282,14 +283,16 @@ function _is_inferrable(f, args...; kwargs...)
end

"""
_test_cotangent(accum_cotangent, ad_cotangent, fd_cotangent; kwargs...)
_test_cotangent(accum_cotangent, ad_cotangent, fd_cotangent[, msg]; kwargs...)
Check if the cotangent `ad_cotangent` from `rrule` is consistent with `accum_tangent` and
approximately equal to the cotangent `fd_cotangent` obtained with finite differencing.
If `accum_cotangent` is `NoTangent()`, i.e., the argument was marked as non-differentiable,
`ad_cotangent` and `fd_cotangent` should be `NoTangent()` as well.
If a msg string is given, it is emmited on test failure.
# Keyword arguments
- If `check_inferred=true` (the default) and `ad_cotangent` is a thunk, then it is checked if
its content can be inferred.
Expand All @@ -298,22 +301,23 @@ If `accum_cotangent` is `NoTangent()`, i.e., the argument was marked as non-diff
function _test_cotangent(
accum_cotangent,
ad_cotangent,
fd_cotangent;
fd_cotangent,
msg="";
check_inferred=true,
kwargs...,
)
ad_cotangent isa AbstractThunk && check_inferred && _test_inferred(unthunk, ad_cotangent)

# The main test of the actual derivative being correct:
test_approx(ad_cotangent, fd_cotangent; kwargs...)
test_approx(ad_cotangent, fd_cotangent, msg; kwargs...)
_test_add!!_behaviour(accum_cotangent, ad_cotangent; kwargs...)
end

# we marked the argument as non-differentiable
function _test_cotangent(::NoTangent, ad_cotangent, ::NoTangent; kwargs...)
function _test_cotangent(::NoTangent, ad_cotangent, ::NoTangent, msg=""; kwargs...)
@test ad_cotangent isa NoTangent
end
function _test_cotangent(::NoTangent, ::ZeroTangent, ::NoTangent; kwargs...)
function _test_cotangent(::NoTangent, ::ZeroTangent, ::NoTangent, msg=""; kwargs...)
error(
"The pullback in the rrule should use NoTangent()" *
" rather than ZeroTangent() for non-perturbable arguments."
Expand All @@ -322,7 +326,8 @@ end
function _test_cotangent(
::NoTangent,
ad_cotangent::ChainRulesCore.NotImplemented,
::NoTangent;
::NoTangent,
msg="";
kwargs...,
)
# this situation can occur if a cotangent is not implemented and
Expand All @@ -332,6 +337,6 @@ function _test_cotangent(
# https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/217
@test_broken ad_cotangent isa NoTangent
end
function _test_cotangent(::NoTangent, ad_cotangent, fd_cotangent; kwargs...)
function _test_cotangent(::NoTangent, ad_cotangent, fd_cotangent, msg=""; kwargs...)
error("cotangent obtained with finite differencing has to be NoTangent()")
end

0 comments on commit 64971e7

Please sign in to comment.