Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Aug 30, 2022
1 parent 15858bf commit c0a1aeb
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 1 deletion.
15 changes: 14 additions & 1 deletion src/rulesets/Base/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -420,11 +420,13 @@ end
##### `mapfoldl(f, g, ::Tuple)`
#####

using Base: mapfoldl_impl

# For tuples there should be no harm in handling `map` first.
# This will also catch `mapreduce`.

function rrule(
cfg::RuleConfig{>:HasReverseMode}, ::typeof(Base.mapfoldl_impl), f::F, op::G, init, x::Tuple;
cfg::RuleConfig{>:HasReverseMode}, ::typeof(mapfoldl_impl), f::F, op::G, init, x::Tuple;
) where {F,G}
y, backmap = rrule(cfg, map, f, x)
z, backred = rrule(cfg, Base.mapfoldl_impl, identity, op, init, y)
Expand All @@ -436,6 +438,11 @@ function rrule(
return z, mapfoldl_pullback_tuple
end

function rrule(::RuleConfig{>:HasReverseMode}, ::typeof(mapfoldl_impl), f, op, init, x::Tuple{})
foldl_pullback_empty(dy) = (NoTangent(), NoTangent(), NoTangent(), dy, NoTangent())
return init, foldl_pullback_empty
end

#####
##### `foldl(f, ::Tuple)`
#####
Expand Down Expand Up @@ -495,6 +502,12 @@ function rrule(
return y, foldl_pullback_tuple_init
end

# Base.tail doesn't work on (), trivial case:
function rrule(::RuleConfig{>:HasReverseMode}, ::typeof(mapfoldl_impl), ::typeof(identity), op, init, x::Tuple{})
foldl_pullback_empty(dy) = (NoTangent(), NoTangent(), NoTangent(), dy, NoTangent())
return init, foldl_pullback_empty
end

#####
##### `foldl(f, ::Array)`
#####
Expand Down
4 changes: 4 additions & 0 deletions test/rulesets/Base/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,10 @@ const _INIT = Base._InitialValue()
# Finite differencing
test_rrule(mapfoldl_impl, identity, /, _INIT, Tuple(1 .+ rand(5)))
test_rrule(mapfoldl_impl, identity, *, 1+rand(), Tuple(rand(ComplexF64, 5)))

# Trivial case
test_rrule(mapfoldl_impl, identity, /, 2pi, ())
test_rrule(mapfoldl_impl, sqrt, /, 2pi, ())
end
@testset "mapfoldl(f, g, ::Tuple)" begin
test_rrule(mapfoldl_impl, cbrt, /, _INIT, Tuple(1 .+ rand(5)), check_inferred=false)
Expand Down

0 comments on commit c0a1aeb

Please sign in to comment.