Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 71 additions & 74 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2192,14 +2192,26 @@ def _valid_auto_compile_criteria(self, model_kwargs: dict[str, Any], generation_
has_disk_offload = "disk" in all_model_devices
can_compile &= not has_disk_offload

# Finally: if the user has manually specified compilation options, but compilation is not possible, let's warn
# If the user has manually specified compilation options, but compilation is not possible, let's warn
# them
if generation_config.compile_config is not None and not can_compile:
logger.warning_once(
"You have set `compile_config`, but we are unable to meet the criteria for compilation. Compilation "
"will be skipped."
)

# Finally: if we can compile, disable tokenizers parallelism and check for FA2 + static cache
os.environ["TOKENIZERS_PARALLELISM"] = "0"
# If we use FA2 and a static cache, we cannot compile with fullgraph
if self.config._attn_implementation == "flash_attention_2":
# only raise warning if the user passed an explicit compile-config
if generation_config.compile_config is not None and generation_config.compile_config.fullgraph:
logger.warning_once(
"When using Flash Attention 2 and a static cache, you cannot use the option `CompileConfig(fullgraph=True)` as "
"FA2 introduces graph breaks. We overrode the option with `fullgraph=False`."
)
generation_config.compile_config.fullgraph = False

return can_compile

def _get_deprecated_gen_repo(
Expand Down Expand Up @@ -2636,7 +2648,7 @@ def generate(
UserWarning,
)

# 8. prepare logits processors and stopping criteria
# 8. Prepare logits processors and stopping criteria
prepared_logits_processor = self._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_length,
Expand Down Expand Up @@ -2843,40 +2855,21 @@ def _sample(
batch_size, cur_len = input_ids.shape[:2]
this_peer_finished = False
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)

model_forward = self.__call__
compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
if compile_forward:
os.environ["TOKENIZERS_PARALLELISM"] = "0"
# If we use FA2 and a static cache, we cannot compile with fullgraph
if self.config._attn_implementation == "flash_attention_2":
# only raise warning if the user passed an explicit compile-config
if generation_config.compile_config is not None and generation_config.compile_config.fullgraph:
logger.warning_once(
"When using Flash Attention 2 and a static cache, you cannot use the option `CompileConfig(fullgraph=True)` as "
"FA2 introduces graph breaks. We overrode the option with `fullgraph=False`."
)
generation_config.compile_config.fullgraph = False
model_forward = self.get_compiled_call(generation_config.compile_config)
model_forward = (
self.get_compiled_call(generation_config.compile_config)
if self._valid_auto_compile_criteria(model_kwargs, generation_config)
else self.__call__
)

if generation_config.prefill_chunk_size is not None:
model_kwargs = self._prefill_chunking(input_ids, generation_config, **model_kwargs)
is_prefill = False
else:
is_prefill = True
prefill_consumed = False
outputs = self._prefill(input_ids, generation_config, model_kwargs)

while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

if is_prefill:
outputs = self(**model_inputs, return_dict=True)
is_prefill = False
else:
if prefill_consumed:
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
outputs = model_forward(**model_inputs, return_dict=True)

# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
prefill_consumed = True
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
Expand Down Expand Up @@ -3246,7 +3239,6 @@ def _beam_search(
`return_dict_in_generate=True` or a [`~generation.GenerateBeamEncoderDecoderOutput`] if
`model.config.is_encoder_decoder=True`.
"""

# 1. init beam_search values
pad_token_id = generation_config._pad_token_tensor
eos_token_id = generation_config._eos_token_tensor
Expand Down Expand Up @@ -3287,8 +3279,6 @@ def _beam_search(
dim=0,
).to(input_ids.device)

model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)

# (joao) feature lost in the refactor. Probably won't implement, hurts readability with minimal gains (there
# are newer low-memory alternatives like the offloaded cache)
sequential = generation_config.low_memory
Expand Down Expand Up @@ -3350,13 +3340,18 @@ def _beam_search(
)
beam_indices = running_beam_indices.detach().clone()

prefill_consumed = False
flat_running_sequences = input_ids
model_outputs = self._prefill(input_ids, generation_config, model_kwargs)

# 4. run the generation loop
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
# a. Forward current tokens, obtain the logits
flat_running_sequences = self._flatten_beam_dim(running_sequences[:, :, :cur_len])
model_inputs = self.prepare_inputs_for_generation(flat_running_sequences, **model_kwargs)

model_outputs = self(**model_inputs, return_dict=True)
if prefill_consumed:
# a. Forward current tokens, obtain the logits
flat_running_sequences = self._flatten_beam_dim(running_sequences[:, :, :cur_len])
model_inputs = self.prepare_inputs_for_generation(flat_running_sequences, **model_kwargs)
model_outputs = self(**model_inputs, return_dict=True)
prefill_consumed = True

# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
model_kwargs = self._update_model_kwargs_for_generation(
Expand Down Expand Up @@ -3839,49 +3834,51 @@ def _assisted_decoding(
else:
return input_ids

def _prefill_chunking(self, input_ids: torch.LongTensor, generation_config: GenerationConfig, **model_kwargs):
# Even if we are not compiling the forward, flex is always compiled when used. With chunk prefill, we may
# end up needing just a bit more graphs than the default (which is 8). Doing this avoids very cryptic warnings
torch._dynamo.config.cache_size_limit = 64

chunk_size = generation_config.prefill_chunk_size
# Only chunk up the token just before last, so that decoding is completely performed outside this function
# (here we simply prefill the cache)
input_chunks = torch.split(input_ids[:, :-1], chunk_size, dim=-1)

if "past_key_values" not in model_kwargs:
raise ValueError("Cannot use prefill chunking without a cache")

model_forward = self.forward
# TODO: v5.1: make public once API stabilized
def _prefill(self, input_ids: torch.LongTensor, generation_config: GenerationConfig, model_kwargs):
if generation_config.prefill_chunk_size is None:
model_kwargs = self._get_initial_cache_position(input_ids.shape[1], input_ids.device, model_kwargs)
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
return self(**model_inputs, return_dict=True)
else: # Chunked prefill
# Even if we are not compiling the forward, flex is always compiled when used. With chunked prefill, we may
# end up needing just a bit more graphs than the default (which is 8). Doing this avoids very cryptic warnings
torch._dynamo.config.cache_size_limit = 64

compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
if compile_forward:
model_forward = self.get_compiled_call(generation_config.compile_config)
chunk_size = generation_config.prefill_chunk_size
input_chunks = torch.split(input_ids, chunk_size, dim=-1)

attention_mask = model_kwargs.pop("attention_mask", None)
if "past_key_values" not in model_kwargs:
raise ValueError("Cannot use prefill chunking without a cache")

past_length = 0
for input_chunk in input_chunks:
current_length = past_length + input_chunk.shape[-1]
# Prepare inputs
if attention_mask is not None:
model_kwargs["attention_mask"] = attention_mask[:, :current_length]
model_kwargs["cache_position"] = torch.arange(
past_length, current_length, dtype=torch.long, device=input_chunk.device
model_forward = (
self.get_compiled_call(generation_config.compile_config)
if self._valid_auto_compile_criteria(model_kwargs, generation_config)
else self.__call__
)
model_kwargs["position_ids"] = model_kwargs["cache_position"].unsqueeze(0)
model_inputs = self.prepare_inputs_for_generation(input_chunk, **model_kwargs)

outputs = model_forward(**model_inputs, return_dict=True)
attention_mask = model_kwargs.pop("attention_mask", None)
past_length = 0
for input_chunk in input_chunks:
current_length = past_length + input_chunk.shape[-1]
if attention_mask is not None:
model_kwargs["attention_mask"] = attention_mask[:, :current_length]
model_kwargs["cache_position"] = torch.arange(
past_length, current_length, dtype=torch.long, device=input_chunk.device
)
model_kwargs["position_ids"] = model_kwargs["cache_position"].unsqueeze(0)
model_inputs = self.prepare_inputs_for_generation(input_chunk, **model_kwargs)

model_kwargs["past_key_values"] = outputs.past_key_values
past_length = current_length
outputs = model_forward(**model_inputs, return_dict=True)

model_kwargs["attention_mask"] = attention_mask
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
_ = model_kwargs.pop("position_ids", None)
model_kwargs["past_key_values"] = outputs.past_key_values
past_length = current_length

return model_kwargs
model_kwargs["attention_mask"] = attention_mask
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
_ = model_kwargs.pop("position_ids", None)
# Latest outputs contain next token logits
return outputs


def _speculative_sampling(
Expand Down
11 changes: 5 additions & 6 deletions src/transformers/models/csm/generation_csm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional, Union

Expand Down Expand Up @@ -204,11 +203,11 @@ def _sample(
criterion.max_length -= cur_len
# ============================================

model_forward = self.__call__
compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
if compile_forward:
os.environ["TOKENIZERS_PARALLELISM"] = "0"
model_forward = self.get_compiled_call(generation_config.compile_config)
model_forward = (
self.get_compiled_call(generation_config.compile_config)
if self._valid_auto_compile_criteria(model_kwargs, generation_config)
else self.__call__
)

is_prefill = True
while self._has_unfinished_sequences(
Expand Down
49 changes: 29 additions & 20 deletions src/transformers/models/dia/generation_dia.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,12 @@ def _main_generate_loop(
)
generation_mode = generation_config.get_generation_mode(assistant_model)

if generation_mode not in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
raise ValueError(
"Got incompatible mode for generation, should be one of greedy or sampling. "
"Ensure that beam search is de-activated by setting `num_beams=1`."
)

self._validate_model_kwargs(model_kwargs.copy())
self._validate_generation_mode(generation_mode, generation_config, generation_mode_kwargs)

Expand Down Expand Up @@ -382,26 +388,29 @@ def _main_generate_loop(
# Prepare inner 2D logic in generation loop
input_ids = input_ids.reshape(-1, input_ids.shape[-1])

# 10. go into different generation modes
if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
# 11. expand input_ids with `num_return_sequences` additional sequences per batch
if generation_config.num_return_sequences > 1:
raise ValueError("`num_return_sequences>1` is incompatible with Dia.")

# 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
return self._sample(
input_ids,
logits_processor=prepared_logits_processor,
stopping_criteria=prepared_stopping_criteria,
generation_config=generation_config,
**generation_mode_kwargs,
**model_kwargs,
)
else:
raise ValueError(
"Got incompatible mode for generation, should be one of greedy or sampling. "
"Ensure that beam search is de-activated by setting `num_beams=1`."
)
model_kwargs = self._get_initial_cache_position(input_ids.shape[1], input_ids.device, model_kwargs)
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

# 10. Prefill
model_inputs.update({"output_attentions": generation_config.output_attentions})
model_inputs.update({"output_hidden_states": generation_config.output_hidden_states})
outputs = self(**model_inputs, return_dict=True)

# 11. expand input_ids with `num_return_sequences` additional sequences per batch
if generation_config.num_return_sequences > 1:
raise ValueError("`num_return_sequences>1` is incompatible with Dia.")

# 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
return self._sample(
input_ids,
logits_processor=prepared_logits_processor,
stopping_criteria=prepared_stopping_criteria,
generation_config=generation_config,
prefill_outputs=outputs,
**generation_mode_kwargs,
**model_kwargs,
)

@torch.no_grad()
def generate(
Expand Down
Loading