From 596ea66b2a5fedb0433cdf56bac5f39d8932cb8c Mon Sep 17 00:00:00 2001 From: Lior Blech Date: Fri, 24 Nov 2023 00:08:30 +0000 Subject: [PATCH 1/5] added helpful failure massages for tests --- src/testers.jl | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/testers.jl b/src/testers.jl index 7725a67..c750a8f 100644 --- a/src/testers.jl +++ b/src/testers.jl @@ -213,7 +213,7 @@ function test_rrule( res === nothing && throw(MethodError(rrule_f, Tuple{Core.Typeof.(primals)...})) y_ad, pullback = res y = call(primals...) - test_approx(y_ad, y; isapprox_kwargs...) # make sure primal is correct + test_approx(y_ad, y, "Failed primal value check"; isapprox_kwargs...) # make sure primal is correct ȳ = output_tangent isa Auto ? rand_tangent(y) : output_tangent @@ -231,7 +231,8 @@ function test_rrule( # Correctness testing via finite differencing. is_ignored = isa.(accum_cotangents, NoTangent) fd_cotangents = _make_j′vp_call(fdm, call, ȳ, primals, is_ignored) - foreach(accum_cotangents, ad_cotangents, fd_cotangents) do args... + msgs = ntuple(i->"cotangent for input $i, $(summary(fd_cotangents[i]))", length(fd_cotangents)) + foreach(accum_cotangents, ad_cotangents, fd_cotangents, msgs) do args... _test_cotangent(args...; check_inferred=check_inferred, isapprox_kwargs...) end @@ -298,14 +299,15 @@ If `accum_cotangent` is `NoTangent()`, i.e., the argument was marked as non-diff function _test_cotangent( accum_cotangent, ad_cotangent, - fd_cotangent; + fd_cotangent, + msg=""; check_inferred=true, kwargs..., ) ad_cotangent isa AbstractThunk && check_inferred && _test_inferred(unthunk, ad_cotangent) # The main test of the actual derivative being correct: - test_approx(ad_cotangent, fd_cotangent; kwargs...) + test_approx(ad_cotangent, fd_cotangent, msg; kwargs...) _test_add!!_behaviour(accum_cotangent, ad_cotangent; kwargs...) end From 0cf721a12001c6462b96866a00888584544caccb Mon Sep 17 00:00:00 2001 From: Lior Blech Date: Mon, 27 Nov 2023 17:28:40 +0000 Subject: [PATCH 2/5] update doctests and fix matching on timing --- docs/make.jl | 2 +- docs/src/index.md | 19 +++++++++++-------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/docs/make.jl b/docs/make.jl index 06827ad..0bc7152 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -10,8 +10,8 @@ makedocs(; "ChainRulesTestUtils" => "index.md", "API" => "api.md", ], - strict=true, checkdocs=:exports, + # doctest=:fix ) const repo = "github.com/JuliaDiff/ChainRulesTestUtils.jl.git" diff --git a/docs/src/index.md b/docs/src/index.md index 0d1355c..6d30f47 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -1,3 +1,6 @@ +```@meta +DocTestFilters = r"[0-9\.]+s" +``` # ChainRulesTestUtils [![CI](https://github.com/JuliaDiff/ChainRulesTestUtils.jl/workflows/CI/badge.svg?branch=main)](https://github.com/JuliaDiff/ChainRulesTestUtils.jl/actions?query=workflow%3ACI) @@ -65,8 +68,8 @@ Keep this in mind when testing discontinuous rules for functions like [ReLU](htt julia> using ChainRulesTestUtils; julia> test_frule(two2three, 3.33, -7.77); -Test Summary: | Pass Total -test_frule: two2three on Float64,Float64 | 6 6 +Test Summary: | Pass Total Time +test_frule: two2three on Float64,Float64 | 6 6 2.4s ``` @@ -77,8 +80,8 @@ The call will test the `rrule` for function `f` at the point `x`, and similarly ```jldoctest ex julia> test_rrule(two2three, 3.33, -7.77); -Test Summary: | Pass Total -test_rrule: two2three on Float64,Float64 | 9 9 +Test Summary: | Pass Total Time +test_rrule: two2three on Float64,Float64 | 10 10 0.9s ``` @@ -105,13 +108,13 @@ with the `frule` and `rrule` defined with the help of `@scalar_rule` macro call. ```jldoctest ex julia> test_scalar(relu, 0.5); -Test Summary: | Pass Total -test_scalar: relu at 0.5 | 11 11 +Test Summary: | Pass Total Time +test_scalar: relu at 0.5 | 12 12 1.0s julia> test_scalar(relu, -0.5); -Test Summary: | Pass Total -test_scalar: relu at -0.5 | 11 11 +Test Summary: | Pass Total Time +test_scalar: relu at -0.5 | 12 12 0.0s ``` From c36d81c872c6ca8a2e12e7d9751ddabfdfd2ca3e Mon Sep 17 00:00:00 2001 From: Lior Blech Date: Mon, 27 Nov 2023 18:23:05 +0000 Subject: [PATCH 3/5] include debugging example --- docs/src/index.md | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index 6d30f47..77f6c8e 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -1,5 +1,5 @@ ```@meta -DocTestFilters = r"[0-9\.]+s" +DocTestFilters = [r"[0-9\.]+s",r"isapprox\(.*\)"] ``` # ChainRulesTestUtils @@ -41,12 +41,12 @@ end # output ``` -and `rrule` +and `rrule` which contains a mistake in the first cotangent ```jldoctest ex function ChainRulesCore.rrule(::typeof(two2three), x1, x2) y = two2three(x1, x2) function two2three_pullback(Ȳ) - return (NoTangent(), 2.0*Ȳ[2], 3.0*Ȳ[3]) + return (NoTangent(), 2.1*Ȳ[2], 3.0*Ȳ[3]) end return y, two2three_pullback end @@ -69,22 +69,27 @@ julia> using ChainRulesTestUtils; julia> test_frule(two2three, 3.33, -7.77); Test Summary: | Pass Total Time -test_frule: two2three on Float64,Float64 | 6 6 2.4s +test_frule: two2three on Float64,Float64 | 6 6 2.7s ``` ### Testing the `rrule` -[`test_rrule`](@ref) takes in the function `f`, and primal inputsr `x`. +[`test_rrule`](@ref) takes in the function `f`, and primal inputs `x`. The call will test the `rrule` for function `f` at the point `x`, and similarly to `frule` some rules should be tested at multiple points in the domain. ```jldoctest ex julia> test_rrule(two2three, 3.33, -7.77); -Test Summary: | Pass Total Time -test_rrule: two2three on Float64,Float64 | 10 10 0.9s - +test_rrule: two2three on Float64,Float64: Test Failed at /home/lior/.julia/dev/ChainRulesTestUtils/src/check_result.jl:24 + Expression: isapprox(actual, expected; kwargs...) + Problem: cotangent for input 2, Float64 + Evaluated: isapprox(-4.032, -3.840000000001641; rtol = 1.0e-9, atol = 1.0e-9) +[...] ``` +The output of the test indicates to us the cause of the failure under "Problem:" with the expected (`rrule` derived) and actual finite difference results. +The Problem lies with the cotangent corresponding to input 2 of `rrule`, which is the first cotangent as expected. + ## Scalar example For functions with a single argument and a single output, such as e.g. ReLU, @@ -109,7 +114,7 @@ call. ```jldoctest ex julia> test_scalar(relu, 0.5); Test Summary: | Pass Total Time -test_scalar: relu at 0.5 | 12 12 1.0s +test_scalar: relu at 0.5 | 12 12 1.2s julia> test_scalar(relu, -0.5); From cb53974bb75ba399ca22b5501283f7b73e746311 Mon Sep 17 00:00:00 2001 From: Lior Blech Date: Thu, 14 Dec 2023 20:03:41 +0200 Subject: [PATCH 4/5] fix missing parameter --- src/testers.jl | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/testers.jl b/src/testers.jl index c750a8f..be9607b 100644 --- a/src/testers.jl +++ b/src/testers.jl @@ -312,10 +312,10 @@ function _test_cotangent( end # we marked the argument as non-differentiable -function _test_cotangent(::NoTangent, ad_cotangent, ::NoTangent; kwargs...) +function _test_cotangent(::NoTangent, ad_cotangent, ::NoTangent, msg=""; kwargs...) @test ad_cotangent isa NoTangent end -function _test_cotangent(::NoTangent, ::ZeroTangent, ::NoTangent; kwargs...) +function _test_cotangent(::NoTangent, ::ZeroTangent, ::NoTangent, msg=""; kwargs...) error( "The pullback in the rrule should use NoTangent()" * " rather than ZeroTangent() for non-perturbable arguments." @@ -324,7 +324,8 @@ end function _test_cotangent( ::NoTangent, ad_cotangent::ChainRulesCore.NotImplemented, - ::NoTangent; + ::NoTangent, + msg=""; kwargs..., ) # this situation can occur if a cotangent is not implemented and @@ -334,6 +335,6 @@ function _test_cotangent( # https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/217 @test_broken ad_cotangent isa NoTangent end -function _test_cotangent(::NoTangent, ad_cotangent, fd_cotangent; kwargs...) +function _test_cotangent(::NoTangent, ad_cotangent, fd_cotangent, msg=""; kwargs...) error("cotangent obtained with finite differencing has to be NoTangent()") end From f7388ec5a7a8df3513ce9acdf1be4950567e14d7 Mon Sep 17 00:00:00 2001 From: Lior Blech Date: Thu, 14 Dec 2023 22:38:28 +0200 Subject: [PATCH 5/5] reflect changes to tester in docs --- src/testers.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/testers.jl b/src/testers.jl index be9607b..67dbbf3 100644 --- a/src/testers.jl +++ b/src/testers.jl @@ -283,7 +283,7 @@ function _is_inferrable(f, args...; kwargs...) end """ - _test_cotangent(accum_cotangent, ad_cotangent, fd_cotangent; kwargs...) + _test_cotangent(accum_cotangent, ad_cotangent, fd_cotangent[, msg]; kwargs...) Check if the cotangent `ad_cotangent` from `rrule` is consistent with `accum_tangent` and approximately equal to the cotangent `fd_cotangent` obtained with finite differencing. @@ -291,6 +291,8 @@ approximately equal to the cotangent `fd_cotangent` obtained with finite differe If `accum_cotangent` is `NoTangent()`, i.e., the argument was marked as non-differentiable, `ad_cotangent` and `fd_cotangent` should be `NoTangent()` as well. +If a msg string is given, it is emmited on test failure. + # Keyword arguments - If `check_inferred=true` (the default) and `ad_cotangent` is a thunk, then it is checked if its content can be inferred.