-
Notifications
You must be signed in to change notification settings - Fork 31.3k
[Generation, Gemma 3] When passing a custom generation_config, overwrite default values with the model's base generation_config
#36684
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7221e23
3bf123f
3befb4a
8e004b7
900b273
5ac5869
e105a63
1866a47
c075e66
e7a2fa3
d56f073
9034227
e6b67ef
0b67b12
b66b0dc
a636af0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1162,8 +1162,8 @@ def test_beam_search_low_memory(self): | |
| # The two outputs must match and their shape must be as expected | ||
| self._check_similar_generate_outputs(low_output, high_output) | ||
|
|
||
| @pytest.mark.generate | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (a few tests that failed in previous commits were incorrectly marked :) when
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. btw, how important is it to mark skipped test as generate? I have no idea where we use those marks, unless when trying to run only generation tests in which case skip doesn't help much
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not very important, it's more for long-term bookkeeping with the automated CI reports (how many tests we have of each mark, how much time do we spend on each mark, % of skips, ...) |
||
| @parameterized.expand([("random",), ("same",)]) | ||
| @pytest.mark.generate | ||
| def test_assisted_decoding_matches_greedy_search(self, assistant_type): | ||
| # This test ensures that the assisted generation does not introduce output changes over greedy search. | ||
| # See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535 for more info. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,13 +16,15 @@ | |
|
|
||
| import unittest | ||
|
|
||
| import pytest | ||
| from parameterized import parameterized | ||
|
|
||
| from transformers import ( | ||
| AutoModelForCausalLM, | ||
| AutoTokenizer, | ||
| Gemma3Config, | ||
| Gemma3TextConfig, | ||
| GenerationConfig, | ||
| is_torch_available, | ||
| ) | ||
| from transformers.testing_utils import ( | ||
|
|
@@ -75,6 +77,7 @@ def test_model_outputs_equivalence(self, **kwargs): | |
| pass | ||
|
|
||
| @parameterized.expand([("random",), ("same",)]) | ||
| @pytest.mark.generate | ||
| @unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding") | ||
| def test_assisted_decoding_matches_greedy_search(self, assistant_type): | ||
| pass | ||
|
|
@@ -83,6 +86,7 @@ def test_assisted_decoding_matches_greedy_search(self, assistant_type): | |
| def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type): | ||
| pass | ||
|
|
||
| @pytest.mark.generate | ||
| @unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding") | ||
| def test_assisted_decoding_sample(self): | ||
| pass | ||
|
|
@@ -277,6 +281,7 @@ def test_model_outputs_equivalence(self, **kwargs): | |
| pass | ||
|
|
||
| @parameterized.expand([("random",), ("same",)]) | ||
| @pytest.mark.generate | ||
| @unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding") | ||
| def test_assisted_decoding_matches_greedy_search(self, assistant_type): | ||
| pass | ||
|
|
@@ -285,6 +290,7 @@ def test_assisted_decoding_matches_greedy_search(self, assistant_type): | |
| def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type): | ||
| pass | ||
|
|
||
| @pytest.mark.generate | ||
| @unittest.skip("Gemma3 has HybridCache which is not compatible with assisted decoding") | ||
| def test_assisted_decoding_sample(self): | ||
| pass | ||
|
|
@@ -551,3 +557,34 @@ def test_generation_beyond_sliding_window(self, attn_implementation: str): | |
|
|
||
| EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip | ||
| self.assertEqual(output_text, EXPECTED_COMPLETIONS) | ||
|
|
||
| def test_generation_beyond_sliding_window_with_generation_config(self): | ||
| """ | ||
|
Comment on lines
+561
to
+562
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. very very nice thanks! |
||
| Same as `test_generation_beyond_sliding_window`, but passing a GenerationConfig. Regression test for #36684 -- | ||
| ensures `cache_implementation='hybrid'` is correctly inherited from the base `model.generation_config`. | ||
| """ | ||
| model_id = "gg-hf-g/gemma-3-1b-it" | ||
| attn_implementation = "sdpa" | ||
|
|
||
| input_text = [ | ||
| "This is a nice place. " * 800 + "I really enjoy the scenery,", # This is larger than 4096 tokens | ||
| "A list of colors: red, blue", # This will almost all be padding tokens | ||
| ] | ||
| tokenizer = AutoTokenizer.from_pretrained(model_id, padding="left") | ||
| inputs = tokenizer(input_text, padding=True, return_tensors="pt").to(torch_device) | ||
|
|
||
| model = AutoModelForCausalLM.from_pretrained( | ||
| model_id, attn_implementation=attn_implementation, torch_dtype=torch.float16 | ||
| ).to(torch_device) | ||
|
|
||
| # Make sure prefill is larger than sliding window | ||
| input_size = inputs.input_ids.shape[-1] | ||
| self.assertTrue(input_size > model.config.sliding_window) | ||
|
|
||
| generation_config = GenerationConfig(max_new_tokens=20) | ||
|
|
||
| out = model.generate(**inputs, generation_config=generation_config)[:, input_size:] | ||
| output_text = tokenizer.batch_decode(out) | ||
|
|
||
| EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip | ||
| self.assertEqual(output_text, EXPECTED_COMPLETIONS) | ||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes in this function are secondary to the main change:
.generate()(i.e. assumes the user knows what they are doing when the passlogits_processorsto.generate()instead of crashing)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A tiny comment wouldn't hurt for future us, isn't very easy to get why we do this without reading PR description.
Also I am not sure if this is required, aren't we restricting custom logits processors to only those that cannot be configured by generation config? Something that is only defined by users for their use-case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The user always had the option of unsetting a flag and passing the corresponding processor. This change makes it less restricting: if they pass both a flag and the corresponding processor, we keep the processor and ignore the flag (previously we would throw an exception)
I'm not super fan of the new behavior, I think the exception less ambiguous (and thus preferable). However, it's needed to coexist with the main change in this PR, which is (IMO) more important.
I'll add a comment to clarify the evolution of this function, in case we need to trace back a related change :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, same feeling here. I would even restrict it to only custom logits processors in v5 to free us from the burden of maintaining correct priority/ordering etc. Looks like pandora's box what is being fixed 😿
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I definitely don't want this,
generatewould become very uncomfortable to use 😅 A user always has the option of disablinggeneration_configflags and passing the processors in the order they want. But most users don't want that level of control, and yet may want an external processorThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason
transformersis so popular is because we enable complex use-cases from a few lines of codeThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I agree that this gives users more freedom to order processors the way they want by disabling generation config. Though I feel like it is not very clear to me as a user what happens under the hood, when I pass my own Temperature processor or use config. IMO we need a better docs page for advanced usage, if we allow that much freedom and expect users to know what they are doing
Users almost never 100% know what they are doing, thus open issues on GH 😆