Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test regression with recurrent neural networks #1245

Closed
CarloLucibello opened this issue Jun 22, 2020 · 1 comment
Closed

test regression with recurrent neural networks #1245

CarloLucibello opened this issue Jun 22, 2020 · 1 comment

Comments

@CarloLucibello
Copy link
Member

CarloLucibello commented Jun 22, 2020

Last Zygote version makes the following test (which lives in test/cuda/curnn.jl, altough the problem is not cuda-specific) fail

using Flux, Test
m = RNN(10, 5)
x = rand(10)
dm = gradient(m -> sum(m(x)), m)[1]
Flux.reset!(m)
θ = gradient(() -> sum(m(x)), params(m))
@test collect(dm[].cell[].Wi) == collect(θ[m.cell.Wi])

The problem is that structural gradient dm doesn't contain all of its fields' gradients anymore, while the Params derivative seems fine

julia> dm
Base.RefValue{Any}((cell = nothing, init = nothing, state = [0.4698064753664045, 0.6717961137253184, 0.5671022834378837, -1.0362687447377206, 0.09452514544108462]))

julia> θ.grads
IdDict{Any,Any} with 6 entries:
  Float32[0.277591 0.382739  -0.470177 0.156052; -0.133075 0.3 => [0.58861 0.190333  0.478153 0.0493986; 0.7885 0.25497  0.640533 0.0661743;  ; 0.32628 0.105506  0.265052 0.0273828; 0
  Recur(RNNCell(10, 5, tanh))                                    => RefValue{Any}((cell = nothing, init = nothing, state = nothing))
  Float32[-0.833346, -0.42216, 0.256897, -0.611366, 0.277756]    => [0.726462, 0.973166, 0.615703, 0.402695, 0.999971]
  Float32[0.582231 0.256822  -0.537418 -0.527425; -0.0186815 0 => [0.0 0.0  0.0 0.0; 0.0 0.0  0.0 0.0;  ; 0.0 0.0  0.0 0.0; 0.0 0.0  0.0 0.0]
  Float32[0.0, 0.0, 0.0, 0.0, 0.0]                               => [0.469806, 0.671796, 0.567102, -1.03627, 0.0945251]
  :(Main.x)                                                      => [0.520888, 0.81092, 1.09326, 0.450779, 0.191352, -0.195497, -0.526366, 0.00277491, -0.398409, -0.253632]

@DhairyaLGandhi This is causing the CI failures observed in #1204

@ToucheSir
Copy link
Member

the explicit structural grad now includes cell as expected (and no longer requires a Ref).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants