Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding non-mutating recur for the new chain interface. #7

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/Fluxperimental.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ export shinkansen!

include("chain.jl")


include("recur.jl")


include("compact.jl")

end # module Fluxperimental
13 changes: 9 additions & 4 deletions src/chain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ function apply(chain::Flux.Chain, x)
Flux.Chain(layers), out
end

function apply(chain::Flux.Chain, x::Union{AbstractVector{<:AbstractArray}, Base.Generator})
layers, out = _apply(chain.layers, x)
Flux.Chain(layers), out
end

function _apply(layers::NamedTuple{NMS, TPS}, x) where {NMS, TPS}
layers, out = _apply(Tuple(layers), x)
NamedTuple{NMS}(layers), out
Expand All @@ -18,7 +23,7 @@ end
function _scan(layers::AbstractVector, x)
new_layers = typeof(layers)(undef, length(layers))
for (idx, f) in enumerate(layers)
new_layers[idx], x = _apply(f, x)
new_layers[idx], x = apply(f, x)
end
new_layers, x
end
Expand All @@ -27,7 +32,7 @@ end
# example pulled from https://github.com/mcabbott/Flux.jl/blob/chain_rrule/src/cuda/cuda.jl
function ChainRulesCore.rrule(cfg::ChainRulesCore.RuleConfig, ::typeof(_scan), layers, x)
duo = accumulate(layers; init=((nothing, x), nothing)) do ((pl, input), _), cur_layer
out, back = ChainRulesCore.rrule_via_ad(cfg, _apply, cur_layer, input)
out, back = ChainRulesCore.rrule_via_ad(cfg, apply, cur_layer, input)
end
outs = map(first, duo)
backs = map(last, duo)
Expand All @@ -52,11 +57,11 @@ end
@generated function _apply(layers::Tuple{Vararg{<:Any,N}}, x) where {N}
x_symbols = vcat(:x, [gensym() for _ in 1:N])
l_symbols = [gensym() for _ in 1:N]
calls = [:(($(l_symbols[i]), $(x_symbols[i+1])) = _apply(layers[$i], $(x_symbols[i]))) for i in 1:N]
calls = [:(($(l_symbols[i]), $(x_symbols[i+1])) = apply(layers[$i], $(x_symbols[i]))) for i in 1:N]
push!(calls, :(return tuple($(l_symbols...)), $(x_symbols[end])))
Expr(:block, calls...)
end

_apply(layer, x) = layer, layer(x)
apply(layer, x) = layer, layer(x)


97 changes: 97 additions & 0 deletions src/recur.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""
NM_Recur
Non-mutating Recur. An experimental recur interface for the new chain api.
"""
struct NM_Recur{RET_SEQUENCE, T, S}
cell::T
state::S
function NM_Recur(cell, state; return_sequence::Bool=false)
new{return_sequence, typeof(cell), typeof(state)}(cell, state)
end
function NM_Recur{true}(cell, state)
new{true, typeof(cell), typeof(state)}(cell, state)
end
function NM_Recur{false}(cell, state)
new{false, typeof(cell), typeof(state)}(cell, state)
end
end

function apply(m::NM_Recur, x)
state, y = m.cell(m.state, x)
return NM_Recur(m.cell, state), y
end

# This is the same way we do 3-tensers from Flux.Recur
function apply(m::NM_Recur{true}, x::AbstractArray{T, 3}) where T
# h = [m(x_t) for x_t in eachlastdim(x)]
l, h = apply(m, Flux.eachlastdim(x))
sze = size(h[1])
l, reshape(reduce(hcat, h), sze[1], sze[2], length(h))
end

function apply(m::NM_Recur{false}, x::AbstractArray{T, 3}) where T
apply(m, Flux.eachlastdim(x))
end

function apply(l::NM_Recur{false}, xs::Union{AbstractVector{<:AbstractArray}, Base.Generator})
rnn = l.cell
# carry = layer.stamte
x_init, x_rest = Iterators.peel(xs)
(carry, y) = rnn(l.state, x_init)
for x in x_rest
(carry, y) = rnn(carry, x)
end
NM_Recur{false}(rnn, carry), y
end

# From Lux.jl: https://github.com/LuxDL/Lux.jl/pull/287/
function apply(l::NM_Recur{true}, xs::Union{AbstractVector{<:AbstractArray}, Base.Generator})
rnn = l.cell
_xs = if xs isa Base.Generator
collect(xs) # TODO: Fix. I can't figure out how to get around this for generators.
else
xs
end
x_init, _ = Iterators.peel(_xs)

(carry, out_) = rnn(l.state, x_init)

init = (typeof(out_)[out_], carry)

function recurrence_op(input, (outputs, carry))
carry, out = rnn(carry, input)
return vcat(outputs, typeof(out)[out]), carry
end
results = foldr(recurrence_op, _xs[(begin+1):end]; init)
return NM_Recur{true}(rnn, results[1][end]), first(results)
end

Flux.@functor NM_Recur
Flux.trainable(a::NM_Recur) = (; cell = a.cell)

Base.show(io::IO, m::NM_Recur) = print(io, "Recur(", m.cell, ")")

NM_RNN(a...; return_sequence::Bool=false, ka...) = NM_Recur(Flux.RNNCell(a...; ka...); return_sequence=return_sequence)
NM_Recur(m::Flux.RNNCell; return_sequence::Bool=false) = NM_Recur(m, m.state0; return_sequence=return_sequence)

# Quick Reset functionality

struct RecurWalk <: Flux.Functors.AbstractWalk end
(::RecurWalk)(recurse, x) = x isa Fluxperimental.NM_Recur ? reset(x) : Flux.Functors.DefaultWalk()(recurse, x)

function reset(m::NM_Recur{SEQ}) where SEQ
NM_Recur{SEQ}(m.cell, m.cell.state0)
end
reset(m) = m
function reset(m::Flux.Chain)
ret = Flux.Functors.fmap((l)->l, m; walk=RecurWalk())
end


##
# Fallback apply timeseries data to other layers. Likely needs to be thoought through a bit more.
##

function apply(l, xs::Union{AbstractVector{<:AbstractArray}, Base.Generator})
l, [l(x) for x in xs]
end
109 changes: 109 additions & 0 deletions test/recur.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@

@testset "RNN gradients-implicit" begin
cell = Flux.RNNCell(1, 1, identity)
layer = Flux.Recur(cell)
layer.cell.Wi .= 5.0
layer.cell.Wh .= 4.0
layer.cell.b .= 0.0f0
layer.cell.state0 .= 7.0
x = [[2.0f0], [3.0f0]]

# theoretical primal gradients
primal =
layer.cell.Wh .* (layer.cell.Wh * layer.cell.state0 .+ x[1] .* layer.cell.Wi) .+
x[2] .* layer.cell.Wi
∇Wi = x[1] .* layer.cell.Wh .+ x[2]
∇Wh = 2 .* layer.cell.Wh .* layer.cell.state0 .+ x[1] .* layer.cell.Wi
∇b = layer.cell.Wh .+ 1
∇state0 = layer.cell.Wh .^ 2

nm_layer = Fluxperimental.NM_Recur(cell; return_sequence = true)
ps = Flux.params(nm_layer)
e, g = Flux.withgradient(ps) do
l, out = Fluxperimental.apply(nm_layer, x)
sum(out[2])
end

@test primal[1] ≈ e
@test ∇Wi ≈ g[ps[1]]
@test ∇Wh ≈ g[ps[2]]
@test ∇b ≈ g[ps[3]]
@test ∇state0 ≈ g[ps[4]]
end

@testset "RNN gradients-implicit-partial sequence" begin
cell = Flux.RNNCell(1, 1, identity)
layer = Flux.Recur(cell)
layer.cell.Wi .= 5.0
layer.cell.Wh .= 4.0
layer.cell.b .= 0.0f0
layer.cell.state0 .= 7.0
x = [[2.0f0], [3.0f0]]

# theoretical primal gradients
primal =
layer.cell.Wh .* (layer.cell.Wh * layer.cell.state0 .+ x[1] .* layer.cell.Wi) .+
x[2] .* layer.cell.Wi
∇Wi = x[1] .* layer.cell.Wh .+ x[2]
∇Wh = 2 .* layer.cell.Wh .* layer.cell.state0 .+ x[1] .* layer.cell.Wi
∇b = layer.cell.Wh .+ 1
∇state0 = layer.cell.Wh .^ 2

nm_layer = Fluxperimental.NM_Recur(cell; return_sequence = false)
ps = Flux.params(nm_layer)
e, g = Flux.withgradient(ps) do
l, out = Fluxperimental.apply(nm_layer, x)
sum(out)
end

@test primal[1] ≈ e
@test ∇Wi ≈ g[ps[1]]
@test ∇Wh ≈ g[ps[2]]
@test ∇b ≈ g[ps[3]]
@test ∇state0 ≈ g[ps[4]]
end

@testset "RNN gradients-explicit partial sequence" begin


cell = Flux.RNNCell(1, 1, identity)
layer = Flux.Recur(cell)
layer.cell.Wi .= 5.0
layer.cell.Wh .= 4.0
layer.cell.b .= 0.0f0
layer.cell.state0 .= 7.0
x = [[2.0f0], [3.0f0]]

# theoretical primal gradients
primal =
layer.cell.Wh .* (layer.cell.Wh * layer.cell.state0 .+ x[1] .* layer.cell.Wi) .+
x[2] .* layer.cell.Wi
∇Wi = x[1] .* layer.cell.Wh .+ x[2]
∇Wh = 2 .* layer.cell.Wh .* layer.cell.state0 .+ x[1] .* layer.cell.Wi
∇b = layer.cell.Wh .+ 1
∇state0 = layer.cell.Wh .^ 2



nm_layer = Fluxperimental.NM_Recur(cell; return_sequence = false)
e, g = Flux.withgradient(nm_layer) do layer
r_l = Fluxperimental.reset(layer)
l, out = Fluxperimental.apply(r_l, x)
sum(out)
end
grads = g[1][:cell]

@test primal[1] ≈ e

if VERSION < v"1.7"
@test ∇Wi ≈ grads[:Wi]
@test ∇Wh ≈ grads[:Wh]
@test ∇b ≈ grads[:b]
@test ∇state0 ≈ grads[:state0]
else
@test ∇Wi ≈ grads[:Wi]
@test ∇Wh ≈ grads[:Wh]
@test ∇b ≈ grads[:b]
@test ∇state0 ≈ grads[:state0]
end
end
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ using Flux, Fluxperimental

include("chain.jl")

include("recur.jl")

include("compact.jl")

end