@@ -209,18 +209,6 @@ def _get_beam_kwargs(self, num_return_sequences=1):
209209 }
210210 return beam_kwargs
211211
212- def _get_diverse_beam_kwargs (self , num_return_sequences = 1 ):
213- beam_kwargs = {
214- "early_stopping" : False ,
215- "length_penalty" : 2.0 ,
216- "num_beams" : 2 ,
217- "num_return_sequences" : num_return_sequences ,
218- "num_beam_groups" : 2 , # one beam per group
219- "diversity_penalty" : 2.0 ,
220- "trust_remote_code" : True ,
221- }
222- return beam_kwargs
223-
224212 def _get_constrained_beam_kwargs (self , num_return_sequences = 1 ):
225213 beam_kwargs = {
226214 "early_stopping" : False ,
@@ -352,36 +340,6 @@ def _beam_sample_generate(
352340
353341 return output_generate
354342
355- def _group_beam_search_generate (
356- self ,
357- model ,
358- inputs_dict ,
359- beam_kwargs ,
360- output_scores = False ,
361- output_logits = False ,
362- output_attentions = False ,
363- output_hidden_states = False ,
364- return_dict_in_generate = False ,
365- use_cache = True ,
366- ):
367- logits_processor_kwargs = self ._get_logits_processor_kwargs (do_sample = False , config = model .config )
368- output_generate = model .generate (
369- do_sample = False ,
370- max_new_tokens = self .max_new_tokens ,
371- min_new_tokens = self .max_new_tokens ,
372- output_scores = output_scores ,
373- output_logits = output_logits ,
374- output_attentions = output_attentions ,
375- output_hidden_states = output_hidden_states ,
376- return_dict_in_generate = return_dict_in_generate ,
377- use_cache = use_cache ,
378- ** beam_kwargs ,
379- ** logits_processor_kwargs ,
380- ** inputs_dict ,
381- )
382-
383- return output_generate
384-
385343 def _constrained_beam_search_generate (
386344 self ,
387345 model ,
@@ -748,77 +706,6 @@ def test_generate_without_input_ids(self):
748706 )
749707 self .assertIsNotNone (output_ids_generate )
750708
751- @pytest .mark .generate
752- def test_group_beam_search_generate (self ):
753- for model_class in self .all_generative_model_classes :
754- config , inputs_dict = self .prepare_config_and_inputs_for_generate ()
755-
756- model = model_class (config ).to (torch_device ).eval ()
757- # check `generate()` and `group_beam_search()` are equal
758- beam_kwargs = self ._get_diverse_beam_kwargs ()
759- output_generate = self ._group_beam_search_generate (
760- model = model ,
761- inputs_dict = inputs_dict ,
762- beam_kwargs = beam_kwargs ,
763- )
764- if model .config .get_text_config (decoder = True ).is_encoder_decoder :
765- self .assertTrue (output_generate .shape [1 ] == self .max_new_tokens + 1 )
766- else :
767- self .assertTrue (output_generate .shape [1 ] == self .max_new_tokens + inputs_dict ["input_ids" ].shape [1 ])
768-
769- # check `group_beam_search` for higher than 1 `num_return_sequences`
770- num_return_sequences = 2
771- beam_kwargs = self ._get_diverse_beam_kwargs (num_return_sequences = num_return_sequences )
772- output_generate = self ._group_beam_search_generate (
773- model = model ,
774- inputs_dict = inputs_dict ,
775- beam_kwargs = beam_kwargs ,
776- )
777- if model .config .get_text_config (decoder = True ).is_encoder_decoder :
778- self .assertTrue (output_generate .shape [1 ] == self .max_new_tokens + 1 )
779- else :
780- self .assertTrue (output_generate .shape [1 ] == self .max_new_tokens + inputs_dict ["input_ids" ].shape [1 ])
781-
782- @pytest .mark .generate
783- def test_group_beam_search_generate_dict_output (self ):
784- for model_class in self .all_generative_model_classes :
785- config , inputs_dict = self .prepare_config_and_inputs_for_generate ()
786- if self .has_attentions :
787- config ._attn_implementation = "eager" # can't output attentions otherwise
788-
789- model = model_class (config ).to (torch_device ).eval ()
790- beam_kwargs = self ._get_diverse_beam_kwargs ()
791- output_generate = self ._group_beam_search_generate (
792- model = model ,
793- inputs_dict = inputs_dict ,
794- beam_kwargs = beam_kwargs ,
795- output_scores = True ,
796- output_logits = True ,
797- output_hidden_states = True ,
798- output_attentions = self .has_attentions ,
799- return_dict_in_generate = True ,
800- use_cache = False ,
801- )
802- if model .config .get_text_config (decoder = True ).is_encoder_decoder :
803- self .assertTrue (output_generate .sequences .shape [1 ] == self .max_new_tokens + 1 )
804- self .assertIsInstance (output_generate , GenerateBeamEncoderDecoderOutput )
805- # Retrocompatibility check
806- self .assertIsInstance (output_generate , BeamSearchEncoderDecoderOutput )
807- else :
808- self .assertTrue (
809- output_generate .sequences .shape [1 ] == self .max_new_tokens + inputs_dict ["input_ids" ].shape [1 ]
810- )
811- self .assertIsInstance (output_generate , GenerateBeamDecoderOnlyOutput )
812- # Retrocompatibility check
813- self .assertIsInstance (output_generate , BeamSearchDecoderOnlyOutput )
814-
815- self ._check_generate_outputs (
816- output_generate ,
817- model .config ,
818- num_return_sequences = beam_kwargs ["num_return_sequences" ],
819- num_beams = beam_kwargs ["num_beams" ],
820- )
821-
822709 @is_flaky () # Some models have position-specific tokens, this test may try to force them in an invalid position
823710 @pytest .mark .generate
824711 def test_constrained_beam_search_generate (self ):
@@ -2672,6 +2559,7 @@ def test_diverse_beam_search(self):
26722559 diversity_penalty = 2.0 ,
26732560 remove_invalid_values = True ,
26742561 trust_remote_code = True ,
2562+ custom_generate = "transformers-community/group-beam-search" ,
26752563 )
26762564
26772565 generated_text = bart_tokenizer .batch_decode (outputs , skip_special_tokens = True )
@@ -2831,6 +2719,7 @@ def test_generate_input_values_as_encoder_kwarg(self):
28312719 self .assertListEqual (output_sequences .tolist (), output_sequences_kwargs .tolist ())
28322720 self .assertEqual (output_sequences .shape , (2 , 5 ))
28332721
2722+ # TODO joao, manuel: remove in v4.62.0
28342723 def test_transition_scores_group_beam_search_encoder_decoder (self ):
28352724 articles = [
28362725 "Justin Timberlake and Jessica Biel, welcome to parenthood." ,
@@ -2839,20 +2728,27 @@ def test_transition_scores_group_beam_search_encoder_decoder(self):
28392728 tokenizer = BartTokenizer .from_pretrained ("hf-internal-testing/tiny-random-bart" )
28402729 model = BartForConditionalGeneration .from_pretrained (
28412730 "hf-internal-testing/tiny-random-bart" ,
2731+ eos_token_id = None ,
2732+ )
2733+ generation_config = GenerationConfig (
28422734 max_length = 10 ,
28432735 num_beams = 2 ,
28442736 num_beam_groups = 2 ,
28452737 num_return_sequences = 2 ,
28462738 diversity_penalty = 1.0 ,
2847- eos_token_id = None ,
28482739 return_dict_in_generate = True ,
28492740 output_scores = True ,
28502741 length_penalty = 0.0 ,
28512742 )
28522743 model = model .to (torch_device )
28532744
28542745 input_ids = tokenizer (articles , return_tensors = "pt" , padding = True ).input_ids .to (torch_device )
2855- outputs = model .generate (input_ids = input_ids , trust_remote_code = True )
2746+ outputs = model .generate (
2747+ input_ids = input_ids ,
2748+ generation_config = generation_config ,
2749+ trust_remote_code = True ,
2750+ custom_generate = "transformers-community/group-beam-search" ,
2751+ )
28562752
28572753 transition_scores = model .compute_transition_scores (outputs .sequences , outputs .scores , outputs .beam_indices )
28582754 transition_scores_sum = transition_scores .sum (- 1 )
@@ -4823,6 +4719,16 @@ def test_generate_custom_cache_position(self):
48234719 [
48244720 ("transformers-community/dola" , {"dola_layers" : "low" }),
48254721 ("transformers-community/contrastive-search" , {"penalty_alpha" : 0.6 , "top_k" : 4 }),
4722+ (
4723+ "transformers-community/group-beam-search" ,
4724+ {
4725+ "do_sample" : False ,
4726+ "num_beams" : 2 ,
4727+ "num_beam_groups" : 2 ,
4728+ "diversity_penalty" : 2.0 ,
4729+ "length_penalty" : 2.0 ,
4730+ },
4731+ ),
48264732 ]
48274733 )
48284734 def test_hub_gen_strategies (self , custom_generate , extra_kwargs ):
0 commit comments