From c302378d084135fa6cc4560e1c5f018d876eb5ff Mon Sep 17 00:00:00 2001 From: ratishsp Date: Mon, 2 Apr 2018 13:30:44 +0100 Subject: [PATCH] Fixes issue of copy_loss_by_seqlength results in overflow over value of 255 because of sum over boolean tensor --- onmt/modules/CopyGenerator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onmt/modules/CopyGenerator.py b/onmt/modules/CopyGenerator.py index ece879aeb3..222b7a1c2d 100644 --- a/onmt/modules/CopyGenerator.py +++ b/onmt/modules/CopyGenerator.py @@ -207,7 +207,7 @@ def _compute_loss(self, batch, output, target, copy_attn, align): # Compute Loss as NLL divided by seq length # Compute Sequence Lengths pad_ix = batch.dataset.fields['tgt'].vocab.stoi[onmt.io.PAD_WORD] - tgt_lens = batch.tgt.ne(pad_ix).sum(0).float() + tgt_lens = batch.tgt.ne(pad_ix).float().sum(0) # Compute Total Loss per sequence in batch loss = loss.view(-1, batch.batch_size).sum(0) # Divide by length of each sequence and sum