Skip to content

Commit 822efd8

Browse files
committed
aaand remove tests after all green!!
1 parent 62cb274 commit 822efd8

File tree

7 files changed

+24
-149
lines changed

7 files changed

+24
-149
lines changed

src/transformers/configuration_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1139,6 +1139,9 @@ def _get_global_generation_defaults() -> dict[str, Any]:
11391139
"exponential_decay_length_penalty": None,
11401140
"suppress_tokens": None,
11411141
"begin_suppress_tokens": None,
1142+
# Deprecated arguments (moved to the Hub). TODO joao, manuel: remove in v4.62.0
1143+
"num_beam_groups": 1,
1144+
"diversity_penalty": 0.0,
11421145
}
11431146

11441147
def _get_non_default_generation_parameters(self) -> dict[str, Any]:

tests/generation/test_utils.py

Lines changed: 21 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -209,18 +209,6 @@ def _get_beam_kwargs(self, num_return_sequences=1):
209209
}
210210
return beam_kwargs
211211

212-
def _get_diverse_beam_kwargs(self, num_return_sequences=1):
213-
beam_kwargs = {
214-
"early_stopping": False,
215-
"length_penalty": 2.0,
216-
"num_beams": 2,
217-
"num_return_sequences": num_return_sequences,
218-
"num_beam_groups": 2, # one beam per group
219-
"diversity_penalty": 2.0,
220-
"trust_remote_code": True,
221-
}
222-
return beam_kwargs
223-
224212
def _get_constrained_beam_kwargs(self, num_return_sequences=1):
225213
beam_kwargs = {
226214
"early_stopping": False,
@@ -352,36 +340,6 @@ def _beam_sample_generate(
352340

353341
return output_generate
354342

355-
def _group_beam_search_generate(
356-
self,
357-
model,
358-
inputs_dict,
359-
beam_kwargs,
360-
output_scores=False,
361-
output_logits=False,
362-
output_attentions=False,
363-
output_hidden_states=False,
364-
return_dict_in_generate=False,
365-
use_cache=True,
366-
):
367-
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config)
368-
output_generate = model.generate(
369-
do_sample=False,
370-
max_new_tokens=self.max_new_tokens,
371-
min_new_tokens=self.max_new_tokens,
372-
output_scores=output_scores,
373-
output_logits=output_logits,
374-
output_attentions=output_attentions,
375-
output_hidden_states=output_hidden_states,
376-
return_dict_in_generate=return_dict_in_generate,
377-
use_cache=use_cache,
378-
**beam_kwargs,
379-
**logits_processor_kwargs,
380-
**inputs_dict,
381-
)
382-
383-
return output_generate
384-
385343
def _constrained_beam_search_generate(
386344
self,
387345
model,
@@ -748,77 +706,6 @@ def test_generate_without_input_ids(self):
748706
)
749707
self.assertIsNotNone(output_ids_generate)
750708

751-
@pytest.mark.generate
752-
def test_group_beam_search_generate(self):
753-
for model_class in self.all_generative_model_classes:
754-
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
755-
756-
model = model_class(config).to(torch_device).eval()
757-
# check `generate()` and `group_beam_search()` are equal
758-
beam_kwargs = self._get_diverse_beam_kwargs()
759-
output_generate = self._group_beam_search_generate(
760-
model=model,
761-
inputs_dict=inputs_dict,
762-
beam_kwargs=beam_kwargs,
763-
)
764-
if model.config.get_text_config(decoder=True).is_encoder_decoder:
765-
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
766-
else:
767-
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
768-
769-
# check `group_beam_search` for higher than 1 `num_return_sequences`
770-
num_return_sequences = 2
771-
beam_kwargs = self._get_diverse_beam_kwargs(num_return_sequences=num_return_sequences)
772-
output_generate = self._group_beam_search_generate(
773-
model=model,
774-
inputs_dict=inputs_dict,
775-
beam_kwargs=beam_kwargs,
776-
)
777-
if model.config.get_text_config(decoder=True).is_encoder_decoder:
778-
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + 1)
779-
else:
780-
self.assertTrue(output_generate.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1])
781-
782-
@pytest.mark.generate
783-
def test_group_beam_search_generate_dict_output(self):
784-
for model_class in self.all_generative_model_classes:
785-
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
786-
if self.has_attentions:
787-
config._attn_implementation = "eager" # can't output attentions otherwise
788-
789-
model = model_class(config).to(torch_device).eval()
790-
beam_kwargs = self._get_diverse_beam_kwargs()
791-
output_generate = self._group_beam_search_generate(
792-
model=model,
793-
inputs_dict=inputs_dict,
794-
beam_kwargs=beam_kwargs,
795-
output_scores=True,
796-
output_logits=True,
797-
output_hidden_states=True,
798-
output_attentions=self.has_attentions,
799-
return_dict_in_generate=True,
800-
use_cache=False,
801-
)
802-
if model.config.get_text_config(decoder=True).is_encoder_decoder:
803-
self.assertTrue(output_generate.sequences.shape[1] == self.max_new_tokens + 1)
804-
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
805-
# Retrocompatibility check
806-
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
807-
else:
808-
self.assertTrue(
809-
output_generate.sequences.shape[1] == self.max_new_tokens + inputs_dict["input_ids"].shape[1]
810-
)
811-
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
812-
# Retrocompatibility check
813-
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
814-
815-
self._check_generate_outputs(
816-
output_generate,
817-
model.config,
818-
num_return_sequences=beam_kwargs["num_return_sequences"],
819-
num_beams=beam_kwargs["num_beams"],
820-
)
821-
822709
@is_flaky() # Some models have position-specific tokens, this test may try to force them in an invalid position
823710
@pytest.mark.generate
824711
def test_constrained_beam_search_generate(self):
@@ -2672,6 +2559,7 @@ def test_diverse_beam_search(self):
26722559
diversity_penalty=2.0,
26732560
remove_invalid_values=True,
26742561
trust_remote_code=True,
2562+
custom_generate="transformers-community/group-beam-search",
26752563
)
26762564

26772565
generated_text = bart_tokenizer.batch_decode(outputs, skip_special_tokens=True)
@@ -2831,6 +2719,7 @@ def test_generate_input_values_as_encoder_kwarg(self):
28312719
self.assertListEqual(output_sequences.tolist(), output_sequences_kwargs.tolist())
28322720
self.assertEqual(output_sequences.shape, (2, 5))
28332721

2722+
# TODO joao, manuel: remove in v4.62.0
28342723
def test_transition_scores_group_beam_search_encoder_decoder(self):
28352724
articles = [
28362725
"Justin Timberlake and Jessica Biel, welcome to parenthood.",
@@ -2839,20 +2728,27 @@ def test_transition_scores_group_beam_search_encoder_decoder(self):
28392728
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
28402729
model = BartForConditionalGeneration.from_pretrained(
28412730
"hf-internal-testing/tiny-random-bart",
2731+
eos_token_id=None,
2732+
)
2733+
generation_config = GenerationConfig(
28422734
max_length=10,
28432735
num_beams=2,
28442736
num_beam_groups=2,
28452737
num_return_sequences=2,
28462738
diversity_penalty=1.0,
2847-
eos_token_id=None,
28482739
return_dict_in_generate=True,
28492740
output_scores=True,
28502741
length_penalty=0.0,
28512742
)
28522743
model = model.to(torch_device)
28532744

28542745
input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device)
2855-
outputs = model.generate(input_ids=input_ids, trust_remote_code=True)
2746+
outputs = model.generate(
2747+
input_ids=input_ids,
2748+
generation_config=generation_config,
2749+
trust_remote_code=True,
2750+
custom_generate="transformers-community/group-beam-search",
2751+
)
28562752

28572753
transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, outputs.beam_indices)
28582754
transition_scores_sum = transition_scores.sum(-1)
@@ -4823,6 +4719,16 @@ def test_generate_custom_cache_position(self):
48234719
[
48244720
("transformers-community/dola", {"dola_layers": "low"}),
48254721
("transformers-community/contrastive-search", {"penalty_alpha": 0.6, "top_k": 4}),
4722+
(
4723+
"transformers-community/group-beam-search",
4724+
{
4725+
"do_sample": False,
4726+
"num_beams": 2,
4727+
"num_beam_groups": 2,
4728+
"diversity_penalty": 2.0,
4729+
"length_penalty": 2.0,
4730+
},
4731+
),
48264732
]
48274733
)
48284734
def test_hub_gen_strategies(self, custom_generate, extra_kwargs):

tests/models/csm/test_modeling_csm.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -272,16 +272,6 @@ def test_beam_search_generate_dict_outputs_use_cache(self):
272272
def test_beam_sample_generate_dict_output(self):
273273
pass
274274

275-
@pytest.mark.generate
276-
@unittest.skip(reason="CSM does not support group beam search.")
277-
def test_group_beam_search_generate(self):
278-
pass
279-
280-
@pytest.mark.generate
281-
@unittest.skip(reason="CSM does not support group beam search.")
282-
def test_group_beam_search_generate_dict_output(self):
283-
pass
284-
285275
@pytest.mark.generate
286276
@unittest.skip(reason="CSM does not support constrained beam search.")
287277
def test_constrained_beam_search_generate(self):

tests/models/dia/test_modeling_dia.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,6 @@ def skip_non_greedy_generate(self):
237237
skippable_tests = [
238238
"test_sample_generate_dict_output", # return sequences > 1
239239
"test_beam",
240-
"test_group_beam",
241240
"test_constrained_beam",
242241
"test_contrastive",
243242
"test_assisted",

tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -138,16 +138,6 @@ def test_constrained_beam_search_generate_dict_output(self):
138138
def test_generate_without_input_ids(self):
139139
pass
140140

141-
@unittest.skip(reason="RecurrentGemma is unusual and fails a lot of generation tests")
142-
@pytest.mark.generate
143-
def test_group_beam_search_generate(self):
144-
pass
145-
146-
@unittest.skip(reason="RecurrentGemma is unusual and fails a lot of generation tests")
147-
@pytest.mark.generate
148-
def test_group_beam_search_generate_dict_output(self):
149-
pass
150-
151141
@unittest.skip(reason="RecurrentGemma is unusual and fails a lot of generation tests")
152142
@pytest.mark.generate
153143
def test_constrained_beam_search_generate(self):

tests/models/rwkv/test_modeling_rwkv.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -401,13 +401,6 @@ def test_greedy_generate_dict_outputs(self):
401401
super().test_greedy_generate_dict_outputs()
402402
self.has_attentions = old_has_attentions
403403

404-
def test_group_beam_search_generate_dict_output(self):
405-
# This model has a custom attention output shape AND config flags, let's skip those checks
406-
old_has_attentions = self.has_attentions
407-
self.has_attentions = False
408-
super().test_group_beam_search_generate_dict_output()
409-
self.has_attentions = old_has_attentions
410-
411404
def test_sample_generate_dict_output(self):
412405
# This model has a custom attention output shape AND config flags, let's skip those checks
413406
old_has_attentions = self.has_attentions

tests/models/whisper/test_modeling_whisper.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -403,12 +403,6 @@ def _get_beam_kwargs(self, num_return_sequences=1):
403403
beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"]
404404
return beam_kwargs
405405

406-
def _get_diverse_beam_kwargs(self, num_return_sequences=1):
407-
# Overwritten from `GenerationTesterMixin`, Whisper's `num_return_sequences` differs from the core `generate`
408-
beam_kwargs = super()._get_diverse_beam_kwargs(num_return_sequences=num_return_sequences)
409-
beam_kwargs["num_return_sequences"] = beam_kwargs["num_beams"]
410-
return beam_kwargs
411-
412406
def _get_constrained_beam_kwargs(self, num_return_sequences=1):
413407
# Overwritten from `GenerationTesterMixin`, Whisper's `num_return_sequences` differs from the core `generate`
414408
beam_kwargs = super()._get_constrained_beam_kwargs(num_return_sequences=num_return_sequences)

0 commit comments

Comments
 (0)