Skip to content

Backprop through time is truncated to only 1 time step #1209

Closed
@AlexLewandowski

Description

@AlexLewandowski

Unless my understanding of Backpropagation Through Time (BPTT) and Flux/Zygote is off, it seems like BPTT isn't working as intended with Flux/Zygote.

Currently, the gradient being calculated with respect to Wh does not look back in time. In other words, if we have a sequence of length 3,

rnn = Flux.RNN(2, 3)
seq = [rand(2) for i = 1:3]

then the following two gradients (grads_seq and grads_2) are the same.

Flux.reset!(rnn);
grads_seq = gradient(Flux.params(rnn)) do
    sum(rnn.(seq)[3])
end

and

Flux.reset!(rnn);
rnn(seq[1])
rnn(seq[2])
grads_2 = gradient(Wh->sum(tanh.(rnn.cell.Wi * seq[3] +
                                 Wh * rnn.state + rnn.cell.b)), rnn.cell.Wh)

Whereas, the gradient for BPTT should be as follows.

Flux.reset!(rnn);
bptt = gradient(Wh->sum(tanh.(rnn.cell.Wi * seq[3] + Wh *
                                tanh.(rnn.cell.Wi * seq[2] + Wh *
                                      tanh.(rnn.cell.Wi * seq[1] +
                                            Wh * rnn.init + rnn.cell.b)
                                      + rnn.cell.b)
                                + rnn.cell.b)),
                rnn.cell.Wh)

Issue #1168 is possibly related, and here is a gist summarizing this.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions