@@ -2304,7 +2304,10 @@ def generate(
23042304
23052305 generation_mode = generation_config .get_generation_mode (** gen_mode_kwargs )
23062306
2307- # 2. Set Hub repo for deprecated strategies. (TODO joao, manuel: remove this in v4.62.0)
2307+ # Deprecation-related step: set Hub repo for deprecated strategies.
2308+ # NOTE: This must come after initializing generation_config, since we need it to determine if this is a deprecated mode.
2309+ # It must also be before any preparation steps, since Hub repos expect to be loaded before preparation steps.
2310+ # TODO joao, manuel: remove this in v4.62.0
23082311 if deprecate_mode_repo := self ._get_deprecated_gen_repo (generation_mode , trust_remote_code , custom_generate ):
23092312 return GenerationMixin .generate (
23102313 self ,
@@ -2322,7 +2325,7 @@ def generate(
23222325 ** gen_mode_kwargs ,
23232326 )
23242327
2325- # 3 . Set generation parameters if not already defined
2328+ # 2 . Set generation parameters if not already defined
23262329 if synced_gpus is None :
23272330 synced_gpus = (is_deepspeed_zero3_enabled () or is_fsdp_managed_module (self )) and dist .get_world_size () > 1
23282331
@@ -2333,7 +2336,7 @@ def generate(
23332336 requires_attention_mask = "encoder_outputs" not in model_kwargs
23342337 kwargs_has_attention_mask = model_kwargs .get ("attention_mask" , None ) is not None
23352338
2336- # 4 . Define model inputs
2339+ # 3 . Define model inputs
23372340 inputs_tensor , model_input_name , model_kwargs = self ._prepare_model_inputs (
23382341 inputs , generation_config .bos_token_id , model_kwargs
23392342 )
@@ -2357,7 +2360,7 @@ def generate(
23572360 "generation results, please set `padding_side='left'` when initializing the tokenizer."
23582361 )
23592362
2360- # 5 . Define other model kwargs
2363+ # 4 . Define other model kwargs
23612364 # decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are
23622365 # generating the first new token or not, and we only want to use the embeddings for the first new token)
23632366 if not self .config .is_encoder_decoder and model_input_name == "inputs_embeds" :
@@ -2378,7 +2381,7 @@ def generate(
23782381 inputs_tensor , model_kwargs , model_input_name , generation_config
23792382 )
23802383
2381- # 6 . Prepare `input_ids` which will be used for auto-regressive generation
2384+ # 5 . Prepare `input_ids` which will be used for auto-regressive generation
23822385 if self .config .is_encoder_decoder :
23832386 input_ids , model_kwargs = self ._prepare_decoder_input_ids_for_generation (
23842387 batch_size = batch_size ,
@@ -2404,7 +2407,7 @@ def generate(
24042407 if streamer is not None :
24052408 streamer .put (input_ids .cpu ())
24062409
2407- # 7 . Prepare `max_length` depending on other stopping criteria.
2410+ # 6 . Prepare `max_length` depending on other stopping criteria.
24082411 input_ids_length = input_ids .shape [1 ]
24092412 has_default_max_length = kwargs .get ("max_length" ) is None and generation_config .max_length is not None
24102413 has_default_min_length = kwargs .get ("min_length" ) is None and generation_config .min_length is not None
@@ -2425,7 +2428,7 @@ def generate(
24252428
24262429 self ._validate_generated_length (generation_config , input_ids_length , has_default_max_length )
24272430
2428- # 8 . Prepare the cache.
2431+ # 7 . Prepare the cache.
24292432 # - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`.
24302433 # - different models have a different cache name expected by the model (default = "past_key_values")
24312434 # - `max_length`, prepared above, is used to determine the maximum cache length
@@ -2456,7 +2459,7 @@ def generate(
24562459 UserWarning ,
24572460 )
24582461
2459- # 10 . prepare logits processors and stopping criteria
2462+ # 8 . prepare logits processors and stopping criteria
24602463 prepared_logits_processor = self ._get_logits_processor (
24612464 generation_config = generation_config ,
24622465 input_ids_seq_length = input_ids_length ,
@@ -2476,7 +2479,7 @@ def generate(
24762479 # Set model_kwargs `use_cache` so we can use it later in forward runs
24772480 model_kwargs ["use_cache" ] = generation_config .use_cache
24782481
2479- # 11 . go into different generation modes
2482+ # 9 . go into different generation modes
24802483 if isinstance (custom_generate , Callable ):
24812484 result = custom_generate (
24822485 self ,
@@ -2507,7 +2510,7 @@ def generate(
25072510 f"assisted generation is not supported with stateful models, such as { self .__class__ .__name__ } "
25082511 )
25092512
2510- # 11 . Get the candidate generator, given the parameterization
2513+ # 10 . Get the candidate generator, given the parameterization
25112514 gen_mode_kwargs ["target_tokenizer" ] = gen_mode_kwargs .pop ("tokenizer" , None )
25122515 candidate_generator = self ._get_candidate_generator (
25132516 generation_config = generation_config ,
@@ -2518,7 +2521,7 @@ def generate(
25182521 ** gen_mode_kwargs ,
25192522 )
25202523
2521- # 13 . run assisted generate
2524+ # 11 . run assisted generate
25222525 result = self ._assisted_decoding (
25232526 input_ids ,
25242527 candidate_generator = candidate_generator ,
@@ -2531,7 +2534,7 @@ def generate(
25312534 )
25322535
25332536 elif generation_mode in (GenerationMode .SAMPLE , GenerationMode .GREEDY_SEARCH ):
2534- # 12 . run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
2537+ # 10 . run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
25352538 result = self ._sample (
25362539 input_ids ,
25372540 logits_processor = prepared_logits_processor ,
@@ -2543,7 +2546,7 @@ def generate(
25432546 )
25442547
25452548 elif generation_mode in (GenerationMode .BEAM_SAMPLE , GenerationMode .BEAM_SEARCH ):
2546- # 12 . run beam sample
2549+ # 10 . run beam sample
25472550 result = self ._beam_search (
25482551 input_ids ,
25492552 logits_processor = prepared_logits_processor ,
0 commit comments