-
-
Notifications
You must be signed in to change notification settings - Fork 608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added the wrapper Bidirectional
for RNN layers
#1790
base: master
Are you sure you want to change the base?
Changes from all commits
fe14252
01daa36
dd6b70e
808cabc
398e533
bb41190
c107916
0ec5d82
18b406a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -430,3 +430,106 @@ Recur(m::GRUv3Cell) = Recur(m, m.state0) | |||||||||
@adjoint function Broadcast.broadcasted(f::Recur, args...) | ||||||||||
Zygote.∇map(__context__, f, args...) | ||||||||||
end | ||||||||||
|
||||||||||
|
||||||||||
""" | ||||||||||
Bidirectional{A,B} | ||||||||||
|
||||||||||
A wrapper layer that allows the use of [bidirectional](https://ieeexplore.ieee.org/document/650093) layers. It contains two parts that are Flux layers: `forward` and `backward` where | ||||||||||
the forward layer weights are concatenated with the reversed order of the backward layer weights. | ||||||||||
|
||||||||||
It is intended to be used with recurrent layers such as `LSTM`, `GRU` or `RNN` to benefit from the sequential information that recurrent | ||||||||||
layers have, but it will not raise an error if used with a different layer such as `Dense`, as long as the layer is compatible with the concatenation function `vcat`. | ||||||||||
|
||||||||||
For flexibility, it is possible to use the contructor `Bidirectional(rnn, in::Int, out::Int, a...; ka...)` by passing the input and output dimensions | ||||||||||
together with the desired recurrent layers (one of `LSTM`, `GRU`, or `RNN`). Check the examples below for more details. | ||||||||||
|
||||||||||
# Examples | ||||||||||
|
||||||||||
1. Using the flexible constructor to create a bidirectional LSTM layer (BiLSTM): | ||||||||||
```jldoctest | ||||||||||
julia> BLSTM = Bidirectional(LSTM, 3, 5) | ||||||||||
Bidirectional( | ||||||||||
Recur( | ||||||||||
LSTMCell(3, 5), # 190 parameters | ||||||||||
), | ||||||||||
Recur( | ||||||||||
LSTMCell(3, 5), # 190 parameters | ||||||||||
), | ||||||||||
) # Total: 10 trainable arrays, 380 parameters, | ||||||||||
# plus 4 non-trainable, 20 parameters, summarysize 2.141 KiB. | ||||||||||
``` | ||||||||||
2. Checking the dimension after running the bidirectional layer on an input vector. Shows that the | ||||||||||
dimension of the output is twice the dimension of the input: | ||||||||||
```jldoctest | ||||||||||
julia> Bidirectional(LSTM, 3, 5)(rand(Float32, 3)) |> size | ||||||||||
(10,) | ||||||||||
``` | ||||||||||
3. It is possible to use the bidirectional layer inside the `Chain` container: | ||||||||||
``` | ||||||||||
julia> model = Chain(Embedding(10000, 200), Bidirectional(LSTM, 200, 128), Dense(256, 5), softmax) | ||||||||||
Chain( | ||||||||||
Embedding(10000, 200), # 2_000_000 parameters | ||||||||||
Bidirectional( | ||||||||||
Recur( | ||||||||||
LSTMCell(200, 128), # 168_704 parameters | ||||||||||
), | ||||||||||
Recur( | ||||||||||
LSTMCell(200, 128), # 168_704 parameters | ||||||||||
), | ||||||||||
), | ||||||||||
Dense(256, 5), # 1_285 parameters | ||||||||||
NNlib.softmax, | ||||||||||
) # Total: 13 trainable arrays, 2_338_693 parameters, | ||||||||||
# plus 4 non-trainable, 512 parameters, summarysize 8.922 MiB. | ||||||||||
``` | ||||||||||
mcabbott marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
4. It is also possible to use the default constructor | ||||||||||
```jldoctest | ||||||||||
julia> BiLSTM = Bidirectional(GRU(3, 5), GRU(3, 5)) | ||||||||||
Bidirectional( | ||||||||||
Recur( | ||||||||||
GRUCell(3, 5), # 140 parameters | ||||||||||
), | ||||||||||
Recur( | ||||||||||
GRUCell(3, 5), # 140 parameters | ||||||||||
), | ||||||||||
) # Total: 8 trainable arrays, 280 parameters, | ||||||||||
# plus 2 non-trainable, 10 parameters, summarysize 1.562 KiB. | ||||||||||
``` | ||||||||||
5. And use other parameters available on the recurrent layers | ||||||||||
```jldoctest | ||||||||||
julia> BiLSTM = Bidirectional(RNN(3, 5, tanh; init=glorot_normal), LSTM(3, 5; initb=zeros32, init_state=zeros32)) | ||||||||||
Bidirectional( | ||||||||||
Recur( | ||||||||||
RNNCell(3, 5, tanh), # 50 parameters | ||||||||||
), | ||||||||||
Recur( | ||||||||||
LSTMCell(3, 5), # 190 parameters | ||||||||||
), | ||||||||||
) # Total: 9 trainable arrays, 240 parameters, | ||||||||||
# plus 3 non-trainable, 15 parameters, summarysize 1.500 KiB. | ||||||||||
``` | ||||||||||
""" | ||||||||||
struct Bidirectional{A,B} | ||||||||||
forward::A | ||||||||||
backward::B | ||||||||||
end | ||||||||||
|
||||||||||
# Constructor that creates a bidirectional with the same layer for forward and backward | ||||||||||
# Needs to have `in` explicitly declared to avoid conflicts with the default construtor | ||||||||||
Bidirectional(rnn, in::Integer, args...; ka...) = Bidirectional(rnn(in, args...; ka...), rnn(in, args...; ka...)) | ||||||||||
|
||||||||||
# Concatenate the forward and reversed backward weights | ||||||||||
function (m::Bidirectional)(x::Union{AbstractVecOrMat{T},OneHotArray}) where {T} | ||||||||||
return vcat(m.forward(x), reverse(m.backward(reverse(x; dims=1)); dims=1)) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Sorry, just found this. When applying Flux RNNs on dense sequences, the temporal dim is actually the last one. See Flux.jl/src/layers/recurrent.jl Line 83 in dd6b70e
Flux.jl/src/layers/recurrent.jl Line 27 in dd6b70e
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm quite confused now because I was expecting the variable X_onehot = [OneHotMatrix(x, vocab_size) for x in X]
y_onehot = [OneHotMatrix(x, num_tags) for x in y] where I am not sure how to proceed. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. https://fluxml.ai/Flux.jl/stable/models/recurrence/#Working-with-sequences has most of the nitty-gritty details about working with batched sequences for RNNs. In short, the only supported formats are Since this definition of Bidirectional takes a contiguous array instead of a vector of arrays,
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But the rnn layers do not literally accept a vector of matrices. At the link this is iterated through by hand. Should method 2 of
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yes, exactly. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And applied to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The signatures specifically prohibit that, so it should fail. Did you have a specific edge case in mind that isn't caught by the above? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, I just wonder about the mismatch between what LSTM accepts and what this accepts. Prohibiting things which don't make sense is good, you want to know soon. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right, the RNN layers can work with single timesteps but There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you for the great conversation here and the suggestions. I just tested and the Bidirectional(LSTM, 3, 5)(randn(Float32, 3, 7)) |> size
(10, 7) But about the format, I am still confused on how to preprocess the text data so that it would end in I usually follow: Also, by checking these formats I discovered that I should probably use |
||||||||||
end | ||||||||||
|
||||||||||
@functor Bidirectional | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||||||
|
||||||||||
function Base.show(io::IO, b::Bidirectional) | ||||||||||
print(io, "Bidirectional(") | ||||||||||
show(io, b.forward.cell) | ||||||||||
print(io, ", ") | ||||||||||
show(io, b.backward.cell) | ||||||||||
print(io, ")") | ||||||||||
end; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Once we settle on the constructor, this line should be the various constructors not the type definition.