Skip to content

Commit

Permalink
fix bug in tf no_repeat_ngram_size
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickvonplaten committed Mar 11, 2020
1 parent d997ac7 commit 1ba21f9
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,7 +942,8 @@ def _generate_beam_search(
if no_repeat_ngram_size > 0:
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
banned_tokens = calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len)
num_batch_hypotheses = batch_size * num_beams
banned_tokens = calc_banned_tokens(input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len)
# create banned_tokens boolean mask
banned_tokens_indices_mask = []
for banned_tokens_slice in banned_tokens:
Expand Down

0 comments on commit 1ba21f9

Please sign in to comment.