Skip to content

Commit 4655dfa

Browse files
committed
wrong merge
1 parent 4647814 commit 4655dfa

File tree

1 file changed

+16
-13
lines changed

1 file changed

+16
-13
lines changed

src/transformers/generation/utils.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2304,7 +2304,10 @@ def generate(
23042304

23052305
generation_mode = generation_config.get_generation_mode(**gen_mode_kwargs)
23062306

2307-
# 2. Set Hub repo for deprecated strategies. (TODO joao, manuel: remove this in v4.62.0)
2307+
# Deprecation-related step: set Hub repo for deprecated strategies.
2308+
# NOTE: This must come after initializing generation_config, since we need it to determine if this is a deprecated mode.
2309+
# It must also be before any preparation steps, since Hub repos expect to be loaded before preparation steps.
2310+
# TODO joao, manuel: remove this in v4.62.0
23082311
if deprecate_mode_repo := self._get_deprecated_gen_repo(generation_mode, trust_remote_code, custom_generate):
23092312
return GenerationMixin.generate(
23102313
self,
@@ -2322,7 +2325,7 @@ def generate(
23222325
**gen_mode_kwargs,
23232326
)
23242327

2325-
# 3. Set generation parameters if not already defined
2328+
# 2. Set generation parameters if not already defined
23262329
if synced_gpus is None:
23272330
synced_gpus = (is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)) and dist.get_world_size() > 1
23282331

@@ -2333,7 +2336,7 @@ def generate(
23332336
requires_attention_mask = "encoder_outputs" not in model_kwargs
23342337
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
23352338

2336-
# 4. Define model inputs
2339+
# 3. Define model inputs
23372340
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
23382341
inputs, generation_config.bos_token_id, model_kwargs
23392342
)
@@ -2357,7 +2360,7 @@ def generate(
23572360
"generation results, please set `padding_side='left'` when initializing the tokenizer."
23582361
)
23592362

2360-
# 5. Define other model kwargs
2363+
# 4. Define other model kwargs
23612364
# decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are
23622365
# generating the first new token or not, and we only want to use the embeddings for the first new token)
23632366
if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds":
@@ -2378,7 +2381,7 @@ def generate(
23782381
inputs_tensor, model_kwargs, model_input_name, generation_config
23792382
)
23802383

2381-
# 6. Prepare `input_ids` which will be used for auto-regressive generation
2384+
# 5. Prepare `input_ids` which will be used for auto-regressive generation
23822385
if self.config.is_encoder_decoder:
23832386
input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
23842387
batch_size=batch_size,
@@ -2404,7 +2407,7 @@ def generate(
24042407
if streamer is not None:
24052408
streamer.put(input_ids.cpu())
24062409

2407-
# 7. Prepare `max_length` depending on other stopping criteria.
2410+
# 6. Prepare `max_length` depending on other stopping criteria.
24082411
input_ids_length = input_ids.shape[1]
24092412
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
24102413
has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
@@ -2425,7 +2428,7 @@ def generate(
24252428

24262429
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
24272430

2428-
# 8. Prepare the cache.
2431+
# 7. Prepare the cache.
24292432
# - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`.
24302433
# - different models have a different cache name expected by the model (default = "past_key_values")
24312434
# - `max_length`, prepared above, is used to determine the maximum cache length
@@ -2456,7 +2459,7 @@ def generate(
24562459
UserWarning,
24572460
)
24582461

2459-
# 10. prepare logits processors and stopping criteria
2462+
# 8. prepare logits processors and stopping criteria
24602463
prepared_logits_processor = self._get_logits_processor(
24612464
generation_config=generation_config,
24622465
input_ids_seq_length=input_ids_length,
@@ -2476,7 +2479,7 @@ def generate(
24762479
# Set model_kwargs `use_cache` so we can use it later in forward runs
24772480
model_kwargs["use_cache"] = generation_config.use_cache
24782481

2479-
# 11. go into different generation modes
2482+
# 9. go into different generation modes
24802483
if isinstance(custom_generate, Callable):
24812484
result = custom_generate(
24822485
self,
@@ -2507,7 +2510,7 @@ def generate(
25072510
f"assisted generation is not supported with stateful models, such as {self.__class__.__name__}"
25082511
)
25092512

2510-
# 11. Get the candidate generator, given the parameterization
2513+
# 10. Get the candidate generator, given the parameterization
25112514
gen_mode_kwargs["target_tokenizer"] = gen_mode_kwargs.pop("tokenizer", None)
25122515
candidate_generator = self._get_candidate_generator(
25132516
generation_config=generation_config,
@@ -2518,7 +2521,7 @@ def generate(
25182521
**gen_mode_kwargs,
25192522
)
25202523

2521-
# 13. run assisted generate
2524+
# 11. run assisted generate
25222525
result = self._assisted_decoding(
25232526
input_ids,
25242527
candidate_generator=candidate_generator,
@@ -2531,7 +2534,7 @@ def generate(
25312534
)
25322535

25332536
elif generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
2534-
# 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
2537+
# 10. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
25352538
result = self._sample(
25362539
input_ids,
25372540
logits_processor=prepared_logits_processor,
@@ -2543,7 +2546,7 @@ def generate(
25432546
)
25442547

25452548
elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
2546-
# 12. run beam sample
2549+
# 10. run beam sample
25472550
result = self._beam_search(
25482551
input_ids,
25492552
logits_processor=prepared_logits_processor,

0 commit comments

Comments
 (0)