From 633387997209a8288b20f5621d409b5ccbd90ea9 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 29 Aug 2022 20:33:44 -0400 Subject: [PATCH] avoid ambiguities --- src/rulesets/Base/mapreduce.jl | 15 ++++----------- test/rulesets/Base/mapreduce.jl | 1 - 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index 83ed130c4..80bca8cdf 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -438,11 +438,6 @@ function rrule( return z, mapfoldl_pullback_tuple end -function rrule(::RuleConfig{>:HasReverseMode}, ::typeof(mapfoldl_impl), f, op, init, x::Tuple{}) - foldl_pullback_empty(dy) = (NoTangent(), NoTangent(), NoTangent(), dy, NoTangent()) - return init, foldl_pullback_empty -end - ##### ##### `foldl(f, ::Tuple)` ##### @@ -491,6 +486,10 @@ function rrule( init, x::Tuple; ) where {G} + # Trivial case handled here to avoid ambiguities (and necc. because of Base.tail below) + foldl_pullback_empty(dy) = (NoTangent(), NoTangent(), NoTangent(), dy, NoTangent()) + isempty(x) && return init, foldl_pullback_empty + # Treat `init` by simply appending it to the `x`: y, back = rrule(config, Base.mapfoldl_impl, identity, op, Base._InitialValue(), (init, x...)) project_x = ProjectTo(x) @@ -502,12 +501,6 @@ function rrule( return y, foldl_pullback_tuple_init end -# Base.tail doesn't work on (), trivial case: -function rrule(::RuleConfig{>:HasReverseMode}, ::typeof(mapfoldl_impl), ::typeof(identity), op, init, x::Tuple{}) - foldl_pullback_empty(dy) = (NoTangent(), NoTangent(), NoTangent(), dy, NoTangent()) - return init, foldl_pullback_empty -end - ##### ##### `foldl(f, ::Array)` ##### diff --git a/test/rulesets/Base/mapreduce.jl b/test/rulesets/Base/mapreduce.jl index bd3031ed5..494639f80 100644 --- a/test/rulesets/Base/mapreduce.jl +++ b/test/rulesets/Base/mapreduce.jl @@ -306,7 +306,6 @@ const _INIT = Base._InitialValue() # Trivial case test_rrule(mapfoldl_impl, identity, /, 2pi, ()) - test_rrule(mapfoldl_impl, sqrt, /, 2pi, ()) end @testset "mapfoldl(f, g, ::Tuple)" begin test_rrule(mapfoldl_impl, cbrt, /, _INIT, Tuple(1 .+ rand(5)), check_inferred=false)