Skip to content

Conversation

@gante
Copy link
Contributor

@gante gante commented Mar 12, 2025

What does this PR do?

See title.

For instance, gemma 3 models have cache_implementation="hybrid" by default but, if we pass generation_config=GenerationConfig() (i.e. default parameters) the code will crash because a hybrid cache is not used. In other words, let's assume a user wants to use the base parameterization by the model creators, and use model-specific defaults as opposed to global defaults.

Original testing script, crashing on main: (by @NathanHB)

from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
import torch

def main():
    model = "google/gemma-3-1b-it"
    revision = "e735e8d98f6d2ccdb3bdfc43ac1c252bebb2527f"
    dtype = "bfloat16"
    tokenizer = AutoTokenizer.from_pretrained(model)
    model = AutoModelForCausalLM.from_pretrained(model, revision=revision, torch_dtype=dtype, device_map="cuda:0")
    prompt = """Solve the following math problem efficiently and clearly.  The last line of your response should be of the following format: 'Therefore, the final answer is: $\boxed{ANSWER}$. I hope it is correct' (without quotes) where ANSWER is just the final number or expression that solves the problem. Think step by step before answering.

Alice chooses a set $A$ of positive integers. Then Bob lists all finite nonempty sets $B$ of positive integers with the property that the maximum element of $B$ belongs to $A$. Bob's list has 2024 sets. Find the sum of the elements of A.
        """.strip()

    chat = [{
        "content": prompt,
        "role": "user",
    }]

    inputs = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
    print(inputs)
    inputs = tokenizer(inputs, return_tensors="pt").to(model.device)
    print("=== DECODING ===")
    generation_config = GenerationConfig(max_new_tokens=2048, temperature=1.0, do_sample=True, top_k=64, top_p=0.95)
    outputs = model.generate(**inputs, generation_config=generation_config)
    outputs = tokenizer.decode(outputs[0], skip_special_tokens=False)

    print(outputs)

if __name__ == "__main__":
    main()

@github-actions
Copy link
Contributor

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. When it is ready for review, please click the Ready for review button (at the bottom of the PR page).

@github-actions github-actions bot marked this pull request as draft March 12, 2025 18:19
@gante gante marked this pull request as ready for review March 12, 2025 18:20
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

default_list: Union[LogitsProcessorList, StoppingCriteriaList],
custom_list: Union[LogitsProcessorList, StoppingCriteriaList],
) -> Union[LogitsProcessorList, StoppingCriteriaList]:
"""
Copy link
Contributor Author

@gante gante Mar 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes in this function are secondary to the main change:

  • whisper breaks because it both sets custom logits processors AND has the default flags in the generation config to instantiate them
  • after the original change (inherit defaults from the model's generation config), we were throwing an exception here
  • after this secondary change, we only throw a warning and discard the logits processor instance created inside .generate() (i.e. assumes the user knows what they are doing when the pass logits_processors to .generate() instead of crashing)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A tiny comment wouldn't hurt for future us, isn't very easy to get why we do this without reading PR description.

Also I am not sure if this is required, aren't we restricting custom logits processors to only those that cannot be configured by generation config? Something that is only defined by users for their use-case

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I am not sure if this is required, aren't we restricting custom logits processors to only those that cannot be configured by generation config? Something that is only defined by users for their use-case

The user always had the option of unsetting a flag and passing the corresponding processor. This change makes it less restricting: if they pass both a flag and the corresponding processor, we keep the processor and ignore the flag (previously we would throw an exception)

I'm not super fan of the new behavior, I think the exception less ambiguous (and thus preferable). However, it's needed to coexist with the main change in this PR, which is (IMO) more important.

I'll add a comment to clarify the evolution of this function, in case we need to trace back a related change :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, same feeling here. I would even restrict it to only custom logits processors in v5 to free us from the burden of maintaining correct priority/ordering etc. Looks like pandora's box what is being fixed 😿

Copy link
Contributor Author

@gante gante Mar 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would even restrict it to only custom logits processors in v5 to free us from the burden of maintaining correct priority/ordering etc.

I definitely don't want this, generate would become very uncomfortable to use 😅 A user always has the option of disabling generation_config flags and passing the processors in the order they want. But most users don't want that level of control, and yet may want an external processor

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason transformers is so popular is because we enable complex use-cases from a few lines of code

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree that this gives users more freedom to order processors the way they want by disabling generation config. Though I feel like it is not very clear to me as a user what happens under the hood, when I pass my own Temperature processor or use config. IMO we need a better docs page for advanced usage, if we allow that much freedom and expect users to know what they are doing

Users almost never 100% know what they are doing, thus open issues on GH 😆

@gante gante requested review from zucchini-nlp and removed request for ArthurZucker March 12, 2025 18:58
# The two outputs must match and their shape must be as expected
self._check_similar_generate_outputs(low_output, high_output)

@pytest.mark.generate
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(a few tests that failed in previous commits were incorrectly marked :) when parameterized is used, most decorators should be used after it)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw, how important is it to mark skipped test as generate? I have no idea where we use those marks, unless when trying to run only generation tests in which case skip doesn't help much

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not very important, it's more for long-term bookkeeping with the automated CI reports (how many tests we have of each mark, how much time do we spend on each mark, % of skips, ...)

@gante gante changed the title [Generation] When passing a custom generation_config, overwrite default values with the model's base generation_config [Generation, Gemma 3] When passing a custom generation_config, overwrite default values with the model's base generation_config Mar 13, 2025
Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting bug, given that the default cache implementation is None. So with user defined config, we're overriding everything back to None?

This PR can work as a short-term solution but we're over-complicating thing too much in generate imo. Setting generation config values to None is fine for most cases, but I see at least one edge case. Suppose a model saves generation config with cache_implementation='static' and the users wants to override it by passing a config with explicitly set cache_implementation=None, because user wants dynamic cache. The future solution wouldn't work for this case

Custom vs model config issue is also relevant to pretrained config, and we kinda solve that issue by asking users to pass kwargs dict, i.e. we're 100% sure which values user wants to override. Just leaving here as random though hehe, this would address the above edge case

I totally agree we need a robust solution, but seems like whatever we do might have edge cases and will be breaking 😢

default_list: Union[LogitsProcessorList, StoppingCriteriaList],
custom_list: Union[LogitsProcessorList, StoppingCriteriaList],
) -> Union[LogitsProcessorList, StoppingCriteriaList]:
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A tiny comment wouldn't hurt for future us, isn't very easy to get why we do this without reading PR description.

Also I am not sure if this is required, aren't we restricting custom logits processors to only those that cannot be configured by generation config? Something that is only defined by users for their use-case

# The two outputs must match and their shape must be as expected
self._check_similar_generate_outputs(low_output, high_output)

@pytest.mark.generate
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw, how important is it to mark skipped test as generate? I have no idea where we use those marks, unless when trying to run only generation tests in which case skip doesn't help much

@gante
Copy link
Contributor Author

gante commented Mar 14, 2025

@zucchini-nlp I'm also not happy with the state of parameterization, but I disagree with a few of your points. Let me split my comment into parts, starting with why I believe this is the right change in the short term.

First, an overview of our current status:

  1. model Config and GenerationConfig are intertwined and we can't fully separate them without breaking BC (and it would be very breaking, old working code parametrizes generate through Config).
  2. We don't have per-model GenerationConfig, and we are piggybacking default parameterization through Config. More on this later, see long-term plans at the end.
  3. GenerationConfig has many flags, and it's not reasonable to expect the user to read the full docs.
  4. Likewise, it's not reasonable that a user knows the full compatibility between the model and generate.

Short-term issue

When a user creates a model Config, the initialization is, in essence, a diff to the model's default parameterization. In other words, if we do

config = LlamaConfig(hidden_size=512)

config will have all other fields set to model-specific defaults, because we have per-model Config classes. However, if we do

generation_config = GenerationConfig(do_sample=True)

the initialization is model-agnostic. In other words, initializing GenerationConfig this way loses all model-specific parameterization, ⚠️ even if the model config sets generate-specific args ⚠️. This shouldn't happen for two reasons:
a. The two configs have different assumptions: one follows the model's defaults, the other ignores them.
b. Because it is lacking model information, a default GenerationConfig may cause generate to crash

This leaves us with two short-term solutions:
i. Add more model-level validation to inform the user which flags they need to change. This validation would need to be set in each modeling class.
ii. Shift the model-agnostic assumption of GenerationConfig where it is possible (when generation_config has visibility of the model).

This PR is (ii.) above.

Now, into your comment

Interesting bug, given that the default cache implementation is None. So with user defined config, we're overriding everything back to None?

This PR does not do that. If a user passes GenerationConfig(cache_implementation=None) or GenerationConfig() and the model has cache_implementation = foo by default, the resulting cache_implementation will be foo. The global default is replaced by the model default.

Suppose a model saves generation config with cache_implementation='static' and the users wants to override it by passing a config with explicitly set cache_implementation=None, because user wants dynamic cache. The future solution wouldn't work for this case

It is simple to fix: we add the specific case of dynamic to the list of values in cache_implementation so users can override it. But there are definitely edge cases (e.g. the model owner saves num_beams=4 and the user sets num_beams=1). I consider these edge cases a smaller problem than the problem this PR fixes, which is shift in the right direction (start any parameterization from model-specifc defaults, like we do in Config).

In general, we haven't been rigorous following the good pattern of defaulting to None, with None corresponding to "not set by the user", which would have made a PR like this free of conflicts.

Custom vs model config issue is also relevant to pretrained config, and we kinda solve that issue by asking users to pass kwargs dict, i.e. we're 100% sure which values user wants to override. Just leaving here as random though hehe, this would address the above edge case

That would move us in the opposite direction I want to move 😉 With a well-defined config we can activate more advanced features like caching, hashing, etc, which are useful for advanced use-cases (multi-device, compilation, ...)


Long-term issue

To wrap this very long comment: we're seeing more and more models with generation-specific parameterization, so we will definitely need model-specific GenerationConfig.

In a nutshell, the long-term plan is:

  1. Create model-specific GenerationConfig -- not so much to add new flags, but to hold the right defaults
  2. Create AutoGenerationConfig, from_pretrained() loads from the right class
  3. Discourage the use of the generic GenerationConfig
  4. Now that we have model-specific GenerationConfig, deprecate setting any form of generate flags from Config and finally break BC with a long deprecation cycle
  5. Delete all shenanigans like this PR

But this will be a long piece of work, not something I can sort in a few hours :) Meanwhile, I'd like us to move towards model-level generation defaults whenever possible.

@zucchini-nlp
Copy link
Member

even if the model config sets generate-specific args

OMG, so many dependencies in generation config. Yeah, this comment makes total sense and agree with the solution for short-tem. We have been adding a lot of stuff without checking for robustness on edge cases like gemma2. My main concern was about the long-term plan to make generation config stable across model types, as we''ll be getting more models with hardcoded generation values. At least from VLM side, we use static cache for image generation. Interestingly, using dynamic cache degardes quality for some models, no idea why yet

Now that we have model-specific GenerationConfig, deprecate setting any form of generate flags from Config and finally break BC with a long deprecation cycle

THIS! ❤️ Love the plan, and looking forward to getting things sorted out. I feel like even just adding model-specific generation config will solve many issues/hacky workarunds

This PR does not do that. If a user passes GenerationConfig(cache_implementation=None) or GenerationConfig() and the model has cache_implementation = foo by default, the resulting cache_implementation will be foo. The global default is replaced by the model default.

Yeah, I meant before this PR we've been setting all cache to None which caused issues when generating. Trying to get the first root cause

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving, with the note to refactor long-term in the future. Thanks for digging into this issue, and for the detailed plan 💛

@gante
Copy link
Contributor Author

gante commented Mar 15, 2025

@zucchini-nlp fyi, before merging, I've added:

  • a TODO with the long term plan
  • the option to pass cache_implementation="dynamic"
  • a warning when the changes of this PR kick in (i.e. when the model default overrites flags) and what to do about it if it is not desired

Screenshot 2025-03-15 at 11 58 16

@gante gante merged commit fc8764c into huggingface:main Mar 15, 2025
23 checks passed
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice that you jumped quickly on this one! thanks

Comment on lines +561 to +562
def test_generation_beyond_sliding_window_with_generation_config(self):
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very very nice thanks!

@gante gante deleted the inherit_values_from_base_generation_config branch March 17, 2025 10:36
@yaswanth19 yaswanth19 mentioned this pull request Mar 18, 2025
5 tasks
@gante gante added the for patch Tag issues / labels that should be included in the next patch label Mar 19, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

for patch Tag issues / labels that should be included in the next patch

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants