Skip to content

Commit

Permalink
Use view for RNN gate slice extraction
Browse files Browse the repository at this point in the history
This was originally passed over in #907, but I don't find that argument particularly compelling as the return value is only ever used once. Any negative impact on caching is going to happen anyhow during the slice materialization, so we might as well just let the subsequent fused broadcasts handle said materialization for us while reducing allocations.
  • Loading branch information
ToucheSir committed Nov 6, 2021
1 parent ea26f45 commit c9627c5
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

gate(h, n) = (1:h) .+ h*(n-1)
gate(x::AbstractVector, h, n) = @view x[gate(h,n)]
gate(x::AbstractMatrix, h, n) = x[gate(h,n),:]
gate(x::AbstractMatrix, h, n) = view(x, gate(h,n), :)

# Stateful recurrence

Expand Down Expand Up @@ -97,7 +97,7 @@ struct RNNCell{F,A,V,S}
state0::S
end

RNNCell(in::Integer, out::Integer, σ=tanh; init=Flux.glorot_uniform, initb=zeros32, init_state=zeros32) =
RNNCell(in::Integer, out::Integer, σ=tanh; init=Flux.glorot_uniform, initb=zeros32, init_state=zeros32) =
RNNCell(σ, init(out, in), init(out, out), initb(out), init_state(out,1))

function (m::RNNCell{F,A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {F,A,V,T}
Expand Down Expand Up @@ -194,7 +194,7 @@ function Base.getproperty(m::LSTMCell, sym::Symbol)
elseif sym === :c
Zygote.ignore() do
@warn "LSTMCell field :c has been deprecated. Use m::LSTMCell.state0[2] instead."
end
end
return getfield(m, :state0)[2]
else
return getfield(m, sym)
Expand Down Expand Up @@ -273,7 +273,7 @@ struct GRUv3Cell{A,V,S}
end

GRUv3Cell(in, out; init = glorot_uniform, initb = zeros32, init_state = zeros32) =
GRUv3Cell(init(out * 3, in), init(out * 2, out), initb(out * 3),
GRUv3Cell(init(out * 3, in), init(out * 2, out), initb(out * 3),
init(out, out), init_state(out,1))

function (m::GRUv3Cell{A,V,<:AbstractMatrix{T}})(h, x::Union{AbstractVecOrMat{T},OneHotArray}) where {A,V,T}
Expand Down

0 comments on commit c9627c5

Please sign in to comment.