Skip to content

Commit

Permalink
tidy up
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Aug 30, 2022
1 parent 3c185ac commit 15858bf
Showing 1 changed file with 31 additions and 65 deletions.
96 changes: 31 additions & 65 deletions src/rulesets/Base/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ function rrule(
y = first(last(hobbits))
project = ProjectTo(x)
function foldl_pullback_tuple(dy)
trio = accumulate(_reverse1(hobbits); init=(0, dy, 0)) do (_, dc, _), (_, back)
trio = accumulate(reverse(hobbits); init=(0, dy, 0)) do (_, dc, _), (_, back)
ds, da, db = back(dc)
# Don't need to store every `da`, need one for the next iteration + the last.
end
Expand Down Expand Up @@ -501,78 +501,43 @@ end

# The implementation was originally for both tuples and arrays, although using accumulate
# to carry intermediate results along creates arrays of tuples which could be avoided.
# Using a loop can be a few times faster, this should be replaced.
# Note also that it does not return a gradient for `init`.
# Using a loop can be a few times faster, this should be replaced:
# https://github.com/FluxML/Zygote.jl/issues/644#issuecomment-628762305

# Note also that it does not return a gradient for `init`, now marked `@not_implemented`.

function rrule(
config::RuleConfig{>:HasReverseMode}, ::typeof(Base.mapfoldl_impl), ::typeof(identity), op::G, init, x::Union{AbstractArray, Tuple};
config::RuleConfig{>:HasReverseMode}, ::typeof(Base.mapfoldl_impl), ::typeof(identity), op::G, init, x::Union{AbstractArray, Tuple};
) where {G}
list, start = if init === _INIT
_drop1(x), first(x)
start, list = if init === Base._InitialValue()
Iterators.peel(x)
else
# Case with init keyword is simpler to understand first!
_reshape1(x, :), init # (vec is for Julia 1.0, accumulate is fussy)
init, x
end
hobbits = accumulate(list; init=(start, nothing)) do (a,_), b
# Here `a` is what we would normally cary forward, and `_` ignores
# the previous iteration's pullback function (needed later),
# while `b` is the fresh input from `list` as usual.
c, back = rrule_via_ad(config, op, a, b) # LHS is just documentation here!
# We don't really need to store every `c`, last one is `foldl` output.
# (The name, BTW, is because "there and back again" is the subtitle of Tolkien's book.)
hobbits = accumulate(list; init=(start, nothing)) do (a, _), b
c, back = rrule_via_ad(config, op, a, b)
end
y = first(last(hobbits))
axe = axes(x)
project = ProjectTo(x)
function unfoldl(dy)
trio = accumulate(_reverse1(hobbits); init=(0, dy, 0)) do (_, dc, _), (_, back)
trio = accumulate(Iterators.reverse(hobbits); init=(0, dy, 0)) do (_, dc, _), (_, back)
ds, da, db = back(dc)
# Don't need to store every `da`, need one for the next iteration + maybe last
end
dop = sum(first, trio)
dx = map(last, _reverse1(trio))
if init === _INIT
# `hobbits` is one short
dx = map(last, Iterators.reverse(trio))
if init === Base._InitialValue() # `hobbits` is one short
dx = _vcat1(trio[end][2], dx)
end
d_init = @not_implemented "gradient for foldl does not at present include init, sorry"
return (NoTangent(), NoTangent(), dop, d_init, project(_reshape1(dx, axe)))
return (NoTangent(), NoTangent(), dop, d_init, project(reshape(dx, axe)))
end
return y, unfoldl
end


#####
##### Iterator-or-Tuple functions
#####

# This zoo of underscore functions helps `foldl` & `accumulate` handle both tuples and arrays,
# and also provides some alternatives for versions of Julia where iterators weren't supported.
# Inspired by `Base._reverse`, used in defn of `foldr`.

# To support 2nd derivatives, some may need their own gradient rules. And _drop1 should perhaps
# be replaced by _peel1 like Iterators.peel

_reverse1(x) = Iterators.reverse(x)
_drop1(x) = Iterators.drop(x, 1)
_zip2(x, y) = zip(x, y) # for `accumulate`, below

_reverse1(x::Tuple) = reverse(x)
_drop1(x::Tuple) = Base.tail(x)
_zip2(x::Tuple{Vararg{Any,N}}, y::Tuple{Vararg{Any,N}}) where N = ntuple(i -> (x[i],y[i]), N)

const _INIT = Base._InitialValue()

_vcat1(x, ys::AbstractVector) = vcat(x, ys)
_vcat1(x::AbstractArray, ys::AbstractVector) = vcat([x], ys)
_vcat1(x, ys::Tuple) = (x, ys...)

_reshape1(x::AbstractArray, axe) = reshape(x, axe)
_reshape1(x::Tuple, axe) = x

_no_tuple_tangent(dx::Tangent) = ChainRulesCore.backing(dx)
_no_tuple_tangent(dx) = dx


#####
##### `accumulate`
Expand All @@ -584,13 +549,18 @@ _no_tuple_tangent(dx) = dx
# Move it down to: `_accumulate!(op, B, A::AbstractVector, dims::Nothing, init::Nothing)`

function rrule(
config::RuleConfig{>:HasReverseMode}, ::typeof(Base._accumulate!), op::G, y, x::AbstractVector, dims::Nothing, init,
config::RuleConfig{>:HasReverseMode},
::typeof(Base._accumulate!),
op::G, y::AbstractVector,
x::AbstractVector,
dims::Nothing,
init,
) where {G}

list, start = if init === nothing
_drop1(x), first(x)
start, list = if init === nothing
Iterators.peel(x)
else
x, something(init)
something(init), x
end
hobbits = accumulate(list; init = (start, nothing)) do (a, _), b
c, back = rrule_via_ad(config, op, a, b)
Expand All @@ -607,28 +577,24 @@ function rrule(
axe = axes(x)
project = ProjectTo(x)
function decumulate(dy)
dy_plain = _no_tuple_tangent(unthunk(dy))
rev_list = if init === nothing
# Here we rely on `zip` to stop early. Begin explicit with _reverse1(_drop1(...))
# gets "no method matching iterate(::Base.Iterators.Reverse{Base.Iterators.Drop{Array{"
_zip2(_reverse1(hobbits), _reverse1(dy_plain))
else
_zip2(_reverse1(hobbits), _reverse1(dy_plain))
end
dy_plain = unthunk(dy)
rev_list = zip(Iterators.reverse(hobbits), Iterators.reverse(dy_plain))
# Here we rely on `zip` to stop early when init === nothing. Begin explicit with Iterators.reverse(Iterators.drop(..., 1))
# gets "no method matching iterate(::Base.Iterators.Reverse{Base.Iterators.Drop{Array{"
trio = accumulate(rev_list; init=(0, ZeroTangent(), 0)) do (_, dc, _), ((_, back), dz)
ds, da, db = back(dc + dz)
# Don't need to store every 'da', but need for next iteration, and the last one.
end
dop = sum(first, trio)
dx = map(last, _reverse1(trio))
dx = map(last, Iterators.reverse(trio))
if init == nothing
# `hobbits` is one short, and the first one is weird
dx = _vcat1(trio[end][2] + dy_plain[1], dx)
end
dy = @not_implemented "no gradient for `B` in `accumulate!(f, B, A)`, the rule intends to support `accumulate` only"
d_init_not = @not_implemented "gradient for accumulate does not at present include init, sorry"
d_init = init === nothing ? NoTangent() : Tangent{typeof(init)}(; value = d_init_not)
return (NoTangent(), dop, dy, project(_reshape1(dx, axe)), NoTangent(), d_init)
return (NoTangent(), dop, dy, project(reshape(dx, axe)), NoTangent(), d_init)
end
return _reshape1(y, axe), decumulate
return reshape(y, axe), decumulate
end

0 comments on commit 15858bf

Please sign in to comment.