Skip to content

Conversation

Cyrilvallez
Copy link
Member

@Cyrilvallez Cyrilvallez commented Jul 30, 2025

What does this PR do?

Big simplifications everywhere, but most notably:

  • all caches are initialized lazily -> no more issues of devices with device_map, which would lead to breaking the Static dynamo addresses due to device movement + no issue of dimensions with TP + much simpler to prepare for generate (all properties are derived at first update time) -> simpler and more efficient (no device copies)
  • early_initialization provides a way to init everything before update is called -> this is needed for export as we can't trace correctly if initialization is lazy
  • removed CacheProcessor -> QuantizedProcessor should be QuantizedLayers instead, and offloading alone does not justify the Processor boilerplate -> much easier to have offloading as part of the Layer and Cache themselves (it's also much more robust now regarding devices)
  • Hybrid and HybridChunked now check for chunk_attention_size correctly again (it was lost before which would break Llama4)
  • code much easier to follow and understand -> more maintainable
  • this is also a big step towards completely removing the cache_position, which would simplify the library a lot, and will come in a follow-up PR

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@Cyrilvallez Cyrilvallez changed the title Simplify/make more explicit the caching logic Rework the Cache logic to make it simpler and more general Jul 31, 2025
@Cyrilvallez Cyrilvallez changed the title Rework the Cache logic to make it simpler and more general Refactor the Cache logic to make it simpler and more general Aug 1, 2025
@Cyrilvallez
Copy link
Member Author

Slow tests are the same on this PR and on main for llama (arguably most important model), mistral (most tested sliding window model), gemma2 (most tested hybrid model), gemma3 (hybrid model), llama4 (hybrid chunked) - so all scenarios are green!

@gante
Copy link
Member

gante commented Aug 5, 2025

@Cyrilvallez 🔥🔥 PR

nano requests:

  1. add early_initialization and lazy_initialization to the documented methods (in [[autodoc]] Cache)
  2. We can't throw informative exceptions at compilation time. But we can mitigate related problems with comments: In the docstring of lazy_initialization, let's mention that this function can never be called at compilation time, and that early_initialization should be called ahead of compilation instead 🙏

1 + 2 = if users open issues about manual compilation, we can link to the docstrings :D

@Cyrilvallez
Copy link
Member Author

Cyrilvallez commented Aug 5, 2025

Hey @gante! Done with 1. 👌
Concerning 2., we actually CAN correctly compile even lazy_initialization, that's the beauty of it! Of course, this is not efficient and leads to recompiles as prefill should not be compiled for performances, but it works!
You can try on the following snippet:

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model_id = "google/gemma-2-2b"
torch_device = 0

EXPECTED_COMPLETIONS = [
    " the people, the food, the culture, the history, the music, the art, the architecture",
    ", green, yellow, orange, purple, pink, brown, black, white, gray, silver",
]

input_text = [
    "This is a nice place. " * 800 + "I really enjoy the scenery,",  # This is larger than 4096 tokens
    "A list of colors: red, blue",  # This will almost all be padding tokens
]
tokenizer = AutoTokenizer.from_pretrained(model_id, padding="left")
inputs = tokenizer(input_text, padding=True, return_tensors="pt").to(torch_device)

model = AutoModelForCausalLM.from_pretrained(
    model_id, attn_implementation="sdpa", torch_dtype=torch.bfloat16, device_map=torch_device
)
# Compile the full forward! So call to `lazy_initialization` will be compiled as well
model.compile(fullgraph=True)

# Make sure prefill is larger than sliding window
input_size = inputs.input_ids.shape[-1]
assert input_size > model.config.sliding_window

out = model.generate(**inputs, max_new_tokens=20, cache_implementation="hybrid")[:, input_size:]
output_text = tokenizer.batch_decode(out)

assert output_text == EXPECTED_COMPLETIONS

and you'll see it works! Only compilation with cuda graph will fail, e.g. mode="reduce-overhead" will fail, but this is not guaranteed at all to give better perfs anyway!

@Cyrilvallez Cyrilvallez changed the title Refactor the Cache logic to make it simpler and more general [core] Refactor the Cache logic to make it simpler and more general Aug 5, 2025
@gante
Copy link
Member

gante commented Aug 5, 2025

@Cyrilvallez fair! (I would still make a note of it though, for power users that don't rely on generate 😮 )

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nothing much to say: very very clean! 🤗

  • compiling the full call may not be the best: wrt accelerate and hooks in general, would rather just compile the call!

Comment on lines +60 to +64
def offload(self):
"""Offload this layer's data to CPU device."""
if self.keys is not None:
self.keys = self.keys.to("cpu", non_blocking=True)
self.values = self.values.to("cpu", non_blocking=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think for cuda we needed / have better perfs with a different strea

Comment on lines +736 to +737
offload_only_non_sliding (`bool`, *optional*, defaults to `True`):
If `offloading` is `True`, this further decides if only the non-sliding layers will be offloaded (because
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice comment

Comment on lines +785 to +786
with self.prefetch_stream if _is_torch_greater_or_equal_than_2_7 else torch.cuda.stream(self.prefetch_stream):
self.layers[layer_idx].prefetch()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah ok there's the stream

@ArthurZucker
Copy link
Collaborator

run-slow: bamba, dia, falcon_h1, gptj, granitemoehybrid, jamba, kyutai_speech_to_text, lfm2, musicgen, musicgen_melody, rag, zamba, zamba2

Copy link
Contributor

github-actions bot commented Aug 8, 2025

This comment contains run-slow, running the specified jobs:

models: ['models/bamba', 'models/dia', 'models/falcon_h1', 'models/gptj', 'models/granitemoehybrid', 'models/jamba', 'models/kyutai_speech_to_text', 'models/lfm2', 'models/musicgen', 'models/musicgen_melody', 'models/rag', 'models/zamba', 'models/zamba2']
quantizations: [] ...

Copy link
Contributor

github-actions bot commented Aug 8, 2025

[For maintainers] Suggested jobs to run (before merge)

run-slow: bamba, dia, falcon_h1, gptj, granitemoehybrid, jamba, kyutai_speech_to_text, lfm2, musicgen, musicgen_melody, rag, zamba, zamba2

@Cyrilvallez
Copy link
Member Author

Slow tests are similar as main, merging!

@Cyrilvallez Cyrilvallez merged commit dc11a3c into main Aug 8, 2025
23 of 25 checks passed
@Cyrilvallez Cyrilvallez deleted the explicit-cache branch August 8, 2025 12:47
BenjaminBossan added a commit to BenjaminBossan/peft that referenced this pull request Aug 12, 2025
Resolves current CI errors with prefix tuning.

Due to some recent changes in transformers (surfaced by
huggingface/transformers#39797), checking
hasattr(cache, max_cache_len) results in an error:

>>> cache = DynamicCache()
>>> hasattr(cache, "max_cache_len")
Traceback (most recent call last):
  File "/home/name/work/forks/transformers/foo.py", line 9, in <module>
    hasattr(cache, "max_cache_len")
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/name/work/forks/transformers/src/transformers/cache_utils.py", line 916, in max_cache_len
    return max(values)
           ^^^^^^^^^^^
ValueError: max() iterable argument is empty

This has been reported and will be fixed in transformers. On the PEFT
side, it is safeest check the cache type and avoid accessing this
attribute in the first place, which is what this PR does.

Morever, that PR also changed the argument order to initialize
HybridCache (will probably also be reverted in transformers), which is
also taken into account in this PR by only using keyword arguments.
@Cyrilvallez Cyrilvallez mentioned this pull request Aug 13, 2025
This was referenced Aug 15, 2025
BenjaminBossan added a commit to huggingface/peft that referenced this pull request Aug 21, 2025
Resolves current CI errors with prefix tuning.

Due to some recent changes in transformers (surfaced by
huggingface/transformers#39797), checking
hasattr(cache, max_cache_len) results in an error. This PR fixes it.

Morever, that PR also changed the argument order to initialize
HybridCache (will probably also be reverted in transformers), which is
also taken into account in this PR by only using keyword arguments.

Finally, HybridCache will be deprecated and later removed, so move the
import inside a version guard.
csqaiub added a commit to csqaiub/peft that referenced this pull request Sep 28, 2025
Resolves current CI errors with prefix tuning.

Due to some recent changes in transformers (surfaced by
huggingface/transformers#39797), checking
hasattr(cache, max_cache_len) results in an error. This PR fixes it.

Morever, that PR also changed the argument order to initialize
HybridCache (will probably also be reverted in transformers), which is
also taken into account in this PR by only using keyword arguments.

Finally, HybridCache will be deprecated and later removed, so move the
import inside a version guard.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants