diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index e4d483072a..2a9bfd18a3 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -82,7 +82,7 @@ end flip(f, xs) = reverse(f.(reverse(xs))) function (m::Recur)(x::AbstractArray{T, 3}) where T - h = [m(x[:, :, i]) for i in 1:size(x, 3)] + h = [m(view(x, :, :, i)) for i in 1:size(x, 3)] sze = size(h[1]) reshape(reduce(hcat, h), sze[1], sze[2], length(h)) end