Skip to content

Commit

Permalink
simple rule for mapfoldl
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Jul 19, 2022
1 parent 4c6433a commit 9af7a64
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
ChainRulesCore = "1.12"
ChainRulesCore = "1.15.3"
ChainRulesTestUtils = "1.5"
Compat = "3.42.0, 4"
FiniteDifferences = "0.12.20"
Expand Down
19 changes: 19 additions & 0 deletions src/rulesets/Base/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,25 @@ end
end

#####
##### `mapfoldl(f, g, ::Tuple)`
#####

# For tuples there should be no harm in handling `map` first.
# This will also catch `mapreduce`.

function rrule(
cfg::RuleConfig{>:HasReverseMode}, ::typeof(Base.mapfoldl_impl), f::F, op::G, init, x::Tuple;
) where {F,G}
y, backmap = rrule(cfg, map, f, x)
z, backred = rrule(cfg, Base.mapfoldl_impl, identity, op, init, y)
function mapfoldl_pullback_tuple(dz)
_, _, dop, dinit, dy = backred(dz)
_, df, dx = backmap(dy)
return (NoTangent(), df, dop, dinit, dx)
end
return z, mapfoldl_pullback_tuple
end

#####
##### `foldl(f, ::Tuple)`
#####
Expand Down
5 changes: 5 additions & 0 deletions test/rulesets/Base/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,11 @@ const _INIT = Base._InitialValue()
test_rrule(mapfoldl_impl, identity, /, _INIT, Tuple(1 .+ rand(5)))
test_rrule(mapfoldl_impl, identity, *, 1+rand(), Tuple(rand(ComplexF64, 5)))
end
@testset "mapfoldl(f, g, ::Tuple)" begin
test_rrule(mapfoldl_impl, cbrt, /, _INIT, Tuple(1 .+ rand(5)), check_inferred=false)
test_rrule(mapfoldl_impl, abs2, *, 1+rand(), Tuple(rand(ComplexF64, 5)), check_inferred=false)
# TODO make the `map(f, ::Tuple)` rule infer better!
end
end

@testset "Accumulations" begin
Expand Down

0 comments on commit 9af7a64

Please sign in to comment.