Closed
Description
Unless my understanding of Backpropagation Through Time (BPTT) and Flux/Zygote is off, it seems like BPTT isn't working as intended with Flux/Zygote.
Currently, the gradient being calculated with respect to Wh does not look back in time. In other words, if we have a sequence of length 3,
rnn = Flux.RNN(2, 3)
seq = [rand(2) for i = 1:3]
then the following two gradients (grads_seq
and grads_2
) are the same.
Flux.reset!(rnn);
grads_seq = gradient(Flux.params(rnn)) do
sum(rnn.(seq)[3])
end
and
Flux.reset!(rnn);
rnn(seq[1])
rnn(seq[2])
grads_2 = gradient(Wh->sum(tanh.(rnn.cell.Wi * seq[3] +
Wh * rnn.state + rnn.cell.b)), rnn.cell.Wh)
Whereas, the gradient for BPTT should be as follows.
Flux.reset!(rnn);
bptt = gradient(Wh->sum(tanh.(rnn.cell.Wi * seq[3] + Wh *
tanh.(rnn.cell.Wi * seq[2] + Wh *
tanh.(rnn.cell.Wi * seq[1] +
Wh * rnn.init + rnn.cell.b)
+ rnn.cell.b)
+ rnn.cell.b)),
rnn.cell.Wh)
Issue #1168 is possibly related, and here is a gist summarizing this.
Metadata
Metadata
Assignees
Labels
No labels