-
-
Notifications
You must be signed in to change notification settings - Fork 608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added the wrapper Bidirectional
for RNN layers
#1790
base: master
Are you sure you want to change the base?
Changes from 3 commits
fe14252
01daa36
dd6b70e
808cabc
398e533
bb41190
c107916
0ec5d82
18b406a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
@@ -1,4 +1,3 @@ | ||||||||||
|
||||||||||
gate(h, n) = (1:h) .+ h*(n-1) | ||||||||||
gate(x::AbstractVector, h, n) = @view x[gate(h,n)] | ||||||||||
gate(x::AbstractMatrix, h, n) = x[gate(h,n),:] | ||||||||||
|
@@ -430,3 +429,72 @@ Recur(m::GRUv3Cell) = Recur(m, m.state0) | |||||||||
@adjoint function Broadcast.broadcasted(f::Recur, args...) | ||||||||||
Zygote.∇map(__context__, f, args...) | ||||||||||
end | ||||||||||
|
||||||||||
|
||||||||||
""" | ||||||||||
Bidirectional{A,B} | ||||||||||
A wrapper layer that allows bidirectional recurrent layers. | ||||||||||
# Examples | ||||||||||
```jldoctest | ||||||||||
julia> BLSTM = Bidirectional(LSTM, 3, 5) | ||||||||||
Bidirectional( | ||||||||||
Recur( | ||||||||||
LSTMCell(3, 5), # 190 parameters | ||||||||||
), | ||||||||||
Recur( | ||||||||||
LSTMCell(3, 5), # 190 parameters | ||||||||||
), | ||||||||||
) # Total: 10 trainable arrays, 380 parameters, | ||||||||||
# plus 4 non-trainable, 20 parameters, summarysize 2.141 KiB. | ||||||||||
julia> BLSTM(rand(Float32, 3)) |> size | ||||||||||
(10,) | ||||||||||
julia> model = Chain(Embedding(10000, 200), Bidirectional(LSTM, 200, 128), Dense(256, 5), softmax) | ||||||||||
Chain( | ||||||||||
Embedding(10000, 200), # 2_000_000 parameters | ||||||||||
Bidirectional( | ||||||||||
Recur( | ||||||||||
LSTMCell(200, 128), # 168_704 parameters | ||||||||||
), | ||||||||||
Recur( | ||||||||||
LSTMCell(200, 128), # 168_704 parameters | ||||||||||
), | ||||||||||
), | ||||||||||
Dense(256, 5), # 1_285 parameters | ||||||||||
NNlib.softmax, | ||||||||||
) # Total: 13 trainable arrays, 2_338_693 parameters, | ||||||||||
# plus 4 non-trainable, 512 parameters, summarysize 8.922 MiB. | ||||||||||
``` | ||||||||||
mcabbott marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
""" | ||||||||||
struct Bidirectional{A,B} | ||||||||||
forward::A | ||||||||||
backward::B | ||||||||||
end | ||||||||||
|
||||||||||
# Generic constructor for every case | ||||||||||
Bidirectional(forward, f_in::Integer, f_out::Integer, backward, b_in::Integer, b_out::Integer) = Bidirectional(forward(f_in, f_out), backward(b_in, b_out)) | ||||||||||
|
||||||||||
# Constructor for forward and backward having the same size | ||||||||||
Bidirectional(forward, backward, in::Integer, out::Integer) = Bidirectional(forward(in, out), backward(in, out)) | ||||||||||
|
||||||||||
# Constructor to add the same cell as forward and backward with given input and output sizes | ||||||||||
Bidirectional(rnn, in::Integer, out::Integer) = Bidirectional(rnn(in, out), rnn(in, out)) | ||||||||||
mcabbott marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
|
||||||||||
|
||||||||||
# Concatenate the forward and reversed backward weights | ||||||||||
function (m::Bidirectional)(x::Union{AbstractVecOrMat{T},OneHotArray}) where {T} | ||||||||||
return vcat(m.forward(x), reverse(m.backward(reverse(x; dims=1)); dims=1)) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Sorry, just found this. When applying Flux RNNs on dense sequences, the temporal dim is actually the last one. See Flux.jl/src/layers/recurrent.jl Line 83 in dd6b70e
Flux.jl/src/layers/recurrent.jl Line 27 in dd6b70e
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm quite confused now because I was expecting the variable X_onehot = [OneHotMatrix(x, vocab_size) for x in X]
y_onehot = [OneHotMatrix(x, num_tags) for x in y] where I am not sure how to proceed. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. https://fluxml.ai/Flux.jl/stable/models/recurrence/#Working-with-sequences has most of the nitty-gritty details about working with batched sequences for RNNs. In short, the only supported formats are Since this definition of Bidirectional takes a contiguous array instead of a vector of arrays,
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But the rnn layers do not literally accept a vector of matrices. At the link this is iterated through by hand. Should method 2 of
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yes, exactly. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And applied to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The signatures specifically prohibit that, so it should fail. Did you have a specific edge case in mind that isn't caught by the above? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, I just wonder about the mismatch between what LSTM accepts and what this accepts. Prohibiting things which don't make sense is good, you want to know soon. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right, the RNN layers can work with single timesteps but There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you for the great conversation here and the suggestions. I just tested and the Bidirectional(LSTM, 3, 5)(randn(Float32, 3, 7)) |> size
(10, 7) But about the format, I am still confused on how to preprocess the text data so that it would end in I usually follow: Also, by checking these formats I discovered that I should probably use |
||||||||||
end | ||||||||||
|
||||||||||
@functor Bidirectional | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||||||
Base.getproperty(m::Bidirectional, sym::Symbol) = getfield(m, sym) | ||||||||||
dcecchini marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
|
||||||||||
# Show adaptations | ||||||||||
function _big_show(io::IO, obj::Bidirectional, indent::Int=0, name=nothing) | ||||||||||
dcecchini marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
println(io, " "^indent, isnothing(name) ? "" : "$name = ", nameof(typeof(obj)), "(") | ||||||||||
# then we insert names -- can this be done more generically? | ||||||||||
for k in propertynames(obj) | ||||||||||
_big_show(io, getfield(obj, k), indent+2, k) | ||||||||||
end | ||||||||||
end | ||||||||||
|
||||||||||
Base.show(io::IO, m::MIME"text/plain", x::Bidirectional) = _big_show(io, x) | ||||||||||
mcabbott marked this conversation as resolved.
Show resolved
Hide resolved
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Once we settle on the constructor, this line should be the various constructors not the type definition.