-
-
Notifications
You must be signed in to change notification settings - Fork 608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Backprop through time is truncated to only 1 time step #1209
Comments
This is a workaround using Flux
rnn = Flux.RNN(2, 3)
seq = [rand(2) for i = 1:3]
Flux.reset!(rnn);
bptt = gradient(Wh->@show(sum(tanh.(rnn.cell.Wi * seq[3] + Wh *
tanh.(rnn.cell.Wi * seq[2] + Wh *
tanh.(rnn.cell.Wi * seq[1] + Wh * rnn.init + rnn.cell.b)
+ rnn.cell.b)
+ rnn.cell.b))),
rnn.cell.Wh)
grads_seq = gradient(Flux.params([rnn.cell.Wh])) do
hs = Zygote.Buffer([], 3)
for i in 1:3
hs[i] = rnn(seq[i])
end
sum(hs[3])
end
@assert grads_seq[rnn.cell.Wh] == bptt[1]
I think there's something wrong with |
Using the |
The issue is precisely with the broadcasting though. The buffer isn't necessary, this works too:
|
@bhvieira If I want to collect outputs from all timesteps, what is the most efficient way? |
Use |
Can you please try replacing If that works we can add a similar patch for broadcast. |
Just as @bhvieira recommended, using Flux, TrackerFlux
rnn = Flux.RNN(2, 3) |> TrackerFlux.track
seq = [rand(2) for i = 1:3]
bptt = Flux.gradient(Wh->@show(sum(tanh.(rnn.cell.Wi * seq[3] + Wh *
tanh.(rnn.cell.Wi * seq[2] + Wh *
tanh.(rnn.cell.Wi * seq[1] + Wh * rnn.init + rnn.cell.b)
+ rnn.cell.b)
+ rnn.cell.b))),
rnn.cell.Wh)
Flux.reset!(rnn);
grads_seq = Flux.gradient(Flux.params([rnn.cell.Wh])) do
hs = map(rnn, seq)
sum(hs[3])
end
@assert grads_seq[rnn.cell.Wh] == bptt[1] We can also write in a Zygote compatiple way
using Flux, TrackerFlux
rnn = Flux.RNN(2, 3) |> TrackerFlux.track
seq = [rand(2) for i = 1:3]
bptt = Flux.gradient(Wh->@show(sum(tanh.(rnn.cell.Wi * seq[3] + Wh *
tanh.(rnn.cell.Wi * seq[2] + Wh *
tanh.(rnn.cell.Wi * seq[1] + Wh * rnn.init + rnn.cell.b)
+ rnn.cell.b)
+ rnn.cell.b))),
rnn.cell.Wh)
Flux.reset!(rnn);
grads_seq = Flux.gradient(Flux.params([rnn.cell.Wh])) do
hs = Flux.Zygote.Buffer([], 3)
for i in 1:3
hs[i] = rnn(seq[i])
end
sum(hs[3])
end
@assert grads_seq[rnn.cell.Wh] == bptt[1]
using Flux
rnn = Flux.RNN(2, 3)
seq = [rand(2) for i = 1:3]
bptt = Flux.gradient(Wh->@show(sum(tanh.(rnn.cell.Wi * seq[3] + Wh *
tanh.(rnn.cell.Wi * seq[2] + Wh *
tanh.(rnn.cell.Wi * seq[1] + Wh * rnn.init + rnn.cell.b)
+ rnn.cell.b)
+ rnn.cell.b))),
rnn.cell.Wh)
Flux.reset!(rnn);
grads_seq = Flux.gradient(Flux.params([rnn.cell.Wh])) do
hs = Flux.Zygote.Buffer([], 3)
for i in 1:3
hs[i] = rnn(seq[i])
end
sum(hs[3])
end
@assert grads_seq[rnn.cell.Wh] == bptt[1] |
In this example, there's a slight overhead of using julia> using Flux, TrackerFlux
julia> rnn = Flux.RNN(2, 3) |> TrackerFlux.track
Recur(RNNCell(2, 3, tanh))
julia> seq = [rand(2) for i = 1:3]
3-element Array{Array{Float64,1},1}:
[0.03294249342946465, 0.18808334558515472]
[0.34226355023657296, 0.9750375230014132]
[0.9295184217629979, 0.26820876169878827]
julia> ps = Flux.params([rnn.cell.Wh])
Params([Float32[-0.5715096 0.57270575 -0.8965924; -0.48481512 -0.46085548 0.90666986; 0.24707532 -0.52199364 0.03942895] (tracked)])
julia> @btime begin
Flux.reset!($rnn);
Flux.gradient($ps) do
hs = $rnn.($seq)
sum(hs[3])
end
end
119.110 μs (634 allocations: 22.07 KiB)
Grads(...)
julia> @btime begin
Flux.reset!($rnn);
Flux.gradient($ps) do
hs = Vector{Any}(undef, 3)
for i in 1:3
hs[i] = $rnn($seq[i])
end
sum(hs[3])
end
end
117.956 μs (630 allocations: 21.98 KiB)
Grads(...)
julia> @btime begin
Flux.reset!($rnn);
Flux.gradient($ps) do
hs = Flux.Zygote.Buffer([], 3)
for i in 1:3
hs[i] = $rnn($seq[i])
end
sum(hs[3])
end
end
117.305 μs (631 allocations: 22.05 KiB)
Grads(...) julia> using Flux
julia> rnn = Flux.RNN(2, 3)
Recur(RNNCell(2, 3, tanh))
julia> seq = [rand(2) for i = 1:3]
3-element Array{Array{Float64,1},1}:
[0.32629709362357584, 0.11605123770776848]
[0.6291065815003436, 0.1655236202415329]
[0.42141717535016565, 0.7108787078307919]
julia> Flux.reset!(rnn)
3-element Array{Float32,1}:
0.0
0.0
0.0
julia> ps = Flux.params([rnn.cell.Wh])
Params([Float32[-0.07593036 -0.1756413 -0.32087517; 0.8284807 -0.11117959 -0.22851062; 0.55279875 -0.4334936 -0.36445403]])
julia> @btime Flux.gradient($ps) do
hs = Flux.Zygote.Buffer([], 3)
for i in 1:3
hs[i] = $rnn(seq[i])
end
sum(hs[3])
end
85.040 μs (502 allocations: 18.33 KiB)
Grads(...)
julia> @btime Flux.gradient($ps) do
h = 0f0
for i in 1:3
h = $rnn(seq[i])
end
sum(h)
end
73.407 μs (467 allocations: 17.08 KiB)
Grads(...) |
Yeah, as I mentioned previously it makes sense to be that way. Usually the buffer tends to slow down things when you need to populate an array with an operation that could be done in one go with functions, such as the products in #1009 |
@MikeInnes I can confirm the following script works. using Pkg
pkg"add Zygote#stateful-map"
using Flux
rnn = Flux.RNN(2, 3)
seq = [rand(2) for i = 1:3]
Flux.reset!(rnn)
bptt = Flux.gradient(Wh->sum(tanh.(rnn.cell.Wi * seq[3] + Wh *
tanh.(rnn.cell.Wi * seq[2] + Wh *
tanh.(rnn.cell.Wi * seq[1] + Wh * rnn.init + rnn.cell.b)
+ rnn.cell.b)
+ rnn.cell.b)),
rnn.cell.Wh)
Flux.reset!(rnn)
grads_seq = Flux.gradient(Flux.params([rnn.cell.Wh])) do
sum(map(rnn, seq)[3])
end
@assert grads_seq[rnn.cell.Wh] == bptt[1] |
Great, so that confirms that patch fixes this issue. If you try the latest version of the branch, the broadcast version should work too. What's happening here is that when you write a loop, the pullbacks for the RNN application during the loop are applied in reverse order; and the order really matters because they accumulate shared state (the gradient of the hidden state). When you use |
In case it's not clear, others hitting this issue can also use that same branch to fix it: FluxML/Zygote.jl#676. |
676: Stateful map r=dhairyagandhi96 a=MikeInnes See FluxML/Flux.jl#1209. It would be good if we could come up with a reasonable test case. It may be tricky to do that, and if so it may make sense to merge this to avoid the issues people are seeing. Co-authored-by: Mike J Innes <mike.j.innes@gmail.com>
676: Stateful map r=MikeInnes a=MikeInnes See FluxML/Flux.jl#1209. It would be good if we could come up with a reasonable test case. It may be tricky to do that, and if so it may make sense to merge this to avoid the issues people are seeing. Co-authored-by: Mike J Innes <mike.j.innes@gmail.com>
676: Stateful map r=DhairyaLGandhi a=MikeInnes See FluxML/Flux.jl#1209. It would be good if we could come up with a reasonable test case. It may be tricky to do that, and if so it may make sense to merge this to avoid the issues people are seeing. Co-authored-by: Mike J Innes <mike.j.innes@gmail.com> Co-authored-by: CarloLucibello <carlo.lucibello@gmail.com> Co-authored-by: Dhairya Gandhi <dhairya@juliacomputing.com>
Looks like FluxML/Zygote.jl/676 fixed the issue for |
yeah, I had to comment out the broadcasting part of FluxML/Zygote.jl#676 to get the tests to pass. I was seeing some method ambiguity errors, maybe they can be worked around if someone is willing to try. I'm sorry I don't have much time these days |
Perhaps it should be just better then to override broadcasting recurrent layers until it's done, because nowhere it's mentioned it doesn't work as expected. |
I think the fact that RNNs currently silently produce wrong behavior that is quite hard to debug is critical enough that it should be prioritized |
I can pick up on the broadcasting issue, the basic code is already in there, sans some threading around the internals to get it working as expected. We will need to add a minimised test case as well to catch this in the future. |
Unless my understanding of Backpropagation Through Time (BPTT) and Flux/Zygote is off, it seems like BPTT isn't working as intended with Flux/Zygote.
Currently, the gradient being calculated with respect to Wh does not look back in time. In other words, if we have a sequence of length 3,
then the following two gradients (
grads_seq
andgrads_2
) are the same.and
Whereas, the gradient for BPTT should be as follows.
Issue #1168 is possibly related, and here is a gist summarizing this.
The text was updated successfully, but these errors were encountered: