-
Notifications
You must be signed in to change notification settings - Fork 30.7k
[cache refactor] Move all the caching logic to a per-layer approach #39106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[cache refactor] Move all the caching logic to a per-layer approach #39106
Conversation
…yeredCache (huggingface#38077) - Introduces CacheLayer and Cache base classes - Ports Static, Dynamic, Offloaded, Quantized, Hybrid, etc. to use layers - Implements method/attr dispatch across layers to reduce boilerplate - Adds CacheProcessor hooks for offloading, quantization, etc. - Updates and passes tests
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
a2fe24c
to
04d7a0b
Compare
@zucchini-nlp I noticed I was breaking QuantizedCaches. It was hard to spot because no test covered |
d97a02d
to
26c28af
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super super cool, glad to see the cache being refactored. Left a few comments in quant cache, I think it is not same as in current main
Can we also check that the generation with main
vs with PR branch
are identical when low-bit quantizing and generating much longer than residual
length. Not in a test case, but as a sanity check. It's been long since I looked at this cache
src/transformers/cache_utils.py
Outdated
"config and it's not set to None." | ||
) | ||
# Adjust max_cache_len for sliding window layers (they can't be larger than sliding window) | ||
max_cache_len = max_cache_len or config.max_position_embeddings |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
config.max_position_embeddings
doesn't always reflect the actual max length a model can handle and I think sometimes it's filled with non-sense values
Maybe we should have a default max_cache_length
instead, wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm good catch. This comes from
transformers/src/transformers/cache_utils.py
Line 1615 in 1283877
self.max_cache_len = max_cache_len if max_cache_len is not None else config.max_position_embeddings |
This is only relevant for StaticCaches, which are initialized purposedly for torch.compile, so probably is very uncommon for this param to be unitialized. The problem with setting a super-high default is that we will allocate an equally big tensor for the static cache. Are the non-sense values too big usually? Or what do you mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I meant the values are super large. For example examining a random llama shows 10M tokens. Though technically with RoPE it can go to arbitrarily large sequence length, maybe we should better have a default of 1024/2048
tokens in Static Caches?
https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/config.json#L28
I realize it's not really an issue of this PR, but since we started the major clean-up it's time to bring up the discussion 😄
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe its to set to something like 2048 with a warning? cc @gante
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! A good improvement 🤗
src/transformers/cache_utils.py
Outdated
``` | ||
""" | ||
|
||
def __init__( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good for BC, but maybe we want to move forward with just specifying "cache_processor="offloaded"
wdyt @Cyrilvallez
all done @ArthurZucker ! thanks for the detailed review, learning a lot from you |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good first steps! There's still some work to be done, I'm separating my feedback below by sections.
Dependencies: Let's merge #38086 before we continue this review, this diff is very polluted 👀
Complexity: This PR introduces a lot of complexity, which is the root of many long-term maintenance problems. Let's try to simplify as much as we can 🙏 (e.g. kill KVProxy
)
Cache configs: From a model implementation perspective, all it needs to define is the type of cache layers it expects (a list) and, in some edge cases, a few additional cache kwargs (a dict). I think we would do ourselves a favor if we don't use CacheConfig
classes: fewer classes to maintain, simpler cache-related model.config
parameters, fewer tests. See complexity comment above. [Let's move this discussion to Slack]
Documentation: In the diff, we remove the docs for all methods of the caches. This means users won't be able to easily access information about available methods. Since we expect users to import and use Cache
(or subclasses of it), we should make sure Cache
has all its methods in the docs. We could also write something like See the documentation of Cache for shared methods
on all other caches' docstrings, to avoid repeating the same information in our docs.
Docstrings: Some classes miss the __init__
arguments in their docstring (e.g. CacheLayer
), and some non-dunder methods don't have docstrings with args/return (e.g. from_kv
). Let's make sure we dot our i's and cross our t's after we settle on the interfaces :)
@manueldeprada some comments above are regarding code blocks that no longer exist, feel free to just resolve them if they no longer apply 🤗 |
d684339
to
4c03e0f
Compare
5dc5fb4
to
a6b7562
Compare
a6b7562
to
27916bc
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, looks nice now! Small nits
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
file is super long, we might split cache_utils/configuration_utils|utils|layers
? @gante
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lmk! 500 out 2500 LOC are deprecated, including the configuration classes.
We could split the remaining 2000LOC into caches(1000 LOC) and layers (400 LOC) + processors (600 LOC)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Last nit
[For maintainers] Suggested jobs to run (before merge) run-slow: bamba, bart, bigbird_pegasus, biogpt, blenderbot, blenderbot_small, dia, falcon_h1, gemma3n, gptj, granitemoehybrid, informer, jamba, lfm2, longt5, m2m_100 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright, merging! 🤗🚀
Nice work!
Our vLLM CI tests for hybrid models that compare against transformers are broken since we upgraded to latest transformers, and I've traced the issue to this PR. Are there any changes in this PR that could cause the generated tokens to be different? |
@Cyrilvallez I can confirm that using current main, the test I'm using to debug is passing. I will check the rest now, but looks good. |
Can confirm all correctness tests pass using latest main. There is one test still failing but due to some other, different issue I will report separately. |
Correction: it looks like all tests involving mamba2 are passing. But tests involving mamba1 still have mismatching output. |
Can you provide the gen code that produces the mismatch? |
@tdoublep which mamba1 tests? I ran all in models/language/generation/test_hybrid.py::test_models and they pass with cyril's PR |
@manueldeprada Apologies - I should have followed up here. I can't reproduce that mamba1 failure anymore. It must have been something else related to my dev env. I can confirm that latest transformers main looks good from vLLM hybrid test perspective. |
…uggingface#39106) * Squash for refactor: Replace monolithic cache classes with modular LayeredCache (huggingface#38077) - Introduces CacheLayer and Cache base classes - Ports Static, Dynamic, Offloaded, Quantized, Hybrid, etc. to use layers - Implements method/attr dispatch across layers to reduce boilerplate - Adds CacheProcessor hooks for offloading, quantization, etc. - Updates and passes tests * fix quantized, add tests * remove CacheProcessorList * raushan review, arthur review * joao review: minor things * remove cache configs, make CacheLayer a mixin (joaos review) * back to storage inside Cache() * remove cachebase for decorator * no more __getattr__ * fix tests * joaos review except docs * fix ast deprecations for python 3.14: replace node.n by node.value and use `ast.Constant` More verbose exceptions in `fix_docstring` on docstring formatting issues. * Revert "back to storage inside Cache()" This reverts commit 27916bc. * cyril review * simplify cache export * fix lfm2 cache * HybridChunked to layer * BC proxy object for cache.key_cache[i]=... * reorder classes * bfff come on LFM2 * better tests for hybrid and hybridChunked * complete coverage for hybrid chunked caches (prefill chunking) * reimplementing HybridChunked * cyril review * fix ci * docs for cache refactor * docs * oopsie * oopsie * fix after merge * cyril review * arthur review * opsie * fix lfm2 * opsie2
…uggingface#39106) * Squash for refactor: Replace monolithic cache classes with modular LayeredCache (huggingface#38077) - Introduces CacheLayer and Cache base classes - Ports Static, Dynamic, Offloaded, Quantized, Hybrid, etc. to use layers - Implements method/attr dispatch across layers to reduce boilerplate - Adds CacheProcessor hooks for offloading, quantization, etc. - Updates and passes tests * fix quantized, add tests * remove CacheProcessorList * raushan review, arthur review * joao review: minor things * remove cache configs, make CacheLayer a mixin (joaos review) * back to storage inside Cache() * remove cachebase for decorator * no more __getattr__ * fix tests * joaos review except docs * fix ast deprecations for python 3.14: replace node.n by node.value and use `ast.Constant` More verbose exceptions in `fix_docstring` on docstring formatting issues. * Revert "back to storage inside Cache()" This reverts commit 27916bc. * cyril review * simplify cache export * fix lfm2 cache * HybridChunked to layer * BC proxy object for cache.key_cache[i]=... * reorder classes * bfff come on LFM2 * better tests for hybrid and hybridChunked * complete coverage for hybrid chunked caches (prefill chunking) * reimplementing HybridChunked * cyril review * fix ci * docs for cache refactor * docs * oopsie * oopsie * fix after merge * cyril review * arthur review * opsie * fix lfm2 * opsie2
…uggingface#39106) * Squash for refactor: Replace monolithic cache classes with modular LayeredCache (huggingface#38077) - Introduces CacheLayer and Cache base classes - Ports Static, Dynamic, Offloaded, Quantized, Hybrid, etc. to use layers - Implements method/attr dispatch across layers to reduce boilerplate - Adds CacheProcessor hooks for offloading, quantization, etc. - Updates and passes tests * fix quantized, add tests * remove CacheProcessorList * raushan review, arthur review * joao review: minor things * remove cache configs, make CacheLayer a mixin (joaos review) * back to storage inside Cache() * remove cachebase for decorator * no more __getattr__ * fix tests * joaos review except docs * fix ast deprecations for python 3.14: replace node.n by node.value and use `ast.Constant` More verbose exceptions in `fix_docstring` on docstring formatting issues. * Revert "back to storage inside Cache()" This reverts commit 27916bc. * cyril review * simplify cache export * fix lfm2 cache * HybridChunked to layer * BC proxy object for cache.key_cache[i]=... * reorder classes * bfff come on LFM2 * better tests for hybrid and hybridChunked * complete coverage for hybrid chunked caches (prefill chunking) * reimplementing HybridChunked * cyril review * fix ci * docs for cache refactor * docs * oopsie * oopsie * fix after merge * cyril review * arthur review * opsie * fix lfm2 * opsie2
…uggingface#39106) * Squash for refactor: Replace monolithic cache classes with modular LayeredCache (huggingface#38077) - Introduces CacheLayer and Cache base classes - Ports Static, Dynamic, Offloaded, Quantized, Hybrid, etc. to use layers - Implements method/attr dispatch across layers to reduce boilerplate - Adds CacheProcessor hooks for offloading, quantization, etc. - Updates and passes tests * fix quantized, add tests * remove CacheProcessorList * raushan review, arthur review * joao review: minor things * remove cache configs, make CacheLayer a mixin (joaos review) * back to storage inside Cache() * remove cachebase for decorator * no more __getattr__ * fix tests * joaos review except docs * fix ast deprecations for python 3.14: replace node.n by node.value and use `ast.Constant` More verbose exceptions in `fix_docstring` on docstring formatting issues. * Revert "back to storage inside Cache()" This reverts commit 27916bc. * cyril review * simplify cache export * fix lfm2 cache * HybridChunked to layer * BC proxy object for cache.key_cache[i]=... * reorder classes * bfff come on LFM2 * better tests for hybrid and hybridChunked * complete coverage for hybrid chunked caches (prefill chunking) * reimplementing HybridChunked * cyril review * fix ci * docs for cache refactor * docs * oopsie * oopsie * fix after merge * cyril review * arthur review * opsie * fix lfm2 * opsie2
…uggingface#39106) * Squash for refactor: Replace monolithic cache classes with modular LayeredCache (huggingface#38077) - Introduces CacheLayer and Cache base classes - Ports Static, Dynamic, Offloaded, Quantized, Hybrid, etc. to use layers - Implements method/attr dispatch across layers to reduce boilerplate - Adds CacheProcessor hooks for offloading, quantization, etc. - Updates and passes tests * fix quantized, add tests * remove CacheProcessorList * raushan review, arthur review * joao review: minor things * remove cache configs, make CacheLayer a mixin (joaos review) * back to storage inside Cache() * remove cachebase for decorator * no more __getattr__ * fix tests * joaos review except docs * fix ast deprecations for python 3.14: replace node.n by node.value and use `ast.Constant` More verbose exceptions in `fix_docstring` on docstring formatting issues. * Revert "back to storage inside Cache()" This reverts commit 27916bc. * cyril review * simplify cache export * fix lfm2 cache * HybridChunked to layer * BC proxy object for cache.key_cache[i]=... * reorder classes * bfff come on LFM2 * better tests for hybrid and hybridChunked * complete coverage for hybrid chunked caches (prefill chunking) * reimplementing HybridChunked * cyril review * fix ci * docs for cache refactor * docs * oopsie * oopsie * fix after merge * cyril review * arthur review * opsie * fix lfm2 * opsie2
…uggingface#39106) * Squash for refactor: Replace monolithic cache classes with modular LayeredCache (huggingface#38077) - Introduces CacheLayer and Cache base classes - Ports Static, Dynamic, Offloaded, Quantized, Hybrid, etc. to use layers - Implements method/attr dispatch across layers to reduce boilerplate - Adds CacheProcessor hooks for offloading, quantization, etc. - Updates and passes tests * fix quantized, add tests * remove CacheProcessorList * raushan review, arthur review * joao review: minor things * remove cache configs, make CacheLayer a mixin (joaos review) * back to storage inside Cache() * remove cachebase for decorator * no more __getattr__ * fix tests * joaos review except docs * fix ast deprecations for python 3.14: replace node.n by node.value and use `ast.Constant` More verbose exceptions in `fix_docstring` on docstring formatting issues. * Revert "back to storage inside Cache()" This reverts commit 27916bc. * cyril review * simplify cache export * fix lfm2 cache * HybridChunked to layer * BC proxy object for cache.key_cache[i]=... * reorder classes * bfff come on LFM2 * better tests for hybrid and hybridChunked * complete coverage for hybrid chunked caches (prefill chunking) * reimplementing HybridChunked * cyril review * fix ci * docs for cache refactor * docs * oopsie * oopsie * fix after merge * cyril review * arthur review * opsie * fix lfm2 * opsie2
…uggingface#39106) * Squash for refactor: Replace monolithic cache classes with modular LayeredCache (huggingface#38077) - Introduces CacheLayer and Cache base classes - Ports Static, Dynamic, Offloaded, Quantized, Hybrid, etc. to use layers - Implements method/attr dispatch across layers to reduce boilerplate - Adds CacheProcessor hooks for offloading, quantization, etc. - Updates and passes tests * fix quantized, add tests * remove CacheProcessorList * raushan review, arthur review * joao review: minor things * remove cache configs, make CacheLayer a mixin (joaos review) * back to storage inside Cache() * remove cachebase for decorator * no more __getattr__ * fix tests * joaos review except docs * fix ast deprecations for python 3.14: replace node.n by node.value and use `ast.Constant` More verbose exceptions in `fix_docstring` on docstring formatting issues. * Revert "back to storage inside Cache()" This reverts commit 27916bc. * cyril review * simplify cache export * fix lfm2 cache * HybridChunked to layer * BC proxy object for cache.key_cache[i]=... * reorder classes * bfff come on LFM2 * better tests for hybrid and hybridChunked * complete coverage for hybrid chunked caches (prefill chunking) * reimplementing HybridChunked * cyril review * fix ci * docs for cache refactor * docs * oopsie * oopsie * fix after merge * cyril review * arthur review * opsie * fix lfm2 * opsie2
This PR completes Part 1 of the cache refactor tracked in #38077.
Summary:
Cache
is structured in a list of layers.reset()
,crop()
,batch_split()
) now auto-propagates to layers.Implementation details:
cache.key_cache
andcache.value_cache
through KVProxy to efficiently return a layer-indexed list of keys or values and keep BC.CacheProcessor
. In the future, it can be expanded to aCacheProcessorList
if needed.MambaCache
tomodeling_mamba.py
#38086.