From 4c6433a525cdff4109152f8b9565cdfb66a51878 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 18 Jul 2022 21:03:26 -0400 Subject: [PATCH] fix accumulate tests --- test/rulesets/Base/mapreduce.jl | 45 ++++++++------------------------- 1 file changed, 10 insertions(+), 35 deletions(-) diff --git a/test/rulesets/Base/mapreduce.jl b/test/rulesets/Base/mapreduce.jl index 4be5c87c0..641475b08 100644 --- a/test/rulesets/Base/mapreduce.jl +++ b/test/rulesets/Base/mapreduce.jl @@ -358,10 +358,9 @@ end end end # cumprod - @testset "accumulate(f, ::Array)" begin + @testset "accumulate(f, ::Vector)" begin # `accumulate(f, A; init)` goes to `_accumulate!(op, B, A, dims::Nothing, init::Nothing)`. # The rule is now attached there, as this is the simplest way to handle `init` keyword. - @eval using Base: _accumulate! # Simple y1, b1 = rrule(CFG, _accumulate!, *, [0, 0, 0, 0], [1, 2, 3, 4], nothing, Some(1)) @@ -371,9 +370,9 @@ end @test b1([1, 1, 1, 1])[6] isa Tangent{Some{Int64}} @test b1([1, 1, 1, 1])[6].value isa ChainRulesCore.NotImplemented - y2, b2 = rrule(CFG, accumulate, /, [1 2; 3 4]) - @test y2 ≈ accumulate(/, [1 2; 3 4]) - @test b2(ones(2, 2))[3] ≈ [1.5416666 -0.104166664; -0.18055555 -0.010416667] atol=1e-6 + # y2, b2 = rrule(CFG, _accumulate!, /, [0 0; 0 0], [1 2; 3 4], :, nothing) + # @test y2 ≈ accumulate(/, [1 2; 3 4.0]) + # @test b2(ones(2, 2))[3] ≈ [1.5416666 -0.104166664; -0.18055555 -0.010416667] atol=1e-6 # Test execution order c3 = Counter() @@ -403,35 +402,11 @@ end # ForwardDiff.gradient(z -> sum(accumulate((x,y)->x*y*13, z, init=3)), [5,7,11]) |> string # Finite differencing - test_rrule(accumulate, *, randn(5); fkwargs=(; init=rand())) - test_rrule(accumulate, /, 1 .+ rand(3, 4)) - test_rrule(accumulate, ^, 1 .+ rand(2, 3); fkwargs=(; init=rand())) + # test_rrule(accumulate, *, randn(5); fkwargs=(; init=rand())) + test_rrule(_accumulate!, *, randn(5) ⊢ NoTangent(), randn(5), nothing, Some(rand())) + # test_rrule(accumulate, /, 1 .+ rand(3, 4)) + test_rrule(_accumulate!, /, randn(4) ⊢ NoTangent(), 1 .+ rand(4), nothing, nothing) + # test_rrule(accumulate, ^, 1 .+ rand(2, 3); fkwargs=(; init=rand())) + test_rrule(_accumulate!, ^, randn(6) ⊢ NoTangent(), 1 .+ rand(6), nothing, Some(rand())) end - @testset "accumulate(f, ::Tuple)" begin - # Simple - y1, b1 = rrule(CFG, accumulate, *, (1, 2, 3, 4); init=1) - @test y1 == (1, 2, 6, 24) - @test b1((1, 1, 1, 1)) == (NoTangent(), NoTangent(), Tangent{NTuple{4,Int}}(33, 16, 10, 6)) - - # Finite differencing - test_rrule(accumulate, *, Tuple(randn(5)); fkwargs=(; init=rand())) - test_rrule(accumulate, /, Tuple(1 .+ rand(5)); check_inferred=false) - - test_rrule(_accumulate!, *, randn(5) ⊢ NoTangent(), randn(5), nothing, nothing) - test_rrule(_accumulate!, /, randn(5) ⊢ NoTangent(), randn(5), nothing, Some(1 + rand())) - # if VERSION >= v"1.5" - # test_rrule(accumulate, /, 1 .+ rand(3, 4)) - # test_rrule(accumulate, ^, 1 .+ rand(2, 3); fkwargs=(; init=rand())) - # end - end - # VERSION >= v"1.5" && @testset "accumulate(f, ::Tuple)" begin - # # Simple - # y1, b1 = rrule(CFG, accumulate, *, (1, 2, 3, 4); init=1) - # @test y1 == (1, 2, 6, 24) - # @test b1((1, 1, 1, 1)) == (NoTangent(), NoTangent(), Tangent{NTuple{4,Int}}(33, 16, 10, 6)) - - # # Finite differencing - # test_rrule(accumulate, *, Tuple(randn(5)); fkwargs=(; init=rand())) - # test_rrule(accumulate, /, Tuple(1 .+ rand(5)); check_inferred=false) - # end end