Skip to content

[generate] beam search -- fix output cropping #37080

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3931,9 +3931,14 @@ def _beam_search(
beam_scores = self._flatten_beam_dim(beam_scores[:, :num_return_sequences])
beam_indices = self._flatten_beam_dim(beam_indices[:, :num_return_sequences, :])

# Crop the static-shaped tensors to the actual size
sequences = sequences[:, :cur_len]
beam_indices = beam_indices[:, : cur_len - decoder_prompt_len]
# Crop the static-shaped tensors to the actual size.
# `beam_indices` is initialized with -1s, and is updated with the beam index of the generated token at each
# step. We can use it to detect the generated length, which may be != `cur_len` (e.g. selected beam is from a
# previous decoding iteration)
max_generated_length = ((beam_indices + 1).bool()).sum(dim=1).max()
output_length = decoder_prompt_len + max_generated_length
sequences = sequences[:, :output_length]
beam_indices = beam_indices[:, :max_generated_length]

if return_dict_in_generate:
if not output_scores:
Expand Down
47 changes: 28 additions & 19 deletions tests/models/bart/test_modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,13 +599,15 @@ def test_xsum_1_1_generation(self):
" 2002 to prosecute genocide, crimes against humanity and war crimes."
)
EXPECTED = (
"</s>"
" The International Criminal Court (ICC) has announced that it has been announced by the International"
" Criminal court."
"</s>"
)

dct = tok(ARTICLE, return_tensors="pt")
generated_ids = hf.generate(**dct, num_beams=4)
result = tok.batch_decode(generated_ids, skip_special_tokens=True)[0]
result = tok.batch_decode(generated_ids)[0]
Copy link
Member Author

@gante gante Mar 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests: update beam search tests to also print special tokens

e.g. this updated test fails on main because it is returning extra pad tokens, because of the incorrect crop

assert EXPECTED == result

def test_xsum_1_1_batch_generation(self):
Expand Down Expand Up @@ -729,16 +731,18 @@ def test_xsum_1_1_batch_generation(self):
truncation=True,
)
generated_ids = self.xsum_1_1_model.generate(**batch, num_beams=4)
result = self.tok.batch_decode(generated_ids, skip_special_tokens=True)
assert (
result[0]
== " The International Criminal Court (ICC) has announced that it has been announced by the International"
result = self.tok.batch_decode(generated_ids)
assert result[0] == (
"</s>"
" The International Criminal Court (ICC) has announced that it has been announced by the International"
" Criminal court."
"</s><pad><pad><pad><pad><pad>"
)
assert (
result[1]
== " An investigation into the crash that killed at least 10 people in the French capital has been"
assert result[1] == (
"</s>"
" An investigation into the crash that killed at least 10 people in the French capital has been"
" released by the French police investigating the crash."
"</s>"
)

def test_encoder_equiv(self):
Expand Down Expand Up @@ -939,8 +943,10 @@ def test_xsum_summarization_same_as_fairseq(self):
PGE_ARTICLE = """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."""

EXPECTED_SUMMARY = (
"</s>"
"California's largest power company has begun shutting off electricity to thousands of customers in the"
" state."
"</s>"
)
dct = tok.batch_encode_plus(
[PGE_ARTICLE],
Expand All @@ -962,10 +968,7 @@ def test_xsum_summarization_same_as_fairseq(self):
decoder_start_token_id=model.config.eos_token_id,
)

decoded = tok.batch_decode(
hypotheses_batch,
skip_special_tokens=True,
)
decoded = tok.batch_decode(hypotheses_batch)
self.assertEqual(EXPECTED_SUMMARY, decoded[0])

def test_xsum_config_generation_params(self):
Expand Down Expand Up @@ -1189,26 +1192,32 @@ def test_cnn_summarization_same_as_fairseq(self):
assert hypotheses_batch[:, 1].eq(0).all().item()

EXPECTED = [
"</s><s>"
"A French prosecutor says he is not aware of any video footage from on board the plane. Two German "
"magazines claim to have found a cell phone video showing the crash. The publications say they watched "
"the video, which was found by a source close to the investigation. All 150 on board Germanwings Flight "
"9525 were killed.",
"9525 were killed."
"</s>",
"</s><s>"
"Palestinian Authority becomes 123rd member of the International Criminal Court. The move gives the court "
"jurisdiction over alleged crimes in Palestinian territories. Israel and the United States opposed the "
"Palestinians' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki said it was a "
"move toward greater justice.",
"move toward greater justice."
"</s><pad><pad><pad><pad>",
"</s><s>"
"U.S. and its negotiating partners reached a strong framework agreement with Iran. Peter Bergen: The "
"debate that has already begun will likely result in more heat than light. He says critics have made "
"dubious assumptions and doubtful assertions. Bergen says the goal was to block Iran from building a "
"nuclear weapon.",
"nuclear weapon."
"</s><pad><pad><pad>",
"</s><s>"
"Liana Barrientos, 39, has been married 10 times, sometimes within two weeks of each other. Prosecutors "
"say the marriages were part of an immigration scam. She pleaded not guilty at State Supreme Court in the "
"Bronx on Friday. If convicted, she faces up to four years in prison.",
"Bronx on Friday. If convicted, she faces up to four years in prison."
"</s><pad><pad><pad><pad><pad>",
]

generated_summaries = tok.batch_decode(
hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True
)
generated_summaries = tok.batch_decode(hypotheses_batch.tolist())
assert generated_summaries == EXPECTED

@slow
Expand Down
8 changes: 5 additions & 3 deletions tests/models/biogpt/test_modeling_biogpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def test_inference_lm_head_model(self):
torch.testing.assert_close(output[:, :3, :3], expected_slice, rtol=1e-4, atol=1e-4)

@slow
def test_biogpt_generation(self):
def test_biogpt_generation_beam_search(self):
tokenizer = BioGptTokenizer.from_pretrained("microsoft/biogpt")
model = BioGptForCausalLM.from_pretrained("microsoft/biogpt")
model.to(torch_device)
Expand All @@ -448,13 +448,15 @@ def test_biogpt_generation(self):
num_beams=5,
early_stopping=True,
)
output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True)
output_str = tokenizer.decode(output_ids[0])

EXPECTED_OUTPUT_STR = (
"</s>"
"COVID-19 is a global pandemic caused by severe acute respiratory syndrome coronavirus 2 (SARS-CoV-2), the"
" causative agent of coronavirus disease 2019 (COVID-19), which has spread to more than 200 countries and"
" territories, including the United States (US), Canada, Australia, New Zealand, the United Kingdom (UK),"
" and the United States of America (USA), as of March 11, 2020, with more than 800,000 confirmed cases and"
" more than 800,000 deaths."
" more than 800,000 deaths. "
"</s>"
)
self.assertEqual(output_str, EXPECTED_OUTPUT_STR)
16 changes: 10 additions & 6 deletions tests/models/m2m_100/test_modeling_m2m_100.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,16 +415,20 @@ def test_seq_to_seq_generation(self):
)

expected_en = [
"The NSA case highlights the total absence of intelligence debate",
"I think there are two levels of response from the French government.",
"</s> __en__ "
"The NSA case highlights the total absence of intelligence debate"
"</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>",
"</s> __en__ "
"I think there are two levels of response from the French government."
"</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>",
"</s> __en__ "
"When François Hollande calls Barack Obama or when Foreign Minister Laurent Fabius calls the U.S."
" Ambassador, they respond to a real discovery, which is that of the scale of U.S. surveillance on all"
" communications in France.",
" communications in France."
"</s>",
]

generated = tokenizer.batch_decode(
hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True
)
generated = tokenizer.batch_decode(hypotheses_batch)
assert generated == expected_en

@require_flash_attn
Expand Down
37 changes: 23 additions & 14 deletions tests/models/t5/test_modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -1475,19 +1475,27 @@ def test_summarization(self):
)

expected_summaries = [
"<pad> "
'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a'
" cell phone video of the final seconds . \"one can hear cries of 'My God' in several languages,\" one"
" magazine says .",
" magazine says ."
"</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s>",
"<pad> "
"the formal accession was marked by a ceremony at The Hague, in the Netherlands . the ICC opened a"
" preliminary examination into the situation in the occupied Palestinian territory . as members of the"
" court, Palestinians may be subject to counter-charges as well .",
" court, Palestinians may be subject to counter-charges as well ."
"</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s>",
"<pad> "
"the u.s. and its negotiating partners reached a very strong framework agreement with Iran . aaron miller:"
" the debate that has already begun since the announcement of the new framework will likely result in more"
" heat than light . the deal would reduce Iran's low-enriched uranium stockpile, cut centrifuges and"
" implement a rigorous inspection regime .",
" implement a rigorous inspection regime ."
"</s>",
"<pad> "
"prosecutors say the marriages were part of an immigration scam . if convicted, barrientos faces two"
' criminal counts of "offering a false instrument for filing in the first degree" she has been married 10'
" times, with nine of her marriages occurring between 1999 and 2002 .",
" times, with nine of her marriages occurring between 1999 and 2002 ."
"</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s>",
]

use_task_specific_params(model, "summarization")
Expand All @@ -1512,11 +1520,8 @@ def test_summarization(self):
early_stopping=True,
)

decoded = tok.batch_decode(hypotheses_batch, skip_special_tokens=True, clean_up_tokenization_spaces=False)
self.assertListEqual(
expected_summaries,
decoded,
)
decoded = tok.batch_decode(hypotheses_batch)
self.assertListEqual(expected_summaries, decoded)

@slow
def test_translation_en_to_de(self):
Expand All @@ -1526,13 +1531,13 @@ def test_translation_en_to_de(self):

en_text = '"Luigi often said to me that he never wanted the brothers to end up in court", she wrote.'
expected_translation = (
'"Luigi sagte mir oft, dass er nie wollte, dass die Brüder am Gericht sitzen", schrieb sie.'
'<pad> "Luigi sagte mir oft, dass er nie wollte, dass die Brüder am Gericht sitzen", schrieb sie.</s>'
)

input_ids = tok.encode(model.config.prefix + en_text, return_tensors="pt")
input_ids = input_ids.to(torch_device)
output = model.generate(input_ids)
translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
translation = tok.decode(output[0])
self.assertEqual(translation, expected_translation)

@slow
Expand All @@ -1558,13 +1563,15 @@ def test_translation_en_to_fr(self):
do_sample=False,
early_stopping=True,
)
translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
translation = tok.decode(output[0])
new_truncated_translation = (
"<pad> "
"Cette section d'images provenant de l'enregistrement infrarouge effectué par le télescope Spitzer montre "
"un "
"« portrait familial » de générations innombrables d’étoiles : les plus anciennes sont observées "
"sous forme "
"de points bleus."
"</s>"
)

self.assertEqual(translation, new_truncated_translation)
Expand All @@ -1575,11 +1582,13 @@ def test_translation_en_to_ro(self):
tok = self.tokenizer
use_task_specific_params(model, "translation_en_to_ro")
en_text = "Taco Bell said it plans to add 2,000 locations in the US by 2022."
expected_translation = "Taco Bell a declarat că intenţionează să adauge 2 000 de locaţii în SUA până în 2022."
expected_translation = (
"<pad> Taco Bell a declarat că intenţionează să adauge 2 000 de locaţii în SUA până în 2022.</s>"
)

inputs = tok(model.config.prefix + en_text, return_tensors="pt").to(torch_device)
output = model.generate(**inputs)
translation = tok.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
translation = tok.decode(output[0])
self.assertEqual(translation, expected_translation)

@slow
Expand Down
Loading