Skip to content

Commit acd8205

Browse files
Align assisted generate for unified signature in decoding methods (#40657)
* Squashed previous branch * unify assisted generate to common decoding method signature * move checks to validate steps where possible * fix csm and other models that override _sample * ops dia you again * opsie * joao review
1 parent 16b821c commit acd8205

File tree

2 files changed

+90
-95
lines changed

2 files changed

+90
-95
lines changed

src/transformers/generation/utils.py

Lines changed: 89 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,19 @@
129129
"past_buckets_states", # reformer
130130
]
131131

132+
GENERATION_MODES_MAPPING = {
133+
GenerationMode.SAMPLE: "_sample",
134+
GenerationMode.GREEDY_SEARCH: "_sample",
135+
GenerationMode.BEAM_SEARCH: "_beam_search",
136+
GenerationMode.BEAM_SAMPLE: "_beam_search",
137+
GenerationMode.ASSISTED_GENERATION: "_assisted_decoding",
138+
# Deprecated methods
139+
GenerationMode.DOLA_GENERATION: "transformers-community/dola",
140+
GenerationMode.CONTRASTIVE_SEARCH: "transformers-community/contrastive-search",
141+
GenerationMode.GROUP_BEAM_SEARCH: "transformers-community/group-beam-search",
142+
GenerationMode.CONSTRAINED_BEAM_SEARCH: "transformers-community/constrained-beam-search",
143+
}
144+
132145

133146
@dataclass
134147
class GenerateDecoderOnlyOutput(ModelOutput):
@@ -1492,12 +1505,25 @@ def compute_transition_scores(
14921505

14931506
return transition_scores
14941507

1495-
def _validate_generation_mode(self, generation_mode, generation_mode_kwargs):
1508+
def _validate_generation_mode(self, generation_mode, generation_config, generation_mode_kwargs):
14961509
if generation_mode == GenerationMode.BEAM_SEARCH and "streamer" in generation_mode_kwargs:
14971510
raise ValueError(
14981511
"`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1."
14991512
)
15001513

1514+
if generation_mode == GenerationMode.ASSISTED_GENERATION:
1515+
if generation_config.num_return_sequences > 1:
1516+
raise ValueError(
1517+
"num_return_sequences has to be 1 when doing assisted generate, "
1518+
f"but is {generation_config.num_return_sequences}."
1519+
)
1520+
if self._is_stateful:
1521+
# In assisted generation we need the ability to confirm whether the model would pick certain tokens,
1522+
# which is not possible with stateful models (they can't reset to a previous subset of generated text)
1523+
raise ValueError(
1524+
f"assisted generation is not supported with stateful models, such as {self.__class__.__name__}"
1525+
)
1526+
15011527
if (assistant_model := generation_mode_kwargs.get("assistant_model")) is not None:
15021528
if self.config.is_encoder_decoder and not assistant_model.config.is_encoder_decoder:
15031529
attributes_to_check = ["encoder_attention_heads", "encoder_ffn_dim", "encoder_layers"]
@@ -2136,16 +2162,9 @@ def _get_deprecated_gen_repo(
21362162
"""
21372163
Returns the Hub repo for a deprecated generation mode, if any.
21382164
"""
2139-
moved_to_hub_modes = {
2140-
GenerationMode.DOLA_GENERATION: "transformers-community/dola",
2141-
GenerationMode.CONTRASTIVE_SEARCH: "transformers-community/contrastive-search",
2142-
GenerationMode.GROUP_BEAM_SEARCH: "transformers-community/group-beam-search",
2143-
GenerationMode.CONSTRAINED_BEAM_SEARCH: "transformers-community/constrained-beam-search",
2144-
}
2145-
if custom_generate is not None or generation_mode not in moved_to_hub_modes:
2165+
if custom_generate is not None or "/" not in (repo := GENERATION_MODES_MAPPING[generation_mode]):
21462166
return None
21472167

2148-
repo = moved_to_hub_modes[generation_mode]
21492168
logger.warning_once(
21502169
f"{generation_mode.name.replace('_', ' ').title()} was moved to a `custom_generate` repo: https://hf.co/{repo}. "
21512170
f"To prevent loss of backward compatibility, add `custom_generate='{repo}'` "
@@ -2175,10 +2194,11 @@ def _extract_generation_mode_kwargs(
21752194
"assistant_model": assistant_model,
21762195
"streamer": streamer,
21772196
}
2178-
if synced_gpus is not None:
2179-
generation_mode_kwargs["synced_gpus"] = (
2180-
is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
2181-
) and dist.get_world_size() > 1
2197+
generation_mode_kwargs["synced_gpus"] = (
2198+
(is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)) and dist.get_world_size() > 1
2199+
if synced_gpus is None
2200+
else synced_gpus
2201+
)
21822202
generation_mode_kwargs = {k: v for k, v in generation_mode_kwargs.items() if v is not None}
21832203
# Custom_generate callables can have their own set of arguments
21842204
# To extract them, we compare the signature with the standard _sample method
@@ -2338,15 +2358,20 @@ def generate(
23382358
generation_config, use_model_defaults, **kwargs
23392359
)
23402360
generation_mode = generation_config.get_generation_mode(assistant_model)
2361+
if isinstance(custom_generate, Callable):
2362+
decoding_method = custom_generate
2363+
else:
2364+
# type() required to access the unbound class-level method
2365+
decoding_method = getattr(type(self), GENERATION_MODES_MAPPING[generation_mode])
23412366

23422367
self._validate_model_kwargs(model_kwargs.copy())
2343-
self._validate_generation_mode(generation_mode, generation_mode_kwargs)
2368+
self._validate_generation_mode(generation_mode, generation_config, generation_mode_kwargs)
23442369

23452370
# Deprecation-related step: set Hub repo for deprecated strategies.
23462371
# NOTE: This must come after initializing generation_config, since we need it to determine if this is a deprecated mode.
23472372
# It must also be before any preparation steps, since Hub repos expect to be loaded before preparation steps.
23482373
# TODO joao, manuel: remove this in v4.62.0
2349-
if deprecate_mode_repo := self._get_deprecated_gen_repo(generation_mode, trust_remote_code, custom_generate):
2374+
if deprecated_mode_repo := self._get_deprecated_gen_repo(generation_mode, trust_remote_code, custom_generate):
23502375
return GenerationMixin.generate(
23512376
self,
23522377
inputs=inputs,
@@ -2358,7 +2383,7 @@ def generate(
23582383
negative_prompt_ids=negative_prompt_ids,
23592384
negative_prompt_attention_mask=negative_prompt_attention_mask,
23602385
use_model_defaults=use_model_defaults,
2361-
custom_generate=deprecate_mode_repo,
2386+
custom_generate=deprecated_mode_repo,
23622387
trust_remote_code=trust_remote_code,
23632388
**generation_mode_kwargs,
23642389
**kwargs,
@@ -2376,6 +2401,9 @@ def generate(
23762401
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
23772402
inputs, generation_config.bos_token_id, model_kwargs
23782403
)
2404+
# Some generation modes (e.g. assisted) need `inputs_tensor` to rerun encoder.forward()
2405+
if "inputs_tensor" in inspect.signature(decoding_method).parameters.keys():
2406+
generation_mode_kwargs["inputs_tensor"] = inputs_tensor
23792407
batch_size = inputs_tensor.shape[0]
23802408

23812409
device = inputs_tensor.device
@@ -2511,80 +2539,16 @@ def generate(
25112539
# Set model_kwargs `use_cache` so we can use it later in forward runs
25122540
model_kwargs["use_cache"] = generation_config.use_cache
25132541

2514-
# 9. go into different generation modes
2515-
if isinstance(custom_generate, Callable):
2516-
result = custom_generate(
2517-
self,
2518-
input_ids,
2519-
logits_processor=prepared_logits_processor,
2520-
stopping_criteria=prepared_stopping_criteria,
2521-
generation_config=generation_config,
2522-
**generation_mode_kwargs,
2523-
**model_kwargs,
2524-
)
2525-
elif generation_mode == GenerationMode.ASSISTED_GENERATION:
2526-
if generation_config.num_return_sequences > 1:
2527-
raise ValueError(
2528-
"num_return_sequences has to be 1 when doing assisted generate, "
2529-
f"but is {generation_config.num_return_sequences}."
2530-
)
2531-
if batch_size > 1:
2532-
raise ValueError("assisted generate is only supported for batch_size = 1")
2533-
if not model_kwargs["use_cache"]:
2534-
raise ValueError("assisted generate requires `use_cache=True`")
2535-
if generation_config.cache_implementation in ["static", "hybrid", "sliding_window"]:
2536-
raise ValueError("assisted generate is not supported with Static cache classes`")
2537-
if self._is_stateful:
2538-
# In assisted generation we need the ability to confirm whether the model would pick certain tokens,
2539-
# which is not possible with stateful models (they can't reset to a previous subset of generated text)
2540-
raise ValueError(
2541-
f"assisted generation is not supported with stateful models, such as {self.__class__.__name__}"
2542-
)
2543-
2544-
# 10. Get the candidate generator, given the parameterization
2545-
candidate_generator = self._get_candidate_generator(
2546-
generation_config=generation_config,
2547-
input_ids=input_ids,
2548-
inputs_tensor=inputs_tensor,
2549-
assistant_model=generation_mode_kwargs.pop("assistant_model", None),
2550-
logits_processor=logits_processor,
2551-
target_tokenizer=generation_mode_kwargs.pop("tokenizer", None),
2552-
assistant_tokenizer=generation_mode_kwargs.pop("assistant_tokenizer", None),
2553-
model_kwargs=model_kwargs,
2554-
)
2555-
2556-
# 11. run assisted generate
2557-
result = self._assisted_decoding(
2558-
input_ids,
2559-
candidate_generator=candidate_generator,
2560-
logits_processor=prepared_logits_processor,
2561-
stopping_criteria=prepared_stopping_criteria,
2562-
generation_config=generation_config,
2563-
**generation_mode_kwargs,
2564-
**model_kwargs,
2565-
)
2566-
2567-
elif generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
2568-
# 10. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
2569-
result = self._sample(
2570-
input_ids,
2571-
logits_processor=prepared_logits_processor,
2572-
stopping_criteria=prepared_stopping_criteria,
2573-
generation_config=generation_config,
2574-
**generation_mode_kwargs,
2575-
**model_kwargs,
2576-
)
2577-
2578-
elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
2579-
# 10. run beam sample
2580-
result = self._beam_search(
2581-
input_ids,
2582-
logits_processor=prepared_logits_processor,
2583-
stopping_criteria=prepared_stopping_criteria,
2584-
generation_config=generation_config,
2585-
**generation_mode_kwargs,
2586-
**model_kwargs,
2587-
)
2542+
# 9. Call generation mode
2543+
result = decoding_method(
2544+
self,
2545+
input_ids,
2546+
logits_processor=prepared_logits_processor,
2547+
stopping_criteria=prepared_stopping_criteria,
2548+
generation_config=generation_config,
2549+
**generation_mode_kwargs,
2550+
**model_kwargs,
2551+
)
25882552

25892553
# Convert to legacy cache format if requested
25902554
if (
@@ -3466,12 +3430,15 @@ def _beam_search(
34663430
def _assisted_decoding(
34673431
self,
34683432
input_ids: torch.LongTensor,
3469-
candidate_generator: CandidateGenerator,
34703433
logits_processor: LogitsProcessorList,
34713434
stopping_criteria: StoppingCriteriaList,
34723435
generation_config: GenerationConfig,
34733436
synced_gpus: bool = False,
34743437
streamer: Optional["BaseStreamer"] = None,
3438+
inputs_tensor: torch.FloatTensor = None,
3439+
assistant_model: Optional["PreTrainedModel"] = None,
3440+
assistant_tokenizer: Optional["PreTrainedTokenizerBase"] = None,
3441+
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
34753442
**model_kwargs,
34763443
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
34773444
r"""
@@ -3483,9 +3450,6 @@ def _assisted_decoding(
34833450
Parameters:
34843451
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
34853452
The sequence used as a prompt for the generation.
3486-
candidate_generator (`CandidateGenerator`):
3487-
A derived instance of [`CandidateGenerator`] that defines how candidate sequences are generated. For
3488-
more information, the documentation of [`CandidateGenerator`] should be read.
34893453
logits_processor (`LogitsProcessorList`):
34903454
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
34913455
used to modify the prediction scores of the language modeling head applied at each generation step.
@@ -3500,6 +3464,15 @@ def _assisted_decoding(
35003464
streamer (`BaseStreamer`, *optional*):
35013465
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
35023466
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
3467+
inputs_tensor (`torch.FloatTensor`, *optional*):
3468+
The input tensor for generation. For decoder models, usually `input_ids`. For encoder-decoder models,
3469+
the tensor that produced `model_kwargs["encoder_outputs"]`.
3470+
assistant_model (`PreTrainedModel`, *optional*):
3471+
The model used to assist the generation process. If not provided, the main model will be used.
3472+
assistant_tokenizer (`PreTrainedTokenizerBase`, *optional*):
3473+
The tokenizer used for the assistant model. If not provided, the token space is assumed to be the same.
3474+
tokenizer (`PreTrainedTokenizerBase`, *optional*):
3475+
The tokenizer used for the main model. If not provided, the token space is assumed to be the same.
35033476
model_kwargs:
35043477
Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
35053478
If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
@@ -3511,6 +3484,26 @@ def _assisted_decoding(
35113484
`return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
35123485
`model.config.is_encoder_decoder=True`.
35133486
"""
3487+
# The cache must be dynamic for assisted generation, and the check must happen AFTER preparing cache
3488+
if not model_kwargs["use_cache"]:
3489+
raise ValueError("assisted generate requires `use_cache=True`")
3490+
if generation_config.cache_implementation in ["static", "hybrid", "sliding_window"] or (
3491+
"past_key_values" in model_kwargs
3492+
and hasattr(model_kwargs["past_key_values"], "layers")
3493+
and any(getattr(l, "is_compileable", False) for l in model_kwargs["past_key_values"].layers)
3494+
):
3495+
raise ValueError("assisted generate is not supported with Static cache classes`")
3496+
# Get the candidate generator, given the parameterization
3497+
candidate_generator = self._get_candidate_generator(
3498+
generation_config=generation_config,
3499+
input_ids=input_ids,
3500+
inputs_tensor=inputs_tensor,
3501+
assistant_model=assistant_model,
3502+
logits_processor=logits_processor,
3503+
target_tokenizer=tokenizer,
3504+
assistant_tokenizer=assistant_tokenizer,
3505+
model_kwargs=model_kwargs,
3506+
)
35143507
# init values
35153508
do_sample = generation_config.do_sample
35163509
output_attentions = generation_config.output_attentions
@@ -3535,6 +3528,8 @@ def _assisted_decoding(
35353528

35363529
# keep track of which sequences are already finished
35373530
batch_size, cur_len = input_ids.shape[:2]
3531+
if batch_size > 1:
3532+
raise ValueError("assisted generate is only supported for batch_size = 1")
35383533
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
35393534
model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
35403535

src/transformers/models/dia/generation_dia.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def _main_generate_loop(
278278
generation_mode = generation_config.get_generation_mode(assistant_model)
279279

280280
self._validate_model_kwargs(model_kwargs.copy())
281-
self._validate_generation_mode(generation_mode, generation_mode_kwargs)
281+
self._validate_generation_mode(generation_mode, generation_config, generation_mode_kwargs)
282282

283283
# 2. Set generation parameters if not already defined
284284
if synced_gpus is None:

0 commit comments

Comments
 (0)