Skip to content

Batch_size dependent beam search #994

Closed
@pschwllr

Description

I observed that the top-n predictions differed for different batch sizes using the translate.py script, e.g.:

Batch size 64, top-5 prediction:

[-49.4120] ['▁Seh', 'r', '▁interessant', '▁ist', '▁auch', '▁die', '▁Tatsache', ',', '▁dass', '▁das', '▁Parlament', '▁in', '▁seine', 'm', '▁Vor', 'schlag', '▁berücksichtigt', '▁hat', ',', '▁dass', '▁es', '▁sich', '▁hier', '▁um', '▁eine', '▁Sache', '▁handelt', ',', '▁die', '▁in', '▁erst', 'er', '▁Linie', '▁auf', '▁die', '▁Belang', 'e', '▁der', '▁Bürger', '▁zugeschnitten', '▁ist', '.']

Batch size 5, top-5 prediction:

[-49.7143] ['▁Seh', 'r', '▁interessant', '▁ist', '▁auch', '▁die', '▁Tatsache', ',', '▁dass', '▁das', '▁Parlament', '▁in', '▁seine', 'm', '▁Vor', 'schlag', '▁berücksichtigt', '▁hat', ',', '▁dass', '▁es', '▁sich', '▁hier', '▁um', '▁eine', '▁Sache', '▁handelt', ',', '▁die', '▁in', '▁erst', 'er', '▁Linie', '▁um', '▁die', '▁Belang', 'e', '▁der', '▁Bürger', '▁geht', '.']

Please correct me, if I am mistaken. I guess, the problem is that in the _translate_batch function, beams that are done, are currently still advanced.

So if a sentence in the batch is not done yet, it keeps the whole batch “alive”.

This might lead to longer, but more probable beam hypotheses. Those slightly more probable hypotheses, then replace the ones that would have been predicted with batch_size 1 for example, where the beam in question would not have been kept alive for so long.

The larger the batch_size, the more probable that all the beams are kept alive for longer.

The bug can be fixed by adding one line to onmt/translate/translator.py (line 635):

# (c) Advance each beam.
   for j, b in enumerate(beam):
       if not b.done():   # <----- only advance beams that are not done yet
          b.advance(out[:, j],
               beam_attn.data[:, j, :memory_lengths[j]])
          dec_states.beam_update(j, b.get_current_origin(), beam_size)

Like that I get consistent predictions for different batch_sizes.

Just as a side note, what is a bit sad, is that I was getting slightly better top-n (n>1) results with the bug.


How to reproduce:
(python 3.5, pytorch 0.4.1, OpenNMT 0.4.1)

I downloaded the pre-trained model English-2-German model found on:
http://opennmt.net/Models-py/
an place it into available_models.

Then, I took the first 100 entries of the default src-test.txt.

head -n 100 data/src-test.txt > data/src-100-test.txt

Predict:

python translate.py -model available_models/averaged-10-epoch.pt -src data/src-100-test.txt -output data/pred_100_bs5txt  -batch_size 5 -replace_unk -max_length 200 -verbose -n_best 5

python translate.py -model available_models/averaged-10-epoch.pt -src data/src-100-test.txt -output data/pred_100_bs64txt  -batch_size 64 -replace_unk -max_length 200 -verbose -n_best 5

Compare:

diff  data/pred_100_bs64.txt data/pred_100_bs5.txt   

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions