Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
}
QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache}
ALL_CACHE_IMPLEMENTATIONS = (
list(NEED_SETUP_CACHE_CLASSES_MAPPING.keys()) + list(CACHE_CONFIG_MAPPING.keys()) + ["offloaded"]
list(NEED_SETUP_CACHE_CLASSES_MAPPING.keys()) + list(CACHE_CONFIG_MAPPING.keys()) + ["offloaded", "dynamic"]
)


Expand Down Expand Up @@ -175,16 +175,16 @@ class GenerationConfig(PushToHubMixin):
cache_implementation (`str`, *optional*, default to `None`):
Name of the cache class that will be instantiated in `generate`, for faster decoding. Possible values are:

- `"dynamic"`: [`DynamicCache`]
- `"static"`: [`StaticCache`]
- `"offloaded_static"`: [`OffloadedStaticCache`]
- `"sliding_window"`: [`SlidingWindowCache`]
- `"hybrid"`: [`HybridCache`]
- `"mamba"`: [`MambaCache`]
- `"quantized"`: [`QuantizedCache`]

We support other cache types, but they must be manually instantiated and
passed to `generate` through the `past_key_values` argument. See our
[cache documentation](https://huggingface.co/docs/transformers/en/kv_cache) for further information.
If none is specified, we will use the default cache for the model (which is often [`DynamicCache`]). See
our [cache documentation](https://huggingface.co/docs/transformers/en/kv_cache) for further information.
cache_config (`CacheConfig` or `dict`, *optional*, default to `None`):
Arguments used in the key-value cache class can be passed in `cache_config`. Can be passed as a `Dict` and
it will be converted to its repsective `CacheConfig` internally.
Expand Down
65 changes: 47 additions & 18 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1177,21 +1177,37 @@ def _merge_criteria_processor_list(
default_list: Union[LogitsProcessorList, StoppingCriteriaList],
custom_list: Union[LogitsProcessorList, StoppingCriteriaList],
) -> Union[LogitsProcessorList, StoppingCriteriaList]:
"""
Copy link
Contributor Author

@gante gante Mar 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes in this function are secondary to the main change:

  • whisper breaks because it both sets custom logits processors AND has the default flags in the generation config to instantiate them
  • after the original change (inherit defaults from the model's generation config), we were throwing an exception here
  • after this secondary change, we only throw a warning and discard the logits processor instance created inside .generate() (i.e. assumes the user knows what they are doing when the pass logits_processors to .generate() instead of crashing)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A tiny comment wouldn't hurt for future us, isn't very easy to get why we do this without reading PR description.

Also I am not sure if this is required, aren't we restricting custom logits processors to only those that cannot be configured by generation config? Something that is only defined by users for their use-case

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I am not sure if this is required, aren't we restricting custom logits processors to only those that cannot be configured by generation config? Something that is only defined by users for their use-case

The user always had the option of unsetting a flag and passing the corresponding processor. This change makes it less restricting: if they pass both a flag and the corresponding processor, we keep the processor and ignore the flag (previously we would throw an exception)

I'm not super fan of the new behavior, I think the exception less ambiguous (and thus preferable). However, it's needed to coexist with the main change in this PR, which is (IMO) more important.

I'll add a comment to clarify the evolution of this function, in case we need to trace back a related change :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, same feeling here. I would even restrict it to only custom logits processors in v5 to free us from the burden of maintaining correct priority/ordering etc. Looks like pandora's box what is being fixed 😿

Copy link
Contributor Author

@gante gante Mar 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would even restrict it to only custom logits processors in v5 to free us from the burden of maintaining correct priority/ordering etc.

I definitely don't want this, generate would become very uncomfortable to use 😅 A user always has the option of disabling generation_config flags and passing the processors in the order they want. But most users don't want that level of control, and yet may want an external processor

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason transformers is so popular is because we enable complex use-cases from a few lines of code

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree that this gives users more freedom to order processors the way they want by disabling generation config. Though I feel like it is not very clear to me as a user what happens under the hood, when I pass my own Temperature processor or use config. IMO we need a better docs page for advanced usage, if we allow that much freedom and expect users to know what they are doing

Users almost never 100% know what they are doing, thus open issues on GH 😆

Merge user-defined processors/criteria with the ones instantiated inside `generate`. In case the same
processor/criteria is present on both lists, use the user-defined one.

(Note: up to v4.49.0, this funtion threw an exception is the same logit processor was found twice.)
"""
if len(custom_list) == 0:
return default_list

final_list = type(default_list)()
for default in default_list:
using_custom = False
for custom in custom_list:
if type(custom) is type(default):
object_type = "stopping criteria" if isinstance(custom, StoppingCriteria) else "logits processor"
raise ValueError(
f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to"
f" `.generate()`, but it has already been created with the values {default}. {default} has been"
" created by passing the corresponding arguments to generate or by the model's config default"
f" values. If you just want to change the default values of {object_type} consider passing"
f" them as arguments to `.generate()` instead of using a custom {object_type}."
logger.warning_once(
f"A custom {object_type} of type {type(custom)} has been passed to `.generate()`, but it "
f"was also created in `.generate()`, given its parameterization. The custom {type(custom)} "
f"will take precedence. Please check the docstring of {type(custom)} to see related "
"`.generate()` flags."
)
default_list.extend(custom_list)
return default_list
final_list.append(custom)
using_custom = True
break
if not using_custom:
final_list.append(default)

for custom in custom_list:
if custom not in final_list:
final_list.append(custom)
return final_list

def compute_transition_scores(
self,
Expand Down Expand Up @@ -1573,17 +1589,28 @@ def _prepare_generation_config(
# exception will be raised in `_validate_model_kwargs`
if not is_torchdynamo_compiling():
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs)
# If `generation_config` is provided, let's fallback ALL special tokens to the default values for the model

# If `generation_config` is provided, let's fallback ALL default values to the model's generation config
# TODO (joao): per-model generation config classes.
if not using_model_generation_config:
if generation_config.bos_token_id is None:
generation_config.bos_token_id = self.generation_config.bos_token_id
if generation_config.eos_token_id is None:
generation_config.eos_token_id = self.generation_config.eos_token_id
if generation_config.pad_token_id is None:
generation_config.pad_token_id = self.generation_config.pad_token_id
if generation_config.decoder_start_token_id is None:
generation_config.decoder_start_token_id = self.generation_config.decoder_start_token_id
modified_values = {}
default_generation_config = GenerationConfig()
for key, default_value in default_generation_config.__dict__.items():
if key.startswith("_"): # metadata
continue
custom_gen_config_value = getattr(generation_config, key)
model_gen_config_value = getattr(self.generation_config, key)
if custom_gen_config_value == default_value and model_gen_config_value != default_value:
modified_values[key] = model_gen_config_value
setattr(generation_config, key, model_gen_config_value)
if len(modified_values) > 0:
logger.warning_once(
f"`generation_config` default values have been modified to match model-specific defaults: "
f"{modified_values}. If this is not desired, please set these values explicitly."
)

# Finally, apply any passed kwargs
model_kwargs = generation_config.update(**kwargs)
else:
model_kwargs = kwargs

Expand Down Expand Up @@ -1837,6 +1864,8 @@ def _prepare_cache_for_generation(
model_kwargs[cache_name] = cache_class(cache_config)
elif generation_config.cache_implementation == "offloaded":
model_kwargs[cache_name] = OffloadedCache()
elif generation_config.cache_implementation == "dynamic":
model_kwargs[cache_name] = DynamicCache()

# Use DynamicCache() instance by default. This will avoid back and forth from legacy format that
# keeps copying the cache thus using much more memory
Expand Down
2 changes: 1 addition & 1 deletion tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1162,8 +1162,8 @@ def test_beam_search_low_memory(self):
# The two outputs must match and their shape must be as expected
self._check_similar_generate_outputs(low_output, high_output)

@pytest.mark.generate
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(a few tests that failed in previous commits were incorrectly marked :) when parameterized is used, most decorators should be used after it)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw, how important is it to mark skipped test as generate? I have no idea where we use those marks, unless when trying to run only generation tests in which case skip doesn't help much

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not very important, it's more for long-term bookkeeping with the automated CI reports (how many tests we have of each mark, how much time do we spend on each mark, % of skips, ...)

@parameterized.expand([("random",), ("same",)])
@pytest.mark.generate
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
# This test ensures that the assisted generation does not introduce output changes over greedy search.
# See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535 for more info.
Expand Down
3 changes: 3 additions & 0 deletions tests/models/aya_vision/test_modeling_aya_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import unittest

import pytest
from parameterized import parameterized

from transformers import (
Expand Down Expand Up @@ -261,6 +262,7 @@ def test_eager_matches_sdpa_generate(self):
pass

@parameterized.expand([("random",), ("same",)])
@pytest.mark.generate
@unittest.skip("Cohere2 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
pass
Expand All @@ -269,6 +271,7 @@ def test_assisted_decoding_matches_greedy_search(self, assistant_type):
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
pass

@pytest.mark.generate
@unittest.skip("Cohere2 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_sample(self):
pass
Expand Down
3 changes: 3 additions & 0 deletions tests/models/cohere2/test_modeling_cohere2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import unittest

import pytest
from packaging import version
from parameterized import parameterized
from pytest import mark
Expand Down Expand Up @@ -81,6 +82,7 @@ def test_eager_matches_sdpa_generate(self):
pass

@parameterized.expand([("random",), ("same",)])
@pytest.mark.generate
@unittest.skip("Cohere2 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
pass
Expand All @@ -89,6 +91,7 @@ def test_assisted_decoding_matches_greedy_search(self, assistant_type):
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
pass

@pytest.mark.generate
@unittest.skip("Cohere2 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_sample(self):
pass
Expand Down
3 changes: 2 additions & 1 deletion tests/models/fuyu/test_modeling_fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,12 +299,13 @@ def test_training_gradient_checkpointing_use_reentrant(self):
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass

@pytest.mark.generate
@parameterized.expand([("random",), ("same",)])
@pytest.mark.generate
@unittest.skip("Fuyu doesn't support assisted generation due to the need to crop/extend image patches indices")
def test_assisted_decoding_matches_greedy_search(self):
pass

@pytest.mark.generate
@unittest.skip("Fuyu doesn't support assisted generation due to the need to crop/extend image patches indices")
def test_assisted_decoding_sample(self):
pass
Expand Down
3 changes: 3 additions & 0 deletions tests/models/gemma2/test_modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import unittest

import pytest
from packaging import version
from parameterized import parameterized
from pytest import mark
Expand Down Expand Up @@ -96,6 +97,7 @@ def test_eager_matches_sdpa_generate(self):
pass

@parameterized.expand([("random",), ("same",)])
@pytest.mark.generate
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
pass
Expand All @@ -104,6 +106,7 @@ def test_assisted_decoding_matches_greedy_search(self, assistant_type):
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
pass

@pytest.mark.generate
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_sample(self):
pass
Expand Down
37 changes: 37 additions & 0 deletions tests/models/gemma3/test_modeling_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@

import unittest

import pytest
from parameterized import parameterized

from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
Gemma3Config,
Gemma3TextConfig,
GenerationConfig,
is_torch_available,
)
from transformers.testing_utils import (
Expand Down Expand Up @@ -75,6 +77,7 @@ def test_model_outputs_equivalence(self, **kwargs):
pass

@parameterized.expand([("random",), ("same",)])
@pytest.mark.generate
@unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
pass
Expand All @@ -83,6 +86,7 @@ def test_assisted_decoding_matches_greedy_search(self, assistant_type):
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
pass

@pytest.mark.generate
@unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_sample(self):
pass
Expand Down Expand Up @@ -277,6 +281,7 @@ def test_model_outputs_equivalence(self, **kwargs):
pass

@parameterized.expand([("random",), ("same",)])
@pytest.mark.generate
@unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
pass
Expand All @@ -285,6 +290,7 @@ def test_assisted_decoding_matches_greedy_search(self, assistant_type):
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
pass

@pytest.mark.generate
@unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_sample(self):
pass
Expand Down Expand Up @@ -551,3 +557,34 @@ def test_generation_beyond_sliding_window(self, attn_implementation: str):

EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip
self.assertEqual(output_text, EXPECTED_COMPLETIONS)

def test_generation_beyond_sliding_window_with_generation_config(self):
"""
Comment on lines +561 to +562
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very very nice thanks!

Same as `test_generation_beyond_sliding_window`, but passing a GenerationConfig. Regression test for #36684 --
ensures `cache_implementation='hybrid'` is correctly inherited from the base `model.generation_config`.
"""
model_id = "gg-hf-g/gemma-3-1b-it"
attn_implementation = "sdpa"

input_text = [
"This is a nice place. " * 800 + "I really enjoy the scenery,", # This is larger than 4096 tokens
"A list of colors: red, blue", # This will almost all be padding tokens
]
tokenizer = AutoTokenizer.from_pretrained(model_id, padding="left")
inputs = tokenizer(input_text, padding=True, return_tensors="pt").to(torch_device)

model = AutoModelForCausalLM.from_pretrained(
model_id, attn_implementation=attn_implementation, torch_dtype=torch.float16
).to(torch_device)

# Make sure prefill is larger than sliding window
input_size = inputs.input_ids.shape[-1]
self.assertTrue(input_size > model.config.sliding_window)

generation_config = GenerationConfig(max_new_tokens=20)

out = model.generate(**inputs, generation_config=generation_config)[:, input_size:]
output_text = tokenizer.batch_decode(out)

EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip
self.assertEqual(output_text, EXPECTED_COMPLETIONS)
3 changes: 3 additions & 0 deletions tests/models/paligemma2/test_modeling_paligemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import unittest

import pytest
from parameterized import parameterized

from transformers import (
Expand Down Expand Up @@ -351,6 +352,7 @@ def test_beam_search_low_memory(self):
pass

@parameterized.expand([("random",), ("same",)])
@pytest.mark.generate
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
pass
Expand All @@ -359,6 +361,7 @@ def test_assisted_decoding_matches_greedy_search(self, assistant_type):
def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
pass

@pytest.mark.generate
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_sample(self):
pass
Expand Down
4 changes: 4 additions & 0 deletions tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

import unittest

import pytest

from transformers import AutoModelForCausalLM, AutoTokenizer, RecurrentGemmaConfig, is_torch_available, set_seed
from transformers.testing_utils import (
require_bitsandbytes,
Expand Down Expand Up @@ -375,6 +377,7 @@ def test_model_parallelism(self):
def test_model_parallel_beam_search(self):
pass

@pytest.mark.generate
@unittest.skip(reason="Rely on `past_key_values` to crop the assistant pkv. Not supported")
def test_assisted_decoding_matches_greedy_search(self):
pass
Expand All @@ -383,6 +386,7 @@ def test_assisted_decoding_matches_greedy_search(self):
def test_left_padding_compatibility(self):
pass

@pytest.mark.generate
@unittest.skip(reason="Relies on `past_key_values` returned by the model. Not supported with recurrent gemma")
def test_assisted_decoding_sample(self):
pass
Expand Down
1 change: 1 addition & 0 deletions tests/models/smolvlm/test_modeling_smolvlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,7 @@ def test_eager_matches_sdpa_generate(self):
pass

@parameterized.expand([("random",), ("same",)])
@pytest.mark.generate
@unittest.skip(reason="Cache position is off by one leaving out image tokens, FIXME raushan")
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
pass
Expand Down