Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added the wrapper Bidirectional for RNN layers #1790

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
58 changes: 58 additions & 0 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -430,3 +430,61 @@ Recur(m::GRUv3Cell) = Recur(m, m.state0)
@adjoint function Broadcast.broadcasted(f::Recur, args...)
Zygote.∇map(__context__, f, args...)
end


"""
Bidirectional{A,B}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once we settle on the constructor, this line should be the various constructors not the type definition.


A wrapper layer that allows the use of [bidirectional](https://ieeexplore.ieee.org/document/650093) layers. It contains two parts that are Flux layers: `forward` and `backward` where
the forward layer weights are concatenated with the reversed order of the backward layer weights.

It is intended to be used with recurrent layers such as `LSTM`, `GRU` or `RNN` to benefit from the sequential information that recurrent
layers have, but it will not raise an error if used with a different layer such as `Dense`, as long as the layer is compatible with the concatenation function `vcat`.

# Examples
```jldoctest
julia> BLSTM = Bidirectional(LSTM, 3, 5)
Bidirectional(
Recur(
LSTMCell(3, 5), # 190 parameters
),
Recur(
LSTMCell(3, 5), # 190 parameters
),
) # Total: 10 trainable arrays, 380 parameters,
# plus 4 non-trainable, 20 parameters, summarysize 2.141 KiB.
julia> Bidirectional(LSTM, 3, 5)(rand(Float32, 3)) |> size
(10,)

julia> model = Chain(Embedding(10000, 200), Bidirectional(LSTM, 200, 128), Dense(256, 5), softmax)
Chain(
Embedding(10000, 200), # 2_000_000 parameters
Bidirectional(
Recur(
LSTMCell(200, 128), # 168_704 parameters
),
Recur(
LSTMCell(200, 128), # 168_704 parameters
),
),
Dense(256, 5), # 1_285 parameters
NNlib.softmax,
) # Total: 13 trainable arrays, 2_338_693 parameters,
# plus 4 non-trainable, 512 parameters, summarysize 8.922 MiB.
```
mcabbott marked this conversation as resolved.
Show resolved Hide resolved
"""
struct Bidirectional{A,B}
forward::A
backward::B
end

# Constructor that creates a bidirectional with the same layer for forward and backward
Bidirectional(rnn, a...; ka...) = Bidirectional(rnn(a...; ka...), rnn(a...; ka...))


# Concatenate the forward and reversed backward weights
function (m::Bidirectional)(x::Union{AbstractVecOrMat{T},OneHotArray}) where {T}
return vcat(m.forward(x), reverse(m.backward(reverse(x; dims=1)); dims=1))
Copy link
Member

@ToucheSir ToucheSir Nov 28, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return vcat(m.forward(x), reverse(m.backward(reverse(x; dims=1)); dims=1))
return vcat(m.forward(x), reverse(m.backward(reverse(x; dims=3)); dims=3))

Sorry, just found this. When applying Flux RNNs on dense sequences, the temporal dim is actually the last one. See

function (m::Recur)(x::AbstractArray{T, 3}) where T
and
Folding over a 3d Array of dimensions `(features, batch, time)` is also supported:
.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm quite confused now because I was expecting the variable x to be the result of an OneHotMatrix operation on a sentence and would be a matrix. In my experiments, I was using an array of one-hot encoded sentences (that were padded to fixed size) like this:

X_onehot = [OneHotMatrix(x, vocab_size) for x in X]
y_onehot = [OneHotMatrix(x, num_tags) for x in y]

where X is an array of padded sentences and y is the corresponding labels of each word in the sentence (I am experimenting with Named Entity Recognition models). So the input data would be a matrix (batch, (time, features)) and not in the (features, batch, time) format.

I am not sure how to proceed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://fluxml.ai/Flux.jl/stable/models/recurrence/#Working-with-sequences has most of the nitty-gritty details about working with batched sequences for RNNs. In short, the only supported formats are ((features, batch), time) and (features, batch, time). Unlike Python frameworks, Flux puts the batch dim last for all layers because of column major layout. RNNs just add another time dim after that.

Since this definition of Bidirectional takes a contiguous array instead of a vector of arrays, m.forward() and m.backward() dispatch to (m::Recur)(x::AbstractArray{T, 3}). To support both, you'd need something like the following (note: untested!):

  1. (m::Bidirectional)(x::AbstractArray{T, 3}) where T for (features, batch, time)
  2. (m::Bidirectional)(x::Vector{<:AbstractVecOrMat{T}}) where T for ((features, batch), time)

Copy link
Member

@mcabbott mcabbott Nov 29, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But the rnn layers do not literally accept a vector of matrices. At the link this is iterated through by hand. Should method 2 of Bidirectional(fwd, rev)(x) handle that?

julia> LSTM(3,5)(randn(Float32, 3,)) |> size
(5,)

julia> LSTM(3,5)(randn(Float32, 3,7)) |> size
(5, 7)

julia> LSTM(3,5)(randn(Float32, 3,7,11)) |> size
(5, 7, 11)

julia> LSTM(3,5)([randn(Float32, 3,7) for _ in 1:11]) |> size
ERROR: MethodError

julia> function (b::Bidirectional)(xs::AbstractVector{<:AbstractVecOrMat})
         top = [b.forward(x) for x in xs]
         bot = reverse([b.reverse(x) for x in reverse(xs)])
         vcat.(top, bot)
       end

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should method 2 of Bidirectional(fwd, rev)(x) handle that?

Yes, exactly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And applied to randn(Float32, 3,7), should it fail, or still reverse in the 3rd dimension?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The signatures specifically prohibit that, so it should fail. Did you have a specific edge case in mind that isn't caught by the above?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I just wonder about the mismatch between what LSTM accepts and what this accepts. Prohibiting things which don't make sense is good, you want to know soon.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, the RNN layers can work with single timesteps but Bidirectional can't, as it needs to see the whole sequence up-front in order to reverse it. If anything the former should be brought in line with the latter, but that's a conversation for another issue (#1678 probably).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the great conversation here and the suggestions. I just tested and the Bidirectional, as it is, is able to process randn(Float32, 3,7):

Bidirectional(LSTM, 3, 5)(randn(Float32, 3, 7)) |> size
(10, 7)

But about the format, I am still confused on how to preprocess the text data so that it would end in (seq_length, (features, samples)) format. It seems counterintuitive to me.

I usually follow: read the data -> split sentences -> split words -> pad -> one-hot. So my data would be an array with N sentences, where every sentence is described as a one-hot matrix of its words. In this way, I ended up with the (samples, (features, seq_length)) format. How should I preprocess the text data so that I would end up with (seq_length, (features, samples)).

Also, by checking these formats I discovered that I should probably use dims=2 in my formulation on the reverse function (not use the default to reverse in all dimensions, it should reverse only the time that in my case is the second dimension of the onehotmatrix).

end

@functor Bidirectional

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flux.@layer here?

2 changes: 1 addition & 1 deletion src/layers/show.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

for T in [
:Chain, :Parallel, :SkipConnection, :Recur # container types
:Chain, :Parallel, :SkipConnection, :Recur, :Bidirectional # container types
]
@eval function Base.show(io::IO, m::MIME"text/plain", x::$T)
if get(io, :typeinfo, nothing) === nothing # e.g. top level in REPL
Expand Down