Skip to content

Commit

Permalink
Better handle Cuda OOM with overflow batches (OpenNMT#1385)
Browse files Browse the repository at this point in the history
  • Loading branch information
vince62s authored Apr 8, 2019
1 parent 19b52ec commit d4edfc4
Showing 1 changed file with 22 additions and 15 deletions.
37 changes: 22 additions & 15 deletions onmt/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from copy import deepcopy
import itertools
import torch
import traceback

import onmt.utils
from onmt.utils.logging import logger
Expand Down Expand Up @@ -308,7 +309,7 @@ def _gradient_accumulation(self, true_batches, normalization, total_stats,
if self.accum_count > 1:
self.optim.zero_grad()

for batch in true_batches:
for k, batch in enumerate(true_batches):
target_size = batch.tgt.size(0)
# Truncated BPTT: reminder not compatible with accum > 1
if self.trunc_size:
Expand All @@ -335,20 +336,26 @@ def _gradient_accumulation(self, true_batches, normalization, total_stats,
bptt = True

# 3. Compute loss.
loss, batch_stats = self.train_loss(
batch,
outputs,
attns,
normalization=normalization,
shard_size=self.shard_size,
trunc_start=j,
trunc_size=trunc_size)

if loss is not None:
self.optim.backward(loss)

total_stats.update(batch_stats)
report_stats.update(batch_stats)
try:
loss, batch_stats = self.train_loss(
batch,
outputs,
attns,
normalization=normalization,
shard_size=self.shard_size,
trunc_start=j,
trunc_size=trunc_size)

if loss is not None:
self.optim.backward(loss)

total_stats.update(batch_stats)
report_stats.update(batch_stats)

except Exception:
traceback.print_exc()
logger.info("At step %d, we removed a batch - accum %d",
self.optim.training_step, k)

# 4. Update the parameters and statistics.
if self.accum_count == 1:
Expand Down

0 comments on commit d4edfc4

Please sign in to comment.