Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 89 additions & 94 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,19 @@
"past_buckets_states", # reformer
]

GENERATION_MODES_MAPPING = {
GenerationMode.SAMPLE: "_sample",
GenerationMode.GREEDY_SEARCH: "_sample",
GenerationMode.BEAM_SEARCH: "_beam_search",
GenerationMode.BEAM_SAMPLE: "_beam_search",
GenerationMode.ASSISTED_GENERATION: "_assisted_decoding",
# Deprecated methods
GenerationMode.DOLA_GENERATION: "transformers-community/dola",
GenerationMode.CONTRASTIVE_SEARCH: "transformers-community/contrastive-search",
GenerationMode.GROUP_BEAM_SEARCH: "transformers-community/group-beam-search",
GenerationMode.CONSTRAINED_BEAM_SEARCH: "transformers-community/constrained-beam-search",
}


@dataclass
class GenerateDecoderOnlyOutput(ModelOutput):
Expand Down Expand Up @@ -1492,12 +1505,25 @@ def compute_transition_scores(

return transition_scores

def _validate_generation_mode(self, generation_mode, generation_mode_kwargs):
def _validate_generation_mode(self, generation_mode, generation_config, generation_mode_kwargs):
if generation_mode == GenerationMode.BEAM_SEARCH and "streamer" in generation_mode_kwargs:
raise ValueError(
"`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1."
)

if generation_mode == GenerationMode.ASSISTED_GENERATION:
if generation_config.num_return_sequences > 1:
raise ValueError(
"num_return_sequences has to be 1 when doing assisted generate, "
f"but is {generation_config.num_return_sequences}."
)
if self._is_stateful:
# In assisted generation we need the ability to confirm whether the model would pick certain tokens,
# which is not possible with stateful models (they can't reset to a previous subset of generated text)
raise ValueError(
f"assisted generation is not supported with stateful models, such as {self.__class__.__name__}"
)

if (assistant_model := generation_mode_kwargs.get("assistant_model")) is not None:
if self.config.is_encoder_decoder and not assistant_model.config.is_encoder_decoder:
attributes_to_check = ["encoder_attention_heads", "encoder_ffn_dim", "encoder_layers"]
Expand Down Expand Up @@ -2136,16 +2162,9 @@ def _get_deprecated_gen_repo(
"""
Returns the Hub repo for a deprecated generation mode, if any.
"""
moved_to_hub_modes = {
GenerationMode.DOLA_GENERATION: "transformers-community/dola",
GenerationMode.CONTRASTIVE_SEARCH: "transformers-community/contrastive-search",
GenerationMode.GROUP_BEAM_SEARCH: "transformers-community/group-beam-search",
GenerationMode.CONSTRAINED_BEAM_SEARCH: "transformers-community/constrained-beam-search",
}
if custom_generate is not None or generation_mode not in moved_to_hub_modes:
if custom_generate is not None or "/" not in (repo := GENERATION_MODES_MAPPING[generation_mode]):
return None

repo = moved_to_hub_modes[generation_mode]
logger.warning_once(
f"{generation_mode.name.replace('_', ' ').title()} was moved to a `custom_generate` repo: https://hf.co/{repo}. "
f"To prevent loss of backward compatibility, add `custom_generate='{repo}'` "
Expand Down Expand Up @@ -2175,10 +2194,11 @@ def _extract_generation_mode_kwargs(
"assistant_model": assistant_model,
"streamer": streamer,
}
if synced_gpus is not None:
generation_mode_kwargs["synced_gpus"] = (
is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
) and dist.get_world_size() > 1
generation_mode_kwargs["synced_gpus"] = (
(is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)) and dist.get_world_size() > 1
if synced_gpus is None
else synced_gpus
)
generation_mode_kwargs = {k: v for k, v in generation_mode_kwargs.items() if v is not None}
# Custom_generate callables can have their own set of arguments
# To extract them, we compare the signature with the standard _sample method
Expand Down Expand Up @@ -2338,15 +2358,20 @@ def generate(
generation_config, use_model_defaults, **kwargs
)
generation_mode = generation_config.get_generation_mode(assistant_model)
if isinstance(custom_generate, Callable):
decoding_method = custom_generate
else:
# type() required to access the unbound class-level method
decoding_method = getattr(type(self), GENERATION_MODES_MAPPING[generation_mode])

self._validate_model_kwargs(model_kwargs.copy())
self._validate_generation_mode(generation_mode, generation_mode_kwargs)
self._validate_generation_mode(generation_mode, generation_config, generation_mode_kwargs)

# Deprecation-related step: set Hub repo for deprecated strategies.
# NOTE: This must come after initializing generation_config, since we need it to determine if this is a deprecated mode.
# It must also be before any preparation steps, since Hub repos expect to be loaded before preparation steps.
# TODO joao, manuel: remove this in v4.62.0
if deprecate_mode_repo := self._get_deprecated_gen_repo(generation_mode, trust_remote_code, custom_generate):
if deprecated_mode_repo := self._get_deprecated_gen_repo(generation_mode, trust_remote_code, custom_generate):
return GenerationMixin.generate(
self,
inputs=inputs,
Expand All @@ -2358,7 +2383,7 @@ def generate(
negative_prompt_ids=negative_prompt_ids,
negative_prompt_attention_mask=negative_prompt_attention_mask,
use_model_defaults=use_model_defaults,
custom_generate=deprecate_mode_repo,
custom_generate=deprecated_mode_repo,
trust_remote_code=trust_remote_code,
**generation_mode_kwargs,
**kwargs,
Expand All @@ -2376,6 +2401,9 @@ def generate(
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
inputs, generation_config.bos_token_id, model_kwargs
)
# Some generation modes (e.g. assisted) need `inputs_tensor` to rerun encoder.forward()
if "inputs_tensor" in inspect.signature(decoding_method).parameters.keys():
generation_mode_kwargs["inputs_tensor"] = inputs_tensor
batch_size = inputs_tensor.shape[0]

device = inputs_tensor.device
Expand Down Expand Up @@ -2511,80 +2539,16 @@ def generate(
# Set model_kwargs `use_cache` so we can use it later in forward runs
model_kwargs["use_cache"] = generation_config.use_cache

# 9. go into different generation modes
if isinstance(custom_generate, Callable):
result = custom_generate(
self,
input_ids,
logits_processor=prepared_logits_processor,
stopping_criteria=prepared_stopping_criteria,
generation_config=generation_config,
**generation_mode_kwargs,
**model_kwargs,
)
elif generation_mode == GenerationMode.ASSISTED_GENERATION:
if generation_config.num_return_sequences > 1:
raise ValueError(
"num_return_sequences has to be 1 when doing assisted generate, "
f"but is {generation_config.num_return_sequences}."
)
if batch_size > 1:
raise ValueError("assisted generate is only supported for batch_size = 1")
if not model_kwargs["use_cache"]:
raise ValueError("assisted generate requires `use_cache=True`")
if generation_config.cache_implementation in ["static", "hybrid", "sliding_window"]:
raise ValueError("assisted generate is not supported with Static cache classes`")
if self._is_stateful:
# In assisted generation we need the ability to confirm whether the model would pick certain tokens,
# which is not possible with stateful models (they can't reset to a previous subset of generated text)
raise ValueError(
f"assisted generation is not supported with stateful models, such as {self.__class__.__name__}"
)

# 10. Get the candidate generator, given the parameterization
candidate_generator = self._get_candidate_generator(
generation_config=generation_config,
input_ids=input_ids,
inputs_tensor=inputs_tensor,
assistant_model=generation_mode_kwargs.pop("assistant_model", None),
logits_processor=logits_processor,
target_tokenizer=generation_mode_kwargs.pop("tokenizer", None),
assistant_tokenizer=generation_mode_kwargs.pop("assistant_tokenizer", None),
model_kwargs=model_kwargs,
)

# 11. run assisted generate
result = self._assisted_decoding(
input_ids,
candidate_generator=candidate_generator,
logits_processor=prepared_logits_processor,
stopping_criteria=prepared_stopping_criteria,
generation_config=generation_config,
**generation_mode_kwargs,
**model_kwargs,
)

elif generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
# 10. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
result = self._sample(
input_ids,
logits_processor=prepared_logits_processor,
stopping_criteria=prepared_stopping_criteria,
generation_config=generation_config,
**generation_mode_kwargs,
**model_kwargs,
)

elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
# 10. run beam sample
result = self._beam_search(
input_ids,
logits_processor=prepared_logits_processor,
stopping_criteria=prepared_stopping_criteria,
generation_config=generation_config,
**generation_mode_kwargs,
**model_kwargs,
)
# 9. Call generation mode
result = decoding_method(
self,
input_ids,
logits_processor=prepared_logits_processor,
stopping_criteria=prepared_stopping_criteria,
generation_config=generation_config,
**generation_mode_kwargs,
**model_kwargs,
)

# Convert to legacy cache format if requested
if (
Expand Down Expand Up @@ -3466,12 +3430,15 @@ def _beam_search(
def _assisted_decoding(
self,
input_ids: torch.LongTensor,
candidate_generator: CandidateGenerator,
logits_processor: LogitsProcessorList,
stopping_criteria: StoppingCriteriaList,
generation_config: GenerationConfig,
synced_gpus: bool = False,
streamer: Optional["BaseStreamer"] = None,
inputs_tensor: torch.FloatTensor = None,
assistant_model: Optional["PreTrainedModel"] = None,
assistant_tokenizer: Optional["PreTrainedTokenizerBase"] = None,
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
**model_kwargs,
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
r"""
Expand All @@ -3483,9 +3450,6 @@ def _assisted_decoding(
Parameters:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The sequence used as a prompt for the generation.
candidate_generator (`CandidateGenerator`):
A derived instance of [`CandidateGenerator`] that defines how candidate sequences are generated. For
more information, the documentation of [`CandidateGenerator`] should be read.
logits_processor (`LogitsProcessorList`):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step.
Expand All @@ -3500,6 +3464,15 @@ def _assisted_decoding(
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
inputs_tensor (`torch.FloatTensor`, *optional*):
The input tensor for generation. For decoder models, usually `input_ids`. For encoder-decoder models,
the tensor that produced `model_kwargs["encoder_outputs"]`.
assistant_model (`PreTrainedModel`, *optional*):
The model used to assist the generation process. If not provided, the main model will be used.
assistant_tokenizer (`PreTrainedTokenizerBase`, *optional*):
The tokenizer used for the assistant model. If not provided, the token space is assumed to be the same.
tokenizer (`PreTrainedTokenizerBase`, *optional*):
The tokenizer used for the main model. If not provided, the token space is assumed to be the same.
model_kwargs:
Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
Expand All @@ -3511,6 +3484,26 @@ def _assisted_decoding(
`return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
`model.config.is_encoder_decoder=True`.
"""
# The cache must be dynamic for assisted generation, and the check must happen AFTER preparing cache
if not model_kwargs["use_cache"]:
raise ValueError("assisted generate requires `use_cache=True`")
if generation_config.cache_implementation in ["static", "hybrid", "sliding_window"] or (
"past_key_values" in model_kwargs
and hasattr(model_kwargs["past_key_values"], "layers")
and any(getattr(l, "is_compileable", False) for l in model_kwargs["past_key_values"].layers)
):
raise ValueError("assisted generate is not supported with Static cache classes`")
# Get the candidate generator, given the parameterization
candidate_generator = self._get_candidate_generator(
generation_config=generation_config,
input_ids=input_ids,
inputs_tensor=inputs_tensor,
assistant_model=assistant_model,
logits_processor=logits_processor,
target_tokenizer=tokenizer,
assistant_tokenizer=assistant_tokenizer,
model_kwargs=model_kwargs,
)
# init values
do_sample = generation_config.do_sample
output_attentions = generation_config.output_attentions
Expand All @@ -3535,6 +3528,8 @@ def _assisted_decoding(

# keep track of which sequences are already finished
batch_size, cur_len = input_ids.shape[:2]
if batch_size > 1:
raise ValueError("assisted generate is only supported for batch_size = 1")
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/dia/generation_dia.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def _main_generate_loop(
generation_mode = generation_config.get_generation_mode(assistant_model)

self._validate_model_kwargs(model_kwargs.copy())
self._validate_generation_mode(generation_mode, generation_mode_kwargs)
self._validate_generation_mode(generation_mode, generation_config, generation_mode_kwargs)

# 2. Set generation parameters if not already defined
if synced_gpus is None:
Expand Down