Skip to content

Commit d13af35

Browse files
committed
Major state and loss resetting improvement
Suppose there are two videos, "a" and "b", so far we were trying to estimate: 1) a[-1] ~= net(a[-2]) 2) b[ 0] ~= net(a[-1]) <-- this has to be removed 3) b[ 1] ~= net(b[ 0]) Therefore, we now force "net(a[-1]) = b[0]", so there is no error or gradient associated to it. Moreover, we do not want a gradient from 2) to 1) through the state forwarding, so we kill the state 1) -> 2) and therefore the gradient 2) -> 1). Furthermore, we reset the state 2) -> 3) because the state refers to "a", and not to "b". Finally, we force also "net(a[0]) = a[1]" and "net(b[0]) = b[1]", given that there was no state or the state has been just reset.
1 parent 92f21d9 commit d13af35

File tree

1 file changed

+19
-3
lines changed

1 file changed

+19
-3
lines changed

main.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,14 +168,25 @@ def selective_zero(s, new):
168168
s[layer] = s[layer].index_fill(0, V(b), 0) # mask state, zero selected indices
169169

170170

171+
def selective_match(x_hat, x, new):
172+
if new.any(): # if at least one video changed
173+
b = new.nonzero().squeeze(1) # get the list of indices
174+
for bb in b: x_hat[bb].copy_(x[bb]) # force the output to be the expected output
175+
176+
171177
def train(train_loader, model, loss_fun, optimiser, epoch):
172178
print('Training epoch', epoch + 1)
173179
model.train() # set model in train mode
174180
total_loss = {'mse': 0, 'ce': 0, 'rpl': 0}
175181
mse, nll = loss_fun
176182

177183
def compute_loss(x_, next_x, y_, state_):
184+
if not hasattr(compute_loss, 'mismatch'): compute_loss.mismatch = y_.byte().fill_(1) # ignore first prediction
185+
selective_zero(state, mismatch) # no state from the past
178186
(x_hat, state_), (_, idx_) = model(V(x_), state_)
187+
selective_zero(state, mismatch) # no state to the future
188+
selective_match(x_hat.data, next_x, mismatch + compute_loss.mismatch) # last frame or first frame
189+
compute_loss.mismatch = mismatch # last frame <- first frame
179190
mse_loss_ = mse(x_hat, V(next_x))
180191
ce_loss_ = nll(idx_, V(y_))
181192
total_loss['mse'] += mse_loss_.data[0]
@@ -198,11 +209,11 @@ def compute_loss(x_, next_x, y_, state_):
198209
loss = 0
199210
# BTT loop
200211
if from_past:
201-
selective_zero(state, y[0] != from_past[1])
212+
mismatch = y[0] != from_past[1]
202213
ce_loss, mse_loss, state, _ = compute_loss(from_past[0], x[0], from_past[1], state)
203214
loss += mse_loss + ce_loss * args.lambda_
204215
for t in range(0, min(args.big_t, x.size(0)) - 1): # first batch we go only T - 1 steps forward / backward
205-
selective_zero(state, y[t + 1] != y[t])
216+
mismatch = y[t + 1] != y[t]
206217
ce_loss, mse_loss, state, x_hat_data = compute_loss(x[t], x[t + 1], y[t], state)
207218
loss += mse_loss + ce_loss * args.lambda_
208219

@@ -248,13 +259,18 @@ def validate(val_loader, model, loss_fun):
248259
x = x.cuda(async=True)
249260
y = y.cuda(async=True)
250261
if not hasattr(validate, 'state'): validate.state = None # init state attribute
262+
if not hasattr(validate, 'mismatch'): validate.mismatch = y.byte().fill_(1) # ignore first prediction ever
251263
state = validate.state
252264
for (next_x, next_y) in batches:
253265
if args.cuda:
254266
next_x = next_x.cuda(async=True)
255267
next_y = next_y.cuda(async=True)
256-
selective_zero(state, next_y[0] != y[0])
268+
mismatch = next_y[0] != y[0]
269+
selective_zero(state, mismatch) # no state from the past
257270
(x_hat, state), (_, idx) = model(V(x[0], volatile=True), state) # do not compute graph (volatile)
271+
selective_zero(state, mismatch) # no state to the future
272+
selective_match(x_hat.data, next_x[0], mismatch + validate.mismatch) # last frame or first frame
273+
validate.mismatch = mismatch # last frame <- first frame
258274
mse_loss = mse(x_hat, V(next_x[0]))
259275
ce_loss = nll(idx, V(y[0])) * args.lambda_
260276
total_loss['mse'] += mse_loss.data[0]

0 commit comments

Comments
 (0)