Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NewRecur experimental interface #11

Merged
merged 13 commits into from
Aug 9, 2023
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ uuid = "3102ee7a-c841-4564-8f7f-ec69bd4fd658"
version = "0.1.2"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Expand Down
2 changes: 2 additions & 0 deletions src/Fluxperimental.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,6 @@ include("chain.jl")

include("compact.jl")

include("new_recur.jl")

end # module Fluxperimental
140 changes: 140 additions & 0 deletions src/new_recur.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import Flux: ChainRulesCore
import Compat: stack

##### Helper scan funtion which can likely be put into NNLib. #####
"""
scan_full

Recreating jax.lax.scan functionality in julia. Takes a function, initial carry and a sequence, then returns the full output of the sequence and the final carry. See `scan_partial` to only return the final output of the sequence.
"""
function scan_full(func, init_carry, xs::AbstractVector{<:AbstractArray})
# Recurrence operation used in the fold. Takes the state of the
# fold and the next input, returns the new state.
function recurrence_op((carry, outputs), input)
carry, out = func(carry, input)
return carry, vcat(outputs, [out])
end
# Fold left to right.
return Base.mapfoldl_impl(identity, recurrence_op, (init_carry, empty(xs)), xs)
end

function scan_full(func, init_carry, x_block)
# x_block is an abstractarray and we want to scan over the last dimension.
xs_ = Flux.eachlastdim(x_block)

# this is needed due to a bug in eachlastdim which produces a vector in a
# gradient context, but a generator otherwise.
xs = if xs_ isa Base.Generator
collect(xs_) # eachlastdim produces a generator in non-gradient environment
else
xs_
end
scan_full(func, init_carry, xs)
end

# Chain Rule for Base.mapfoldl_impl
function ChainRulesCore.rrule(
config::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasReverseMode},
::typeof(Base.mapfoldl_impl),
::typeof(identity),
op::G,
init,
x::Union{AbstractArray, Tuple};
) where {G}
hobbits = Vector{Any}(undef, length(x)) # Unfornately Zygote needs this
accumulate!(hobbits, x; init=(init, nothing)) do (a, _), b
c, back = ChainRulesCore.rrule_via_ad(config, op, a, b)
end
y = first(last(hobbits))
axe = axes(x)
project = ChainRulesCore.ProjectTo(x)
function unfoldl(dy)
trio = accumulate(Iterators.reverse(hobbits); init=(0, dy, 0)) do (_, dc, _), (_, back)
ds, da, db = back(dc)
end
dop = sum(first, trio)
dx = map(last, Iterators.reverse(trio))
d_init = trio[end][2]
return (ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), dop, d_init, project(reshape(dx, axe)))
end
return y, unfoldl
end


"""
scan_partial

Recreating jax.lax.scan functionality in julia. Takes a function, initial carry and a sequence, then returns the final output of the sequence and the final carry. See `scan_full` to return the entire output sequence.
"""
function scan_partial(func, init_carry, xs::AbstractVector{<:AbstractArray})
x_init, x_rest = Iterators.peel(xs)
(carry, y) = func(init_carry, x_init)
for x in x_rest
(carry, y) = func(carry, x)
end
carry, y
end

function scan_partial(func, init_carry, x_block)
# x_block is an abstractarray and we want to scan over the last dimension.
xs_ = Flux.eachlastdim(x_block)

# this is needed due to a bug in eachlastdim which produces a vector in a
# gradient context, but a generator otherwise.
xs = if xs_ isa Base.Generator
collect(xs_) # eachlastdim produces a generator in non-gradient environment
else
xs_
end
scan_partial(func, init_carry, xs)
end


"""
NewRecur
New Recur. An experimental recur interface for removing statefullness in recurrent architectures for flux. This struct has two type parameters. The first `RET_SEQUENCE` is a boolean which determines whether `scan_full` (`RET_SEQUENCE=true`) or `scan_partial` (`RET_SEQUENCE=false`) is used to scan through the sequence. This structure has no internal state, and instead returns:

```julia
l = NewRNN(1,2)
xs # Some input array Input x BatchSize x Time
init_carry # the initial carry of the cell.
l(xs) # -> returns the output of the RNN, uses cell.state0 as init_carry.
l(init_carry, xs) # -> returns (final_carry, output), where the size ofoutput is determined by RET_SEQUENCE.
```
"""
struct NewRecur{RET_SEQUENCE, T}
cell::T
# state::S
function NewRecur(cell; return_sequence::Bool=false)
new{return_sequence, typeof(cell)}(cell)
end
function NewRecur{true}(cell)
new{true, typeof(cell)}(cell)
end
function NewRecur{false}(cell)
new{false, typeof(cell)}(cell)
end
end

Flux.@functor NewRecur
Flux.trainable(a::NewRecur) = (; cell = a.cell)
Base.show(io::IO, m::NewRecur) = print(io, "Recur(", m.cell, ")")
NewRNN(a...; return_sequence::Bool=false, ka...) = NewRecur(Flux.RNNCell(a...; ka...); return_sequence=return_sequence)

(l::NewRecur)(init_carry, x_mat::AbstractMatrix) = MethodError("Matrix is ambiguous with NewRecur")
(l::NewRecur)(init_carry, x_mat::AbstractVector{T}) where {T<:Number} = MethodError("Vector is ambiguous with NewRecur")

function (l::NewRecur)(xs::AbstractArray)
results = l(l.cell.state0, xs)
results[2] # Only return the output here.
end

function (l::NewRecur{false})(init_carry, xs)
results = scan_partial(l.cell, init_carry, xs)
results[1], results[2]
end

function (l::NewRecur{true})(init_carry, xs)
results = scan_full(l.cell, init_carry, xs)
results[1], stack(results[2], dims=3)
end
188 changes: 188 additions & 0 deletions test/new_recur.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
@testset "NewRecur RNN" begin
@testset "Forward Pass" begin
# tanh is needed for forward check to determine ordering of inputs.
cell = Flux.RNNCell(1, 1, tanh)
layer = Fluxperimental.NewRecur(cell; return_sequence=true)
layer.cell.Wi .= 5.0
layer.cell.Wh .= 4.0
layer.cell.b .= 0.0f0
layer.cell.state0 .= 7.0
x = reshape([2.0f0, 3.0f0], 1, 1, 2)

# Lets make sure th output is correct
h = cell.state0
h, out = cell(h, [2.0f0])
h, out = cell(h, [3.0f0])

@test eltype(layer(x)) <: Float32
@test size(layer(x)) == (1, 1, 2)
@test layer(x)[1, 1, 2] ≈ out[1,1]

@test length(layer(cell.state0, x)) == 2 # should return a tuple. Maybe better test is needed.
@test layer(cell.state0, x)[2][1,1,2] ≈ out[1,1]

@test_throws MethodError layer([2.0f0])
@test_throws MethodError layer([2.0f0;; 3.0f0])
end

@testset "gradients-implicit" begin
cell = Flux.RNNCell(1, 1, identity)
layer = Flux.Recur(cell)
layer.cell.Wi .= 5.0
layer.cell.Wh .= 4.0
layer.cell.b .= 0.0f0
layer.cell.state0 .= 7.0
x = [[2.0f0], [3.0f0]]

# theoretical primal gradients
primal =
layer.cell.Wh .* (layer.cell.Wh * layer.cell.state0 .+ x[1] .* layer.cell.Wi) .+
x[2] .* layer.cell.Wi
∇Wi = x[1] .* layer.cell.Wh .+ x[2]
∇Wh = 2 .* layer.cell.Wh .* layer.cell.state0 .+ x[1] .* layer.cell.Wi
∇b = layer.cell.Wh .+ 1
∇state0 = layer.cell.Wh .^ 2

nm_layer = Fluxperimental.NewRecur(cell; return_sequence = true)
ps = Flux.params(nm_layer)
x_block = reshape(vcat(x...), 1, 1, length(x))
e, g = Flux.withgradient(ps) do
out = nm_layer(x_block)
sum(out[1, 1, 2])
end

@test primal[1] ≈ e
@test ∇Wi ≈ g[ps[1]]
@test ∇Wh ≈ g[ps[2]]
@test ∇b ≈ g[ps[3]]
@test ∇state0 ≈ g[ps[4]]
end

@testset "gradients-explicit" begin

cell = Flux.RNNCell(1, 1, identity)
layer = Flux.Recur(cell)
layer.cell.Wi .= 5.0
layer.cell.Wh .= 4.0
layer.cell.b .= 0.0f0
layer.cell.state0 .= 7.0
x = [[2.0f0], [3.0f0]]

# theoretical primal gradients
primal =
layer.cell.Wh .* (layer.cell.Wh * layer.cell.state0 .+ x[1] .* layer.cell.Wi) .+
x[2] .* layer.cell.Wi
∇Wi = x[1] .* layer.cell.Wh .+ x[2]
∇Wh = 2 .* layer.cell.Wh .* layer.cell.state0 .+ x[1] .* layer.cell.Wi
∇b = layer.cell.Wh .+ 1
∇state0 = layer.cell.Wh .^ 2


x_block = reshape(vcat(x...), 1, 1, length(x))
nm_layer = Fluxperimental.NewRecur(cell; return_sequence = true)
e, g = Flux.withgradient(nm_layer) do layer
out = layer(x_block)
sum(out[1, 1, 2])
end
grads = g[1][:cell]

@test primal[1] ≈ e
@test ∇Wi ≈ grads[:Wi]
@test ∇Wh ≈ grads[:Wh]
@test ∇b ≈ grads[:b]
@test ∇state0 ≈ grads[:state0]
end
end

@testset "New Recur RNN Partial Sequence" begin
@testset "Forward Pass" begin
cell = Flux.RNNCell(1, 1, identity)
layer = Fluxperimental.NewRecur(cell)
layer.cell.Wi .= 5.0
layer.cell.Wh .= 4.0
layer.cell.b .= 0.0f0
layer.cell.state0 .= 7.0
x = reshape([2.0f0, 3.0f0], 1, 1, 2)

h = cell.state0
h, out = cell(h, [2.0f0])
h, out = cell(h, [3.0f0])

@test eltype(layer(x)) <: Float32
@test size(layer(x)) == (1, 1)
@test layer(x)[1, 1] ≈ out[1,1]

@test length(layer(cell.state0, x)) == 2
@test layer(cell.state0, x)[2][1,1] ≈ out[1,1]

@test_throws MethodError layer([2.0f0])
@test_throws MethodError layer([2.0f0;; 3.0f0])
end

@testset "gradients-implicit" begin
cell = Flux.RNNCell(1, 1, identity)
layer = Flux.Recur(cell)
layer.cell.Wi .= 5.0
layer.cell.Wh .= 4.0
layer.cell.b .= 0.0f0
layer.cell.state0 .= 7.0
x = [[2.0f0], [3.0f0]]

# theoretical primal gradients
primal =
layer.cell.Wh .* (layer.cell.Wh * layer.cell.state0 .+ x[1] .* layer.cell.Wi) .+
x[2] .* layer.cell.Wi
∇Wi = x[1] .* layer.cell.Wh .+ x[2]
∇Wh = 2 .* layer.cell.Wh .* layer.cell.state0 .+ x[1] .* layer.cell.Wi
∇b = layer.cell.Wh .+ 1
∇state0 = layer.cell.Wh .^ 2

nm_layer = Fluxperimental.NewRecur(cell; return_sequence = false)
ps = Flux.params(nm_layer)
x_block = reshape(vcat(x...), 1, 1, length(x))
e, g = Flux.withgradient(ps) do
out = (nm_layer)(x_block)
sum(out)
end

@test primal[1] ≈ e
@test ∇Wi ≈ g[ps[1]]
@test ∇Wh ≈ g[ps[2]]
@test ∇b ≈ g[ps[3]]
@test ∇state0 ≈ g[ps[4]]
end

@testset "gradients-explicit" begin
cell = Flux.RNNCell(1, 1, identity)
layer = Flux.Recur(cell)
layer.cell.Wi .= 5.0
layer.cell.Wh .= 4.0
layer.cell.b .= 0.0f0
layer.cell.state0 .= 7.0
x = [[2.0f0], [3.0f0]]

# theoretical primal gradients
primal =
layer.cell.Wh .* (layer.cell.Wh * layer.cell.state0 .+ x[1] .* layer.cell.Wi) .+
x[2] .* layer.cell.Wi
∇Wi = x[1] .* layer.cell.Wh .+ x[2]
∇Wh = 2 .* layer.cell.Wh .* layer.cell.state0 .+ x[1] .* layer.cell.Wi
∇b = layer.cell.Wh .+ 1
∇state0 = layer.cell.Wh .^ 2

x_block = reshape(vcat(x...), 1, 1, length(x))
nm_layer = Fluxperimental.NewRecur(cell; return_sequence = false)
e, g = Flux.withgradient(nm_layer) do layer
out = layer(x_block)
sum(out)
end
grads = g[1][:cell]

@test primal[1] ≈ e
@test ∇Wi ≈ grads[:Wi]
@test ∇Wh ≈ grads[:Wh]
@test ∇b ≈ grads[:b]
@test ∇state0 ≈ grads[:state0]

end
end
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,6 @@ using Flux, Fluxperimental

include("compact.jl")

include("new_recur.jl")

end
Loading