Skip to content

Commit

Permalink
Merge #1686
Browse files Browse the repository at this point in the history
1686: Adding support for folding RNNs over 3d arrays r=DhairyaLGandhi a=mkschleg

From #1678, adding a Recur like interface for a folded operation with support for 3-dimensional arrays. This is how many users expect RNNs to work if they are familiar with Pytorch and Tensorflow, and there seems to be some desire for support for this feature as per the discussion in #1671 and `@jeremiedb` .  This will also make a push to implementing support for the CuDNN versions of RNNs/GRUs/LSTMs more streamlined as this is the data layout that API expects. 

I did a barebones implementation to add support so we can start iterating on API.

There are several questions that I have lingering with this interface:
- ~Should we support different modes where we return all or only the last hidden state? Is there a better way to do the concat of the hidden states?~
- What kind of tests should we have? Just follow what we currently do for RNNs/LSTMs/GRUs?
- ~For the CPU version, does it make sense not to specialize on the different rnn types? We might be able to take more advantage of BLAS if we specialized on say `Folded{GRU}`.~
- ~Do we want to force the temporal dimension to be the 2nd?~
- ~Do we want this to be stateful? (i.e. allow the user to change what the starting hidden state is rather than state0).~

### PR Checklist

- [x] Tests are added
- [ ] Entry in NEWS.md
- [x] Documentation, if applicable
- [ ] API changes require approval from a committer (different from the author, if applicable)


Co-authored-by: Matthew Schlegel <mkschleg@gmail.com>
Co-authored-by: Matthew Schlegel <mkschleg@users.noreply.github.com>
Co-authored-by: Dhairya Gandhi <dhairya@juliacomputing.com>
  • Loading branch information
4 people authored Sep 14, 2021
2 parents 9a395b2 + d1b1daf commit f1dbc97
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 17 deletions.
20 changes: 20 additions & 0 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,19 @@ rnn.state # 5
rnn.(1:10) # apply to a sequence
rnn.state # 60
```
Folding over a 3d Array of dimensions `(features, batch, time)` is also supported:
```julia
accum(h, x) = (h .+ x, x)
rnn = Flux.Recur(accum, zeros(Int, 1, 1))
rnn([2]) # 2
rnn([3]) # 3
rnn.state # 5
rnn(reshape(1:10, 1, 1, :)) # apply to a sequence of (features, batch, time)
rnn.state # 60
```
"""
mutable struct Recur{T,S}
cell::T
Expand Down Expand Up @@ -53,6 +66,7 @@ rnn.state = hidden(rnn.cell)
reset!(m::Recur) = (m.state = m.cell.state0)
reset!(m) = foreach(reset!, functor(m)[1])


# TODO remove in v0.13
function Base.getproperty(m::Recur, sym::Symbol)
if sym === :init
Expand All @@ -67,6 +81,12 @@ end

flip(f, xs) = reverse(f.(reverse(xs)))

function (m::Recur)(x::AbstractArray{T, 3}) where T
h = [m(view(x, :, :, i)) for i in 1:size(x, 3)]
sze = size(h[1])
reshape(reduce(hcat, h), sze[1], sze[2], length(h))
end

# Vanilla RNN

struct RNNCell{F,A,V,S}
Expand Down
11 changes: 10 additions & 1 deletion test/cuda/curnn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,15 @@ end
y = (rnn(ohx); rnn(ohx))

cuy = (curnn(cuohx); curnn(cuohx))
@test y collect(cuy)
@test y collect(cuy)

Flux.reset!(rnn)
Flux.reset!(curnn)
fx = rand(Float32, 10, batch_size, 3)
cufx = gpu(fx)
fy = (rnn(fx); rnn(fx))

cufy = (curnn(cufx); curnn(cufx))
@test fy collect(cufy)
end
end
63 changes: 47 additions & 16 deletions test/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,26 +42,57 @@ end
end
end

@testset "BPTT-3D" begin
seq = rand(Float32, (2, 1, 3))
rnn = RNN(2, 3)
Flux.reset!(rnn)
grads_seq = gradient(Flux.params(rnn)) do
sum(rnn(seq)[:, :, 3])
end
Flux.reset!(rnn);
bptt = gradient(rnn.cell.Wh) do Wh
# calculate state 1
s1 = tanh.(rnn.cell.Wi * seq[:, :, 1] +
Wh * rnn.cell.state0 +
rnn.cell.b)
#calculate state 2
s2 = tanh.(rnn.cell.Wi * seq[:, :, 2] +
Wh * s1 +
rnn.cell.b)
#calculate state 3
s3 = tanh.(rnn.cell.Wi * seq[:, :, 3] +
Wh * s2 +
rnn.cell.b)
sum(s3) # loss is sum of state 3
end
@test grads_seq[rnn.cell.Wh] bptt[1]
end

@testset "RNN-shapes" begin
@testset for R in [RNN, GRU, LSTM, GRUv3]
m1 = R(3, 5)
m2 = R(3, 5)
x1 = rand(Float32, 3)
x2 = rand(Float32,3,1)
Flux.reset!(m1)
Flux.reset!(m2)
@test size(m1(x1)) == (5,)
@test size(m1(x1)) == (5,) # repeat in case of effect from change in state shape
@test size(m2(x2)) == (5,1)
@test size(m2(x2)) == (5,1)
end
@testset for R in [RNN, GRU, LSTM, GRUv3]
m1 = R(3, 5)
m2 = R(3, 5)
m3 = R(3, 5)
x1 = rand(Float32, 3)
x2 = rand(Float32, 3, 1)
x3 = rand(Float32, 3, 1, 2)
Flux.reset!(m1)
Flux.reset!(m2)
Flux.reset!(m3)
@test size(m1(x1)) == (5,)
@test size(m1(x1)) == (5,) # repeat in case of effect from change in state shape
@test size(m2(x2)) == (5, 1)
@test size(m2(x2)) == (5, 1)
@test size(m3(x3)) == (5, 1, 2)
@test size(m3(x3)) == (5, 1, 2)
end
end

@testset "RNN-input-state-eltypes" begin
@testset for R in [RNN, GRU, LSTM, GRUv3]
m = R(3, 5)
x = rand(Float64, 3, 1)
Flux.reset!(m)
@test_throws MethodError m(x)
m = R(3, 5)
x = rand(Float64, 3, 1)
Flux.reset!(m)
@test_throws MethodError m(x)
end
end

0 comments on commit f1dbc97

Please sign in to comment.