129129 "past_buckets_states" , # reformer
130130]
131131
132+ GENERATION_MODES_MAPPING = {
133+ GenerationMode .SAMPLE : "_sample" ,
134+ GenerationMode .GREEDY_SEARCH : "_sample" ,
135+ GenerationMode .BEAM_SEARCH : "_beam_search" ,
136+ GenerationMode .BEAM_SAMPLE : "_beam_search" ,
137+ GenerationMode .ASSISTED_GENERATION : "_assisted_decoding" ,
138+ # Deprecated methods
139+ GenerationMode .DOLA_GENERATION : "transformers-community/dola" ,
140+ GenerationMode .CONTRASTIVE_SEARCH : "transformers-community/contrastive-search" ,
141+ GenerationMode .GROUP_BEAM_SEARCH : "transformers-community/group-beam-search" ,
142+ GenerationMode .CONSTRAINED_BEAM_SEARCH : "transformers-community/constrained-beam-search" ,
143+ }
144+
132145
133146@dataclass
134147class GenerateDecoderOnlyOutput (ModelOutput ):
@@ -1492,12 +1505,25 @@ def compute_transition_scores(
14921505
14931506 return transition_scores
14941507
1495- def _validate_generation_mode (self , generation_mode , generation_mode_kwargs ):
1508+ def _validate_generation_mode (self , generation_mode , generation_config , generation_mode_kwargs ):
14961509 if generation_mode == GenerationMode .BEAM_SEARCH and "streamer" in generation_mode_kwargs :
14971510 raise ValueError (
14981511 "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1."
14991512 )
15001513
1514+ if generation_mode == GenerationMode .ASSISTED_GENERATION :
1515+ if generation_config .num_return_sequences > 1 :
1516+ raise ValueError (
1517+ "num_return_sequences has to be 1 when doing assisted generate, "
1518+ f"but is { generation_config .num_return_sequences } ."
1519+ )
1520+ if self ._is_stateful :
1521+ # In assisted generation we need the ability to confirm whether the model would pick certain tokens,
1522+ # which is not possible with stateful models (they can't reset to a previous subset of generated text)
1523+ raise ValueError (
1524+ f"assisted generation is not supported with stateful models, such as { self .__class__ .__name__ } "
1525+ )
1526+
15011527 if (assistant_model := generation_mode_kwargs .get ("assistant_model" )) is not None :
15021528 if self .config .is_encoder_decoder and not assistant_model .config .is_encoder_decoder :
15031529 attributes_to_check = ["encoder_attention_heads" , "encoder_ffn_dim" , "encoder_layers" ]
@@ -2136,16 +2162,9 @@ def _get_deprecated_gen_repo(
21362162 """
21372163 Returns the Hub repo for a deprecated generation mode, if any.
21382164 """
2139- moved_to_hub_modes = {
2140- GenerationMode .DOLA_GENERATION : "transformers-community/dola" ,
2141- GenerationMode .CONTRASTIVE_SEARCH : "transformers-community/contrastive-search" ,
2142- GenerationMode .GROUP_BEAM_SEARCH : "transformers-community/group-beam-search" ,
2143- GenerationMode .CONSTRAINED_BEAM_SEARCH : "transformers-community/constrained-beam-search" ,
2144- }
2145- if custom_generate is not None or generation_mode not in moved_to_hub_modes :
2165+ if custom_generate is not None or "/" not in (repo := GENERATION_MODES_MAPPING [generation_mode ]):
21462166 return None
21472167
2148- repo = moved_to_hub_modes [generation_mode ]
21492168 logger .warning_once (
21502169 f"{ generation_mode .name .replace ('_' , ' ' ).title ()} was moved to a `custom_generate` repo: https://hf.co/{ repo } . "
21512170 f"To prevent loss of backward compatibility, add `custom_generate='{ repo } '` "
@@ -2175,10 +2194,11 @@ def _extract_generation_mode_kwargs(
21752194 "assistant_model" : assistant_model ,
21762195 "streamer" : streamer ,
21772196 }
2178- if synced_gpus is not None :
2179- generation_mode_kwargs ["synced_gpus" ] = (
2180- is_deepspeed_zero3_enabled () or is_fsdp_managed_module (self )
2181- ) and dist .get_world_size () > 1
2197+ generation_mode_kwargs ["synced_gpus" ] = (
2198+ (is_deepspeed_zero3_enabled () or is_fsdp_managed_module (self )) and dist .get_world_size () > 1
2199+ if synced_gpus is None
2200+ else synced_gpus
2201+ )
21822202 generation_mode_kwargs = {k : v for k , v in generation_mode_kwargs .items () if v is not None }
21832203 # Custom_generate callables can have their own set of arguments
21842204 # To extract them, we compare the signature with the standard _sample method
@@ -2338,15 +2358,20 @@ def generate(
23382358 generation_config , use_model_defaults , ** kwargs
23392359 )
23402360 generation_mode = generation_config .get_generation_mode (assistant_model )
2361+ if isinstance (custom_generate , Callable ):
2362+ decoding_method = custom_generate
2363+ else :
2364+ # type() required to access the unbound class-level method
2365+ decoding_method = getattr (type (self ), GENERATION_MODES_MAPPING [generation_mode ])
23412366
23422367 self ._validate_model_kwargs (model_kwargs .copy ())
2343- self ._validate_generation_mode (generation_mode , generation_mode_kwargs )
2368+ self ._validate_generation_mode (generation_mode , generation_config , generation_mode_kwargs )
23442369
23452370 # Deprecation-related step: set Hub repo for deprecated strategies.
23462371 # NOTE: This must come after initializing generation_config, since we need it to determine if this is a deprecated mode.
23472372 # It must also be before any preparation steps, since Hub repos expect to be loaded before preparation steps.
23482373 # TODO joao, manuel: remove this in v4.62.0
2349- if deprecate_mode_repo := self ._get_deprecated_gen_repo (generation_mode , trust_remote_code , custom_generate ):
2374+ if deprecated_mode_repo := self ._get_deprecated_gen_repo (generation_mode , trust_remote_code , custom_generate ):
23502375 return GenerationMixin .generate (
23512376 self ,
23522377 inputs = inputs ,
@@ -2358,7 +2383,7 @@ def generate(
23582383 negative_prompt_ids = negative_prompt_ids ,
23592384 negative_prompt_attention_mask = negative_prompt_attention_mask ,
23602385 use_model_defaults = use_model_defaults ,
2361- custom_generate = deprecate_mode_repo ,
2386+ custom_generate = deprecated_mode_repo ,
23622387 trust_remote_code = trust_remote_code ,
23632388 ** generation_mode_kwargs ,
23642389 ** kwargs ,
@@ -2376,6 +2401,9 @@ def generate(
23762401 inputs_tensor , model_input_name , model_kwargs = self ._prepare_model_inputs (
23772402 inputs , generation_config .bos_token_id , model_kwargs
23782403 )
2404+ # Some generation modes (e.g. assisted) need `inputs_tensor` to rerun encoder.forward()
2405+ if "inputs_tensor" in inspect .signature (decoding_method ).parameters .keys ():
2406+ generation_mode_kwargs ["inputs_tensor" ] = inputs_tensor
23792407 batch_size = inputs_tensor .shape [0 ]
23802408
23812409 device = inputs_tensor .device
@@ -2511,80 +2539,16 @@ def generate(
25112539 # Set model_kwargs `use_cache` so we can use it later in forward runs
25122540 model_kwargs ["use_cache" ] = generation_config .use_cache
25132541
2514- # 9. go into different generation modes
2515- if isinstance (custom_generate , Callable ):
2516- result = custom_generate (
2517- self ,
2518- input_ids ,
2519- logits_processor = prepared_logits_processor ,
2520- stopping_criteria = prepared_stopping_criteria ,
2521- generation_config = generation_config ,
2522- ** generation_mode_kwargs ,
2523- ** model_kwargs ,
2524- )
2525- elif generation_mode == GenerationMode .ASSISTED_GENERATION :
2526- if generation_config .num_return_sequences > 1 :
2527- raise ValueError (
2528- "num_return_sequences has to be 1 when doing assisted generate, "
2529- f"but is { generation_config .num_return_sequences } ."
2530- )
2531- if batch_size > 1 :
2532- raise ValueError ("assisted generate is only supported for batch_size = 1" )
2533- if not model_kwargs ["use_cache" ]:
2534- raise ValueError ("assisted generate requires `use_cache=True`" )
2535- if generation_config .cache_implementation in ["static" , "hybrid" , "sliding_window" ]:
2536- raise ValueError ("assisted generate is not supported with Static cache classes`" )
2537- if self ._is_stateful :
2538- # In assisted generation we need the ability to confirm whether the model would pick certain tokens,
2539- # which is not possible with stateful models (they can't reset to a previous subset of generated text)
2540- raise ValueError (
2541- f"assisted generation is not supported with stateful models, such as { self .__class__ .__name__ } "
2542- )
2543-
2544- # 10. Get the candidate generator, given the parameterization
2545- candidate_generator = self ._get_candidate_generator (
2546- generation_config = generation_config ,
2547- input_ids = input_ids ,
2548- inputs_tensor = inputs_tensor ,
2549- assistant_model = generation_mode_kwargs .pop ("assistant_model" , None ),
2550- logits_processor = logits_processor ,
2551- target_tokenizer = generation_mode_kwargs .pop ("tokenizer" , None ),
2552- assistant_tokenizer = generation_mode_kwargs .pop ("assistant_tokenizer" , None ),
2553- model_kwargs = model_kwargs ,
2554- )
2555-
2556- # 11. run assisted generate
2557- result = self ._assisted_decoding (
2558- input_ids ,
2559- candidate_generator = candidate_generator ,
2560- logits_processor = prepared_logits_processor ,
2561- stopping_criteria = prepared_stopping_criteria ,
2562- generation_config = generation_config ,
2563- ** generation_mode_kwargs ,
2564- ** model_kwargs ,
2565- )
2566-
2567- elif generation_mode in (GenerationMode .SAMPLE , GenerationMode .GREEDY_SEARCH ):
2568- # 10. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
2569- result = self ._sample (
2570- input_ids ,
2571- logits_processor = prepared_logits_processor ,
2572- stopping_criteria = prepared_stopping_criteria ,
2573- generation_config = generation_config ,
2574- ** generation_mode_kwargs ,
2575- ** model_kwargs ,
2576- )
2577-
2578- elif generation_mode in (GenerationMode .BEAM_SAMPLE , GenerationMode .BEAM_SEARCH ):
2579- # 10. run beam sample
2580- result = self ._beam_search (
2581- input_ids ,
2582- logits_processor = prepared_logits_processor ,
2583- stopping_criteria = prepared_stopping_criteria ,
2584- generation_config = generation_config ,
2585- ** generation_mode_kwargs ,
2586- ** model_kwargs ,
2587- )
2542+ # 9. Call generation mode
2543+ result = decoding_method (
2544+ self ,
2545+ input_ids ,
2546+ logits_processor = prepared_logits_processor ,
2547+ stopping_criteria = prepared_stopping_criteria ,
2548+ generation_config = generation_config ,
2549+ ** generation_mode_kwargs ,
2550+ ** model_kwargs ,
2551+ )
25882552
25892553 # Convert to legacy cache format if requested
25902554 if (
@@ -3466,12 +3430,15 @@ def _beam_search(
34663430 def _assisted_decoding (
34673431 self ,
34683432 input_ids : torch .LongTensor ,
3469- candidate_generator : CandidateGenerator ,
34703433 logits_processor : LogitsProcessorList ,
34713434 stopping_criteria : StoppingCriteriaList ,
34723435 generation_config : GenerationConfig ,
34733436 synced_gpus : bool = False ,
34743437 streamer : Optional ["BaseStreamer" ] = None ,
3438+ inputs_tensor : torch .FloatTensor = None ,
3439+ assistant_model : Optional ["PreTrainedModel" ] = None ,
3440+ assistant_tokenizer : Optional ["PreTrainedTokenizerBase" ] = None ,
3441+ tokenizer : Optional ["PreTrainedTokenizerBase" ] = None ,
34753442 ** model_kwargs ,
34763443 ) -> Union [GenerateNonBeamOutput , torch .LongTensor ]:
34773444 r"""
@@ -3483,9 +3450,6 @@ def _assisted_decoding(
34833450 Parameters:
34843451 input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
34853452 The sequence used as a prompt for the generation.
3486- candidate_generator (`CandidateGenerator`):
3487- A derived instance of [`CandidateGenerator`] that defines how candidate sequences are generated. For
3488- more information, the documentation of [`CandidateGenerator`] should be read.
34893453 logits_processor (`LogitsProcessorList`):
34903454 An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
34913455 used to modify the prediction scores of the language modeling head applied at each generation step.
@@ -3500,6 +3464,15 @@ def _assisted_decoding(
35003464 streamer (`BaseStreamer`, *optional*):
35013465 Streamer object that will be used to stream the generated sequences. Generated tokens are passed
35023466 through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
3467+ inputs_tensor (`torch.FloatTensor`, *optional*):
3468+ The input tensor for generation. For decoder models, usually `input_ids`. For encoder-decoder models,
3469+ the tensor that produced `model_kwargs["encoder_outputs"]`.
3470+ assistant_model (`PreTrainedModel`, *optional*):
3471+ The model used to assist the generation process. If not provided, the main model will be used.
3472+ assistant_tokenizer (`PreTrainedTokenizerBase`, *optional*):
3473+ The tokenizer used for the assistant model. If not provided, the token space is assumed to be the same.
3474+ tokenizer (`PreTrainedTokenizerBase`, *optional*):
3475+ The tokenizer used for the main model. If not provided, the token space is assumed to be the same.
35033476 model_kwargs:
35043477 Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
35053478 If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
@@ -3511,6 +3484,26 @@ def _assisted_decoding(
35113484 `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
35123485 `model.config.is_encoder_decoder=True`.
35133486 """
3487+ # The cache must be dynamic for assisted generation, and the check must happen AFTER preparing cache
3488+ if not model_kwargs ["use_cache" ]:
3489+ raise ValueError ("assisted generate requires `use_cache=True`" )
3490+ if generation_config .cache_implementation in ["static" , "hybrid" , "sliding_window" ] or (
3491+ "past_key_values" in model_kwargs
3492+ and hasattr (model_kwargs ["past_key_values" ], "layers" )
3493+ and any (getattr (l , "is_compileable" , False ) for l in model_kwargs ["past_key_values" ].layers )
3494+ ):
3495+ raise ValueError ("assisted generate is not supported with Static cache classes`" )
3496+ # Get the candidate generator, given the parameterization
3497+ candidate_generator = self ._get_candidate_generator (
3498+ generation_config = generation_config ,
3499+ input_ids = input_ids ,
3500+ inputs_tensor = inputs_tensor ,
3501+ assistant_model = assistant_model ,
3502+ logits_processor = logits_processor ,
3503+ target_tokenizer = tokenizer ,
3504+ assistant_tokenizer = assistant_tokenizer ,
3505+ model_kwargs = model_kwargs ,
3506+ )
35143507 # init values
35153508 do_sample = generation_config .do_sample
35163509 output_attentions = generation_config .output_attentions
@@ -3535,6 +3528,8 @@ def _assisted_decoding(
35353528
35363529 # keep track of which sequences are already finished
35373530 batch_size , cur_len = input_ids .shape [:2 ]
3531+ if batch_size > 1 :
3532+ raise ValueError ("assisted generate is only supported for batch_size = 1" )
35383533 unfinished_sequences = torch .ones (batch_size , dtype = torch .long , device = input_ids .device )
35393534 model_kwargs = self ._get_initial_cache_position (cur_len , input_ids .device , model_kwargs )
35403535
0 commit comments