Skip to content

Commit

Permalink
separate rule for foldl(::Tuple)
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Jul 19, 2022
1 parent 87b4ea4 commit fff84b5
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 12 deletions.
71 changes: 63 additions & 8 deletions src/rulesets/Base/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -417,17 +417,73 @@ end
end

#####
##### `foldl`
#####
##### `foldl(f, ::Tuple)`
#####

# `foldl` guarantees to execute `f` in order, left to right. So it makes sense even when
# this `f` is stateful, in which case the gradient must be calculated in the reverse order.
# this `f` is stateful, in which case the gradient must be calculated in the reverse order.

# The rule is attached to `Base.mapfoldl_impl` because this gets the `init` keyword as an argument,
# which is handled below. For tuples, `reduce` also comes here.

function rrule(
config::RuleConfig{>:HasReverseMode},
::typeof(Base.mapfoldl_impl),
::typeof(identity),
op::G,
init::Base._InitialValue,
x::Tuple;
) where {G}
hobbits = accumulate(Base.tail(x); init=(first(x), nothing)) do (a, _), b
# Here `a` is what we would normally cary forward, and `_` ignores
# the previous iteration's pullback function (needed later),
# while `b` is the fresh input from `list` as usual.
c, back = rrule_via_ad(config, op, a, b)
# We don't really need to store every `c`, last one is `foldl` output.
# (The name, BTW, is because "there and back again" is the subtitle of Tolkien's book.)
end
y = first(last(hobbits))
project = ProjectTo(x)
function foldl_pullback_tuple(dy)
trio = accumulate(_reverse1(hobbits); init=(0, dy, 0)) do (_, dc, _), (_, back)
ds, da, db = back(dc)
# Don't need to store every `da`, need one for the next iteration + the last.
end
dop = sum(first, trio)
dx = (trio[end][2], reverse(map(last, trio))...)
return (NoTangent(), NoTangent(), ProjectTo(op)(dop), NoTangent(), project(dx))
end
return y, foldl_pullback_tuple
end

function rrule(
config::RuleConfig{>:HasReverseMode},
::typeof(Base.mapfoldl_impl),
::typeof(identity),
op::G,
init,
x::Tuple;
) where {G}
# Treat `init` by simply appending it to the `x`:
y, back = rrule(config, Base.mapfoldl_impl, identity, op, Base._InitialValue(), (init, x...))
project_x = ProjectTo(x)
project_in = ProjectTo(init)
function foldl_pullback_tuple_init(dy)
_, _, dop, _, dxplus = back(dy)
return (NoTangent(), NoTangent(), dop, project_in(first(dxplus)), project_x(Base.tail(dxplus)))
end
return y, foldl_pullback_tuple_init
end

# The implementation aims to be efficient for both tuples and arrays, although using accumulate
# to carry intermediate results along creates arrays of tuples which could be avoided; using a
# loop can be a few times faster. Note also that it does not return a gradient for `init`.
#####
##### `foldl(f, ::Array)`
#####

# Maybe that's a problem. Let's move the rule to `mapfoldr_impl(f, op, init, itr)`, where it's easier?
# The implementation was originally for both tuples and arrays, although using accumulate
# to carry intermediate results along creates arrays of tuples which could be avoided.
# Using a loop can be a few times faster, this should be replaced.
# Note also that it does not return a gradient for `init`.

function rrule(
config::RuleConfig{>:HasReverseMode}, ::typeof(Base.mapfoldl_impl), ::typeof(identity), op::G, init, x::Union{AbstractArray, Tuple};
Expand Down Expand Up @@ -486,8 +542,7 @@ _reverse1(x::Tuple) = reverse(x)
_drop1(x::Tuple) = Base.tail(x)
_zip2(x::Tuple{Vararg{Any,N}}, y::Tuple{Vararg{Any,N}}) where N = ntuple(i -> (x[i],y[i]), N)

# struct _InitialValue end # Old versions don't have `Base._InitialValue`
const _INIT = VERSION >= v"1.5" ? Base._InitialValue() : NamedTuple()
const _INIT = Base._InitialValue()

_vcat1(x, ys::AbstractVector) = vcat(x, ys)
_vcat1(x::AbstractArray, ys::AbstractVector) = vcat([x], ys)
Expand Down
30 changes: 26 additions & 4 deletions test/rulesets/Base/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ struct SumRuleConfig <: RuleConfig{Union{HasReverseMode}} end

const CFG = ChainRulesTestUtils.ADviaRuleConfig()

using Base: mapfoldl_impl, _accumulate! # for foldl & accumulate rules
const _INIT = Base._InitialValue()

@testset "Reductions" begin
@testset "sum(::Tuple)" begin
test_frule(sum, Tuple(rand(5)))
Expand Down Expand Up @@ -216,8 +219,6 @@ const CFG = ChainRulesTestUtils.ADviaRuleConfig()
@testset "foldl(f, ::Array)" begin
# `foldl(op, itr; init)` goes to `mapfoldr_impl(identity, op, init, itr)`. The rule is
# now attached there, as this is the simplest way to handle `init` keyword.
@eval using Base: mapfoldl_impl
_INIT = VERSION >= v"1.5" ? Base._InitialValue() : NamedTuple()

# Simple
y1, b1 = rrule(CFG, mapfoldl_impl, identity, *, 1, [1, 2, 3])
Expand Down Expand Up @@ -268,18 +269,39 @@ const CFG = ChainRulesTestUtils.ADviaRuleConfig()
test_rrule(mapfoldl_impl, identity, max, 999, rand(3))
end
@testset "foldl(f, ::Tuple)" begin
y1, b1 = rrule(CFG, foldl, *, (1,2,3); init=1)
y1, b1 = rrule(CFG, mapfoldl_impl, identity, *, 1, (1,2,3))
@test y1 == 6
@test b1(7)[5] == Tangent{NTuple{3,Int}}(42, 21, 14)

y2, b2 = rrule(CFG, mapfoldl_impl, identity, *, _INIT, (1, 2, 0, 4))
@test y2 == 0
@test b2(8)[5] == Tangent{NTuple{4,Int}}(0, 0, 64, 0)

# Test execution order
c5 = Counter()
y5, b5 = rrule(CFG, mapfoldl_impl, identity, c5, _INIT, (5, 7, 11))
@test c5 == Counter(2)
@test y5 == ((5 + 7)*1 + 11)*2 == foldl(Counter(), (5, 7, 11))
@test collect(b5(1)[5]) == [12*32, 12*42, 22]
@test c5 == Counter(42)

c6 = Counter()
y6, b6 = rrule(CFG, mapfoldl_impl, identity, c6, 3, (5, 7, 11))
@test c6 == Counter(3)
@test y6 == (((3 + 5)*1 + 7)*2 + 11)*3 == foldl(Counter(), (5, 7, 11), init=3)
@test collect(b6(1)[5]) == [63*33*13, 43*13, 23]
@test c6 == Counter(63)

# Test gradient of function
y7, b7 = rrule(CFG, mapfoldl_impl, identity, Multiplier(3), _INIT, (5, 7, 11))
@test y7 == foldl((x,y)->x*y*3, (5, 7, 11))
b7_1 = b7(1)
@test b7_1[3] == Tangent{Multiplier{Int}}(x = 2310,)
@test collect(b7_1[5]) == [693, 495, 315]

# Finite differencing
test_rrule(mapfoldl_impl, identity, /, _INIT, Tuple(1 .+ rand(5)))
test_rrule(mapfoldl_impl, identity, *, _INIT, Tuple(rand(ComplexF64, 5)))
test_rrule(mapfoldl_impl, identity, *, 1+rand(), Tuple(rand(ComplexF64, 5)))
end
end

Expand Down

0 comments on commit fff84b5

Please sign in to comment.