Skip to content

Commit

Permalink
Merge branch 'master' into named
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott authored Aug 4, 2021
2 parents 894c32b + 5d2a955 commit 6d4180f
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 11 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Flux Release Notes

## v0.12.7
* Added support for [`GRUv3`](https://github.com/FluxML/Flux.jl/pull/1675)
* The layers within `Chain` and `Parallel` may now [have names](https://github.com/FluxML/Flux.jl/issues/1680).

## v0.12.5
Expand Down
2 changes: 1 addition & 1 deletion src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ using Zygote: Params, @adjoint, gradient, pullback, @nograd
export gradient

export Chain, Dense, Maxout, SkipConnection, Parallel, flatten,
RNN, LSTM, GRU,
RNN, LSTM, GRU, GRUv3,
SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv,
AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool,
Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
Expand Down
61 changes: 56 additions & 5 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,15 @@ end

# GRU

function _gru_output(Wi, Wh, b, x, h)
o = size(h, 1)
gx, gh = Wi*x, Wh*h
r = σ.(gate(gx, o, 1) .+ gate(gh, o, 1) .+ gate(b, o, 1))
z = σ.(gate(gx, o, 2) .+ gate(gh, o, 2) .+ gate(b, o, 2))

return gx, gh, r, z
end

struct GRUCell{A,V,S}
Wi::A
Wh::A
Expand All @@ -195,9 +204,7 @@ GRUCell(in, out; init = glorot_uniform, initb = zeros32, init_state = zeros32) =

function (m::GRUCell{A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {A,V,T}
b, o = m.b, size(h, 1)
gx, gh = m.Wi*x, m.Wh*h
r = σ.(gate(gx, o, 1) .+ gate(gh, o, 1) .+ gate(b, o, 1))
z = σ.(gate(gx, o, 2) .+ gate(gh, o, 2) .+ gate(b, o, 2))
gx, gh, r, z = _gru_output(m.Wi, m.Wh, b, x, h)
= tanh.(gate(gx, o, 3) .+ r .* gate(gh, o, 3) .+ gate(b, o, 3))
h′ = (1 .- z) .*.+ z .* h
sz = size(x)
Expand All @@ -212,8 +219,9 @@ Base.show(io::IO, l::GRUCell) =
"""
GRU(in::Integer, out::Integer)
[Gated Recurrent Unit](https://arxiv.org/abs/1406.1078) layer. Behaves like an
RNN but generally exhibits a longer memory span over sequences.
[Gated Recurrent Unit](https://arxiv.org/abs/1406.1078v1) layer. Behaves like an
RNN but generally exhibits a longer memory span over sequences. This implements
the variant proposed in v1 of the referenced paper.
See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
for a good overview of the internals.
Expand All @@ -233,6 +241,49 @@ function Base.getproperty(m::GRUCell, sym::Symbol)
end
end


# GRU v3

struct GRUv3Cell{A,V,S}
Wi::A
Wh::A
b::V
Wh_h̃::A
state0::S
end

GRUv3Cell(in, out; init = glorot_uniform, initb = zeros32, init_state = zeros32) =
GRUv3Cell(init(out * 3, in), init(out * 2, out), initb(out * 3),
init(out, out), init_state(out,1))

function (m::GRUv3Cell{A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {A,V,T}
b, o = m.b, size(h, 1)
gx, gh, r, z = _gru_output(m.Wi, m.Wh, b, x, h)
= tanh.(gate(gx, o, 3) .+ (m.Wh_h̃ * (r .* h)) .+ gate(b, o, 3))
h′ = (1 .- z) .*.+ z .* h
sz = size(x)
return h′, reshape(h′, :, sz[2:end]...)
end

@functor GRUv3Cell

Base.show(io::IO, l::GRUv3Cell) =
print(io, "GRUv3Cell(", size(l.Wi, 2), ", ", size(l.Wi, 1)÷3, ")")

"""
GRUv3(in::Integer, out::Integer)
[Gated Recurrent Unit](https://arxiv.org/abs/1406.1078v3) layer. Behaves like an
RNN but generally exhibits a longer memory span over sequences. This implements
the variant proposed in v3 of the referenced paper.
See [this article](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
for a good overview of the internals.
"""
GRUv3(a...; ka...) = Recur(GRUv3Cell(a...; ka...))
Recur(m::GRUv3Cell) = Recur(m, m.state0)


@adjoint function Broadcast.broadcasted(f::Recur, args...)
Zygote.∇map(__context__, f, args...)
end
4 changes: 2 additions & 2 deletions test/cuda/curnn.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Flux, CUDA, Test

@testset for R in [RNN, GRU, LSTM]
@testset for R in [RNN, GRU, LSTM, GRUv3]
m = R(10, 5) |> gpu
x = gpu(rand(10))
(m̄,) = gradient(m -> sum(m(x)), m)
Expand All @@ -12,7 +12,7 @@ using Flux, CUDA, Test
end

@testset "RNN" begin
@testset for R in [RNN, GRU, LSTM], batch_size in (1, 5)
@testset for R in [RNN, GRU, LSTM, GRUv3], batch_size in (1, 5)
rnn = R(10, 5)
curnn = fmap(gpu, rnn)

Expand Down
6 changes: 3 additions & 3 deletions test/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ end
end

@testset "RNN-shapes" begin
@testset for R in [RNN, GRU, LSTM]
@testset for R in [RNN, GRU, LSTM, GRUv3]
m1 = R(3, 5)
m2 = R(3, 5)
x1 = rand(Float32, 3)
Expand All @@ -58,10 +58,10 @@ end
end

@testset "RNN-input-state-eltypes" begin
@testset for R in [RNN, GRU, LSTM]
@testset for R in [RNN, GRU, LSTM, GRUv3]
m = R(3, 5)
x = rand(Float64, 3, 1)
Flux.reset!(m)
@test_throws MethodError m(x)
end
end
end

0 comments on commit 6d4180f

Please sign in to comment.