Closed
Description
Last Zygote version makes the following test (which lives in test/cuda/curnn.jl, altough the problem is not cuda-specific) fail
using Flux, Test
m = RNN(10, 5)
x = rand(10)
dm = gradient(m -> sum(m(x)), m)[1]
Flux.reset!(m)
θ = gradient(() -> sum(m(x)), params(m))
@test collect(dm[].cell[].Wi) == collect(θ[m.cell.Wi])
The problem is that structural gradient dm
doesn't contain all of its fields' gradients anymore, while the Params derivative seems fine
julia> dm
Base.RefValue{Any}((cell = nothing, init = nothing, state = [0.4698064753664045, 0.6717961137253184, 0.5671022834378837, -1.0362687447377206, 0.09452514544108462]))
julia> θ.grads
IdDict{Any,Any} with 6 entries:
Float32[0.277591 0.382739 … -0.470177 0.156052; -0.133075 0.3… => [0.58861 0.190333 … 0.478153 0.0493986; 0.7885 0.25497 … 0.640533 0.0661743; … ; 0.32628 0.105506 … 0.265052 0.0273828; 0…
Recur(RNNCell(10, 5, tanh)) => RefValue{Any}((cell = nothing, init = nothing, state = nothing))
Float32[-0.833346, -0.42216, 0.256897, -0.611366, 0.277756] => [0.726462, 0.973166, 0.615703, 0.402695, 0.999971]
Float32[0.582231 0.256822 … -0.537418 -0.527425; -0.0186815 0… => [0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]
Float32[0.0, 0.0, 0.0, 0.0, 0.0] => [0.469806, 0.671796, 0.567102, -1.03627, 0.0945251]
:(Main.x) => [0.520888, 0.81092, 1.09326, 0.450779, 0.191352, -0.195497, -0.526366, 0.00277491, -0.398409, -0.253632]
@DhairyaLGandhi This is causing the CI failures observed in #1204