Skip to content

Commit 4b9d59d

Browse files
committed
Update trainer.py
1 parent ede378a commit 4b9d59d

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

src/11_seq2seq/modules/trainer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -163,12 +163,12 @@ def validate(engine, mini_batch):
163163
# y = (batch_size, length_m)
164164
x, y = mini_batch.src, mini_batch.tgt[0][:, 1:]
165165

166-
with autocast(not engine.config.off_autocast):
167-
y_hat = engine.model(x, mini_batch.tgt[0][:, :-1])
168-
loss = engine.crit(
169-
y_hat.contiguous().view(-1, y_hat.size(-1)),
170-
y.contiguous().view(-1),
171-
)
166+
#with autocast(not engine.config.off_autocast):
167+
y_hat = engine.model(x, mini_batch.tgt[0][:, :-1])
168+
loss = engine.crit(
169+
y_hat.contiguous().view(-1, y_hat.size(-1)),
170+
y.contiguous().view(-1),
171+
)
172172

173173
word_count = int(mini_batch.tgt[1].sum())
174174
loss = float(loss / word_count)

0 commit comments

Comments
 (0)