Skip to content

Use view for RNN gate slice extraction #1761

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 41 additions & 33 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,23 @@

gate(h, n) = (1:h) .+ h*(n-1)
gate(x::AbstractVector, h, n) = @view x[gate(h,n)]
gate(x::AbstractMatrix, h, n) = x[gate(h,n),:]
gate(x::AbstractMatrix, h, n) = view(x, gate(h,n), :)

# AD-friendly helper for dividing monolithic RNN params into equally sized gates
multigate(x::AbstractArray, h, ::Val{N}) where N = ntuple(n -> gate(x,h,n), N)

@adjoint function multigate(x::AbstractArray, h, c)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we strictly need this adjoint?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the difference between being slower on certain configurations and being strictly faster across all configurations, c.f. before and after. The more calls to gate, the more pronounced the effect: note how GRU cells called gate 6-8 times and also regressed the most (on smaller input sizes) without multigate.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The better answer would be to see what part of gate regressed and fixing that instead.

Copy link
Member Author

@ToucheSir ToucheSir Nov 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The change to gate itself was one line: https://github.com/FluxML/Flux.jl/pull/1761/files#diff-54816e9a4978b8c02648fdb29ebfd6d794452dbac8a28d0e84a5e2cc646a3fbfR4. Since the view and getindex use the same adjoint, there's no reason backwards pass performance should be slower (note forwards pass was consistently faster). Thus the only explanations seem to be a benchmarking artifact (note how this shows up only at smaller input sizes) and/or Zygote's compiler being unhappy for some reason (from profiling, almost all of the non-BLAS, non activation self time is spent in the generated Pullback for both cases).

However, what it did expose is that calling gate multiple times regardless of whether it uses view or slicing was inefficient, as the adjoint would allocate a full-sized buffer for the original array on every call. multigate resolves this by only allocating once, thus reducing both memory and (accumulation) compute by a factor of the number of gates. Even if gate wasn't using view, this would be a worthwhile optimization.

function multigate_pullback(dy)
dx = Zygote._zero(x, eltype(x))
map(multigate(dx, h, c), dy) do dxᵢ, dyᵢ
dyᵢ !== nothing && (dxᵢ.= Zygote.accum.(dxᵢ, dyᵢ));
end
return (dx, nothing, nothing)
end
return multigate(x, h, c), multigate_pullback
end

reshape_cell_output(h, x) = reshape(h, :, size(x)[2:end]...)

# Stateful recurrence

Expand Down Expand Up @@ -97,14 +113,13 @@ struct RNNCell{F,A,V,S}
state0::S
end

RNNCell(in::Integer, out::Integer, σ=tanh; init=Flux.glorot_uniform, initb=zeros32, init_state=zeros32) =
RNNCell(in::Integer, out::Integer, σ=tanh; init=Flux.glorot_uniform, initb=zeros32, init_state=zeros32) =
RNNCell(σ, init(out, in), init(out, out), initb(out), init_state(out,1))

function (m::RNNCell{F,A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {F,A,V,T}
σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b
h = σ.(Wi*x .+ Wh*h .+ b)
sz = size(x)
return h, reshape(h, :, sz[2:end]...)
return h, reshape_cell_output(h, x)
end

@functor RNNCell
Expand Down Expand Up @@ -157,14 +172,10 @@ end
function (m::LSTMCell{A,V,<:NTuple{2,AbstractMatrix{T}}})((h, c), x::Union{AbstractVecOrMat{T},OneHotArray}) where {A,V,T}
b, o = m.b, size(h, 1)
g = m.Wi*x .+ m.Wh*h .+ b
input = σ.(gate(g, o, 1))
forget = σ.(gate(g, o, 2))
cell = tanh.(gate(g, o, 3))
output = σ.(gate(g, o, 4))
c = forget .* c .+ input .* cell
h′ = output .* tanh.(c)
sz = size(x)
return (h′, c), reshape(h′, :, sz[2:end]...)
input, forget, cell, output = multigate(g, o, Val(4))
c′ = @. σ(forget) * c + σ(input) * tanh(cell)
h′ = @. σ(output) * tanh(c′)
return (h′, c′), reshape_cell_output(h′, x)
end

@functor LSTMCell
Expand Down Expand Up @@ -194,7 +205,7 @@ function Base.getproperty(m::LSTMCell, sym::Symbol)
elseif sym === :c
Zygote.ignore() do
@warn "LSTMCell field :c has been deprecated. Use m::LSTMCell.state0[2] instead."
end
end
return getfield(m, :state0)[2]
else
return getfield(m, sym)
Expand All @@ -203,13 +214,10 @@ end

# GRU

function _gru_output(Wi, Wh, b, x, h)
o = size(h, 1)
gx, gh = Wi*x, Wh*h
r = σ.(gate(gx, o, 1) .+ gate(gh, o, 1) .+ gate(b, o, 1))
z = σ.(gate(gx, o, 2) .+ gate(gh, o, 2) .+ gate(b, o, 2))

return gx, gh, r, z
function _gru_output(gxs, ghs, bs)
r = @. σ(gxs[1] + ghs[1] + bs[1])
z = @. σ(gxs[2] + ghs[2] + bs[2])
return r, z
end

struct GRUCell{A,V,S}
Expand All @@ -223,12 +231,12 @@ GRUCell(in, out; init = glorot_uniform, initb = zeros32, init_state = zeros32) =
GRUCell(init(out * 3, in), init(out * 3, out), initb(out * 3), init_state(out,1))

function (m::GRUCell{A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {A,V,T}
b, o = m.b, size(h, 1)
gx, gh, r, z = _gru_output(m.Wi, m.Wh, b, x, h)
h̃ = tanh.(gate(gx, o, 3) .+ r .* gate(gh, o, 3) .+ gate(b, o, 3))
h′ = (1 .- z) .* h̃ .+ z .* h
sz = size(x)
return h′, reshape(h′, :, sz[2:end]...)
Wi, Wh, b, o = m.Wi, m.Wh, m.b, size(h, 1)
gxs, ghs, bs = multigate(Wi*x, o, Val(3)), multigate(Wh*h, o, Val(3)), multigate(b, o, Val(3))
r, z = _gru_output(gxs, ghs, bs)
= @. tanh(gxs[3] + r * ghs[3] + bs[3])
h′ = @. (1 - z) * h̃ + z * h
return h′, reshape_cell_output(h′, x)
end

@functor GRUCell
Expand Down Expand Up @@ -273,16 +281,16 @@ struct GRUv3Cell{A,V,S}
end

GRUv3Cell(in, out; init = glorot_uniform, initb = zeros32, init_state = zeros32) =
GRUv3Cell(init(out * 3, in), init(out * 2, out), initb(out * 3),
GRUv3Cell(init(out * 3, in), init(out * 2, out), initb(out * 3),
init(out, out), init_state(out,1))

function (m::GRUv3Cell{A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {A,V,T}
b, o = m.b, size(h, 1)
gx, gh, r, z = _gru_output(m.Wi, m.Wh, b, x, h)
h̃ = tanh.(gate(gx, o, 3) .+ (m.Wh_h̃ * (r .* h)) .+ gate(b, o, 3))
h′ = (1 .- z) .* h̃ .+ z .* h
sz = size(x)
return h′, reshape(h′, :, sz[2:end]...)
Wi, Wh, b, Wh_h̃, o = m.Wi, m.Wh, m.b, m.Wh_h̃, size(h, 1)
gxs, ghs, bs = multigate(Wi*x, o, Val(3)), multigate(Wh*h, o, Val(2)), multigate(b, o, Val(3))
r, z = _gru_output(gxs, ghs, bs)
= tanh.(gxs[3] .+ (Wh_h̃ * (r .* h)) .+ bs[3])
h′ = @. (1 - z) * h̃ + z * h
return h′, reshape_cell_output(h′, x)
end

@functor GRUv3Cell
Expand Down
10 changes: 10 additions & 0 deletions test/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,13 @@ end
@test_throws MethodError m(x)
end
end

@testset "multigate" begin
x = rand(6, 5)
res, (dx,) = Flux.withgradient(x) do x
x1, _, x3 = Flux.multigate(x, 2, Val(3))
sum(x1) + sum(x3 .* 2)
end
@test res == sum(x[1:2, :]) + 2sum(x[5:6, :])
@test dx == [ones(2, 5); zeros(2, 5); fill(2, 2, 5)]
end