Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backprop through time is truncated to only 1 time step #1209

Closed
AlexLewandowski opened this issue Jun 4, 2020 · 17 comments · Fixed by #1358
Closed

Backprop through time is truncated to only 1 time step #1209

AlexLewandowski opened this issue Jun 4, 2020 · 17 comments · Fixed by #1358

Comments

@AlexLewandowski
Copy link

AlexLewandowski commented Jun 4, 2020

Unless my understanding of Backpropagation Through Time (BPTT) and Flux/Zygote is off, it seems like BPTT isn't working as intended with Flux/Zygote.

Currently, the gradient being calculated with respect to Wh does not look back in time. In other words, if we have a sequence of length 3,

rnn = Flux.RNN(2, 3)
seq = [rand(2) for i = 1:3]

then the following two gradients (grads_seq and grads_2) are the same.

Flux.reset!(rnn);
grads_seq = gradient(Flux.params(rnn)) do
    sum(rnn.(seq)[3])
end

and

Flux.reset!(rnn);
rnn(seq[1])
rnn(seq[2])
grads_2 = gradient(Wh->sum(tanh.(rnn.cell.Wi * seq[3] +
                                 Wh * rnn.state + rnn.cell.b)), rnn.cell.Wh)

Whereas, the gradient for BPTT should be as follows.

Flux.reset!(rnn);
bptt = gradient(Wh->sum(tanh.(rnn.cell.Wi * seq[3] + Wh *
                                tanh.(rnn.cell.Wi * seq[2] + Wh *
                                      tanh.(rnn.cell.Wi * seq[1] +
                                            Wh * rnn.init + rnn.cell.b)
                                      + rnn.cell.b)
                                + rnn.cell.b)),
                rnn.cell.Wh)

Issue #1168 is possibly related, and here is a gist summarizing this.

@AlexLewandowski AlexLewandowski changed the title Backprop through time is truncated to 1 Backprop through time is truncated to only 1 time step Jun 4, 2020
@AStupidBear
Copy link
Contributor

This is a workaround

using Flux

rnn = Flux.RNN(2, 3)
seq = [rand(2) for i = 1:3]

Flux.reset!(rnn);

bptt = gradient(Wh->@show(sum(tanh.(rnn.cell.Wi * seq[3] + Wh *
                                tanh.(rnn.cell.Wi * seq[2] + Wh *
                                      tanh.(rnn.cell.Wi * seq[1] + Wh * rnn.init + rnn.cell.b)
                                      + rnn.cell.b)
                                + rnn.cell.b))),
                rnn.cell.Wh)

grads_seq = gradient(Flux.params([rnn.cell.Wh])) do
    hs = Zygote.Buffer([], 3)
    for i in 1:3
        hs[i] = rnn(seq[i])
    end
    sum(hs[3])
end

@assert grads_seq[rnn.cell.Wh] == bptt[1]

I think there's something wrong with broadcast and map for recurrent layers. @MikeInnes @dhairyagandhi96

@bhvieira
Copy link
Contributor

bhvieira commented Jun 4, 2020

Using the Buffer could be too slow though (perhaps not here, because it's a broadcasted operation). There must be a @nograd or dropgrad somewhere killing the chain from previous timepoints, perhaps due to scalar indexing.

@AlexLewandowski
Copy link
Author

The issue is precisely with the broadcasting though. The buffer isn't necessary, this works too:

grads_seq = gradient(Flux.params([rnn.cell.Wh])) do
    h = 0f0
    for i in 1:3
        h = sum(rnn(seq[i]))
    end
    h
end

@AStupidBear
Copy link
Contributor

@bhvieira If I want to collect outputs from all timesteps, what is the most efficient way?

@bhvieira
Copy link
Contributor

bhvieira commented Jun 5, 2020

Use Tracker I guess 😅
It's easy for me to say it because I don't face Zygote in my projects yet, since they started on Tracker and will remain like that for a while

@MikeInnes
Copy link
Member

Can you please try replacing rnn.(seq) with map(rnn, seq), and see if this branch fixes it? (i.e. add Zygote#stateful-map)

If that works we can add a similar patch for broadcast.

@AStupidBear
Copy link
Contributor

Just as @bhvieira recommended, Tracker may still be more robust. But it's also possible to combine Tracker with the lastest Flux

using Flux, TrackerFlux

rnn = Flux.RNN(2, 3) |> TrackerFlux.track
seq = [rand(2) for i = 1:3]

bptt = Flux.gradient(Wh->@show(sum(tanh.(rnn.cell.Wi * seq[3] + Wh *
                                tanh.(rnn.cell.Wi * seq[2] + Wh *
                                      tanh.(rnn.cell.Wi * seq[1] + Wh * rnn.init + rnn.cell.b)
                                      + rnn.cell.b)
                                + rnn.cell.b))),
                rnn.cell.Wh)

Flux.reset!(rnn);
grads_seq = Flux.gradient(Flux.params([rnn.cell.Wh])) do
    hs = map(rnn, seq)
    sum(hs[3])
end
@assert grads_seq[rnn.cell.Wh] == bptt[1]

We can also write in a Zygote compatiple way

  • With Tracker
using Flux, TrackerFlux

rnn = Flux.RNN(2, 3) |> TrackerFlux.track
seq = [rand(2) for i = 1:3]

bptt = Flux.gradient(Wh->@show(sum(tanh.(rnn.cell.Wi * seq[3] + Wh *
                                tanh.(rnn.cell.Wi * seq[2] + Wh *
                                      tanh.(rnn.cell.Wi * seq[1] + Wh * rnn.init + rnn.cell.b)
                                      + rnn.cell.b)
                                + rnn.cell.b))),
                rnn.cell.Wh)

Flux.reset!(rnn);
grads_seq = Flux.gradient(Flux.params([rnn.cell.Wh])) do
    hs = Flux.Zygote.Buffer([], 3)
    for i in 1:3
        hs[i] = rnn(seq[i])
    end
    sum(hs[3])
end
@assert grads_seq[rnn.cell.Wh] == bptt[1]
  • Without Tracker
using Flux

rnn = Flux.RNN(2, 3)
seq = [rand(2) for i = 1:3]

bptt = Flux.gradient(Wh->@show(sum(tanh.(rnn.cell.Wi * seq[3] + Wh *
                                tanh.(rnn.cell.Wi * seq[2] + Wh *
                                      tanh.(rnn.cell.Wi * seq[1] + Wh * rnn.init + rnn.cell.b)
                                      + rnn.cell.b)
                                + rnn.cell.b))),
                rnn.cell.Wh)

Flux.reset!(rnn);
grads_seq = Flux.gradient(Flux.params([rnn.cell.Wh])) do
    hs = Flux.Zygote.Buffer([], 3)
    for i in 1:3
        hs[i] = rnn(seq[i])
    end
    sum(hs[3])
end
@assert grads_seq[rnn.cell.Wh] == bptt[1]

@AStupidBear
Copy link
Contributor

AStupidBear commented Jun 5, 2020

In this example, there's a slight overhead of using Zygote.Buffer with Zygote and there's of course no overhead with Tracker

julia> using Flux, TrackerFlux

julia> rnn = Flux.RNN(2, 3) |> TrackerFlux.track
Recur(RNNCell(2, 3, tanh))

julia> seq = [rand(2) for i = 1:3]
3-element Array{Array{Float64,1},1}:
 [0.03294249342946465, 0.18808334558515472]
 [0.34226355023657296, 0.9750375230014132] 
 [0.9295184217629979, 0.26820876169878827] 

julia> ps = Flux.params([rnn.cell.Wh])
Params([Float32[-0.5715096 0.57270575 -0.8965924; -0.48481512 -0.46085548 0.90666986; 0.24707532 -0.52199364 0.03942895] (tracked)])

julia> @btime begin
           Flux.reset!($rnn);
           Flux.gradient($ps) do
               hs = $rnn.($seq)
               sum(hs[3])
           end
       end

  119.110 μs (634 allocations: 22.07 KiB)
Grads(...)


julia> @btime begin
           Flux.reset!($rnn);
           Flux.gradient($ps) do
               hs = Vector{Any}(undef, 3)
               for i in 1:3
                   hs[i] = $rnn($seq[i])
               end
               sum(hs[3])
           end
       end
  117.956 μs (630 allocations: 21.98 KiB)
Grads(...)


julia> @btime begin
           Flux.reset!($rnn);
           Flux.gradient($ps) do
               hs = Flux.Zygote.Buffer([], 3)
               for i in 1:3
                   hs[i] = $rnn($seq[i])
               end
               sum(hs[3])
           end
       end
  117.305 μs (631 allocations: 22.05 KiB)
Grads(...)
julia> using Flux

julia> rnn = Flux.RNN(2, 3)
Recur(RNNCell(2, 3, tanh))

julia> seq = [rand(2) for i = 1:3]

3-element Array{Array{Float64,1},1}:
 [0.32629709362357584, 0.11605123770776848]
 [0.6291065815003436, 0.1655236202415329]  
 [0.42141717535016565, 0.7108787078307919] 

julia> Flux.reset!(rnn)
3-element Array{Float32,1}:
 0.0
 0.0
 0.0

julia> ps = Flux.params([rnn.cell.Wh])
Params([Float32[-0.07593036 -0.1756413 -0.32087517; 0.8284807 -0.11117959 -0.22851062; 0.55279875 -0.4334936 -0.36445403]])

julia> @btime Flux.gradient($ps) do
           hs = Flux.Zygote.Buffer([], 3)
           for i in 1:3
               hs[i] = $rnn(seq[i])
           end
           sum(hs[3])
       end
  85.040 μs (502 allocations: 18.33 KiB)
Grads(...)

julia> @btime Flux.gradient($ps) do
           h = 0f0
           for i in 1:3
               h = $rnn(seq[i])
           end
           sum(h)
       end

  73.407 μs (467 allocations: 17.08 KiB)
Grads(...)

@bhvieira
Copy link
Contributor

bhvieira commented Jun 5, 2020

Yeah, as I mentioned previously it makes sense to be that way. Usually the buffer tends to slow down things when you need to populate an array with an operation that could be done in one go with functions, such as the products in #1009

@AStupidBear
Copy link
Contributor

@MikeInnes I can confirm the following script works.

using Pkg
pkg"add Zygote#stateful-map"
using Flux

rnn = Flux.RNN(2, 3)
seq = [rand(2) for i = 1:3]

Flux.reset!(rnn)
bptt = Flux.gradient(Wh->sum(tanh.(rnn.cell.Wi * seq[3] + Wh *
                                tanh.(rnn.cell.Wi * seq[2] + Wh *
                                      tanh.(rnn.cell.Wi * seq[1] + Wh * rnn.init + rnn.cell.b)
                                      + rnn.cell.b)
                                + rnn.cell.b)),
                rnn.cell.Wh)

Flux.reset!(rnn)
grads_seq = Flux.gradient(Flux.params([rnn.cell.Wh])) do
    sum(map(rnn, seq)[3])
end

@assert grads_seq[rnn.cell.Wh] == bptt[1]

@MikeInnes
Copy link
Member

Great, so that confirms that patch fixes this issue. If you try the latest version of the branch, the broadcast version should work too.

What's happening here is that when you write a loop, the pullbacks for the RNN application during the loop are applied in reverse order; and the order really matters because they accumulate shared state (the gradient of the hidden state). When you use map we currently just apply an adjoint map that iterates over the seq in the usual order, but we actually need to go in reverse order to handle cases like this.

@MikeInnes
Copy link
Member

In case it's not clear, others hitting this issue can also use that same branch to fix it: FluxML/Zygote.jl#676.

bors bot added a commit to FluxML/Zygote.jl that referenced this issue Jun 12, 2020
676: Stateful map r=dhairyagandhi96 a=MikeInnes

See FluxML/Flux.jl#1209.

It would be good if we could come up with a reasonable test case. It may be tricky to do that, and if so it may make sense to merge this to avoid the issues people are seeing.

Co-authored-by: Mike J Innes <mike.j.innes@gmail.com>
bors bot added a commit to FluxML/Zygote.jl that referenced this issue Jun 12, 2020
676: Stateful map r=MikeInnes a=MikeInnes

See FluxML/Flux.jl#1209.

It would be good if we could come up with a reasonable test case. It may be tricky to do that, and if so it may make sense to merge this to avoid the issues people are seeing.

Co-authored-by: Mike J Innes <mike.j.innes@gmail.com>
bors bot added a commit to FluxML/Zygote.jl that referenced this issue Jul 16, 2020
676: Stateful map r=DhairyaLGandhi a=MikeInnes

See FluxML/Flux.jl#1209.

It would be good if we could come up with a reasonable test case. It may be tricky to do that, and if so it may make sense to merge this to avoid the issues people are seeing.

Co-authored-by: Mike J Innes <mike.j.innes@gmail.com>
Co-authored-by: CarloLucibello <carlo.lucibello@gmail.com>
Co-authored-by: Dhairya Gandhi <dhairya@juliacomputing.com>
@AzamatB
Copy link
Contributor

AzamatB commented Jul 17, 2020

Looks like FluxML/Zygote.jl/676 fixed the issue for map, but not for broadcast. In particular, the example in the OP is still not working correctly.

@CarloLucibello
Copy link
Member

yeah, I had to comment out the broadcasting part of FluxML/Zygote.jl#676 to get the tests to pass. I was seeing some method ambiguity errors, maybe they can be worked around if someone is willing to try. I'm sorry I don't have much time these days

@bhvieira
Copy link
Contributor

yeah, I had to comment out the broadcasting part of FluxML/Zygote.jl#676 to get the tests to pass. I was seeing some method ambiguity errors, maybe they can be worked around if someone is willing to try. I'm sorry I don't have much time these days

Perhaps it should be just better then to override broadcasting recurrent layers until it's done, because nowhere it's mentioned it doesn't work as expected.

@AzamatB
Copy link
Contributor

AzamatB commented Jul 17, 2020

I think the fact that RNNs currently silently produce wrong behavior that is quite hard to debug is critical enough that it should be prioritized

@DhairyaLGandhi
Copy link
Member

I can pick up on the broadcasting issue, the basic code is already in there, sans some threading around the internals to get it working as expected. We will need to add a minimised test case as well to catch this in the future.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
7 participants