Skip to content

Commit c338fd4

Browse files
[cache refactor] Move all the caching logic to a per-layer approach (#39106)
* Squash for refactor: Replace monolithic cache classes with modular LayeredCache (#38077) - Introduces CacheLayer and Cache base classes - Ports Static, Dynamic, Offloaded, Quantized, Hybrid, etc. to use layers - Implements method/attr dispatch across layers to reduce boilerplate - Adds CacheProcessor hooks for offloading, quantization, etc. - Updates and passes tests * fix quantized, add tests * remove CacheProcessorList * raushan review, arthur review * joao review: minor things * remove cache configs, make CacheLayer a mixin (joaos review) * back to storage inside Cache() * remove cachebase for decorator * no more __getattr__ * fix tests * joaos review except docs * fix ast deprecations for python 3.14: replace node.n by node.value and use `ast.Constant` More verbose exceptions in `fix_docstring` on docstring formatting issues. * Revert "back to storage inside Cache()" This reverts commit 27916bc. * cyril review * simplify cache export * fix lfm2 cache * HybridChunked to layer * BC proxy object for cache.key_cache[i]=... * reorder classes * bfff come on LFM2 * better tests for hybrid and hybridChunked * complete coverage for hybrid chunked caches (prefill chunking) * reimplementing HybridChunked * cyril review * fix ci * docs for cache refactor * docs * oopsie * oopsie * fix after merge * cyril review * arthur review * opsie * fix lfm2 * opsie2
1 parent b16688e commit c338fd4

File tree

64 files changed

+2505
-2167
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

64 files changed

+2505
-2167
lines changed

docs/source/en/cache_explanation.md

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -82,22 +82,18 @@ When you use Transformers' [`Cache`] class, the self-attention module performs s
8282

8383
## Cache storage implementation
8484

85-
The actual storage of key-value pairs varies between cache implementations. As an example, consider the [`DynamicCache`].
85+
Caches are structured as a list of layers, where each layer contains a key and value cache. The key and value caches are tensors with the shape `[batch_size, num_heads, seq_len, head_dim]`.
8686

87+
Layers can be of different types (e.g. `DynamicLayer`, `StaticLayer`, `SlidingWindowLayer`), which mostly changes how sequence length is handled and how the cache is updated.
8788

88-
In [`DynamicCache`], the key-value pairs are stored as two lists of tensors. Each tensor in the lists have the shape `[batch_size, num_heads, seq_len, head_dim]`.
89-
- `key_cache`: A list of tensors, one for each layer.
90-
- `value_cache`: A list of tensors, one for each layer.
89+
The simplest is a `DynamicLayer` that grows as more tokens are processed. The sequence length dimension (`seq_len`) increases with each new token:
9190

92-
When new tokens are processed:
93-
94-
1. For each layer, the new key and value states are concatenated with the existing cache.
9591
```py
96-
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
97-
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
92+
cache.layers[idx].keys = torch.cat([cache.layers[idx].keys, key_states], dim=-2)
93+
cache.layers[idx].values = torch.cat([cache.layers[idx].values, value_states], dim=-2)
9894
```
9995

100-
2. The cache grows dynamically as more tokens are processed. The sequence length dimension (`seq_len`) increases with each new token.
96+
Other layer types like `StaticLayer` and `SlidingWindowLayer` have a fixed sequence length that is set when the cache is created. This makes them compatible with `torch.compile`. In the case of `SlidingWindowLayer`, existing tokens are shifted out of the cache when a new token is added.
10197

10298
The example below demonstrates how to create a generation loop with [`DynamicCache`]. As discussed, the attention mask is a concatenation of past and current token values and `1` is added to the cache position for the next token.
10399

docs/source/en/internal/generation_utils.md

Lines changed: 52 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -356,66 +356,93 @@ A [`Constraint`] can be used to force the generation to include specific tokens
356356

357357
## Caches
358358

359-
[[autodoc]] Cache
359+
[[autodoc]] CacheLayerMixin
360360
- update
361+
- get_seq_length
362+
- get_mask_sizes
363+
- get_max_cache_shape
364+
- reset
365+
- reorder_cache
361366

362-
[[autodoc]] CacheConfig
363-
- update
367+
[[autodoc]] DynamicLayer
368+
- update
369+
- crop
370+
- batch_repeat_interleave
371+
- batch_select_indices
364372

365-
[[autodoc]] QuantizedCacheConfig
366-
- validate
373+
[[autodoc]] StaticLayer
374+
- update
367375

368-
[[autodoc]] DynamicCache
376+
[[autodoc]] SlidingWindowLayer
377+
- update
378+
379+
[[autodoc]] CacheProcessor
380+
- pre_update
381+
- post_update
382+
383+
[[autodoc]] OffloadedCacheProcessor
384+
- pre_update
385+
386+
[[autodoc]] QuantizedCacheProcessor
387+
- post_update
388+
389+
[[autodoc]] QuantoQuantizedCacheProcessor
390+
- post_update
391+
392+
[[autodoc]] HQQQuantizedCacheProcessor
393+
- post_update
394+
395+
[[autodoc]] Cache
369396
- update
370397
- get_seq_length
398+
- get_mask_sizes
399+
- get_max_cache_shape
400+
- reset
371401
- reorder_cache
402+
- crop
403+
- batch_repeat_interleave
404+
- batch_select_indices
405+
406+
[[autodoc]] DynamicCache
372407
- to_legacy_cache
373408
- from_legacy_cache
374409

375410
[[autodoc]] QuantizedCache
376-
- update
377-
- get_seq_length
378411

379412
[[autodoc]] QuantoQuantizedCache
380413

414+
[[autodoc]] QuantoQuantizedCacheProcessor
415+
381416
[[autodoc]] HQQQuantizedCache
382417

418+
[[autodoc]] HQQQuantizedCacheProcessor
419+
383420
[[autodoc]] OffloadedCache
384-
- update
385-
- prefetch_layer
386-
- evict_previous_layer
387421

388422
[[autodoc]] StaticCache
389-
- update
390-
- get_seq_length
391-
- reset
392423

393424
[[autodoc]] OffloadedStaticCache
394-
- update
395-
- get_seq_length
396-
- reset
397425

398426
[[autodoc]] HybridCache
399-
- update
400-
- get_seq_length
401-
- reset
427+
428+
[[autodoc]] HybridChunkedCache
402429

403430
[[autodoc]] SlidingWindowCache
404-
- update
405-
- reset
406431

407432
[[autodoc]] EncoderDecoderCache
408-
- get_seq_length
409433
- to_legacy_cache
410434
- from_legacy_cache
411-
- reset
412-
- reorder_cache
413435

414436
[[autodoc]] MambaCache
415437
- update_conv_state
416438
- update_ssm_state
417439
- reset
418440

441+
[[autodoc]] CacheConfig
442+
443+
[[autodoc]] QuantizedCacheConfig
444+
445+
419446
## Watermark Utils
420447

421448
[[autodoc]] WatermarkingConfig

docs/source/en/kv_cache.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ The [`QuantizedCache`] reduces memory requirements by quantizing the KV values t
134134
> [!WARNING]
135135
> Quantizing the cache can harm latency if the context length is short and there is enough GPU memory available for generation without enabling cache quantization. Try to find a balance between memory efficiency and latency.
136136
137-
Enable [`QuantizedCache`] by configuring `cache_implementation="quantized"` in [`GenerationConfig`], and indicate the quantization backend in [`QuantizedCacheConfig`]. Any additional quantization related parameters should also be passed either as a dict or an instance of [`QuantizedCacheConfig`]. You should use the default values for these additional parameters unless you're running out-of-memory. In that case, consider decreasing the residual length.
137+
Enable [`QuantizedCache`] by configuring `cache_implementation="quantized"` in [`GenerationConfig`], and the quantization backend, as well as any additional quantization related parameters should also be passed either as a dict. You should use the default values for these additional parameters unless you're running out-of-memory. In that case, consider decreasing the residual length.
138138

139139
<hfoptions id="quantized-cache">
140140
<hfoption id="HQQQuantizedCache">
@@ -143,7 +143,7 @@ For [`HQQQuantizedCache`], we recommend setting the `axis-key` and `axis-value`
143143

144144
```py
145145
import torch
146-
from transformers import AutoTokenizer, AutoModelForCausalLM, HQQQuantizedCache, QuantizedCacheConfig
146+
from transformers import AutoTokenizer, AutoModelForCausalLM, HQQQuantizedCache
147147

148148
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
149149
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16, device_map="auto")
@@ -161,7 +161,7 @@ For [`QuantoQuantizedCache`], we recommend setting the `axis-key` and `axis-valu
161161

162162
```py
163163
import torch
164-
from transformers import AutoTokenizer, AutoModelForCausalLM, QuantoQuantizedCache, QuantizedCacheConfig
164+
from transformers import AutoTokenizer, AutoModelForCausalLM, QuantoQuantizedCache
165165

166166
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
167167
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16, device_map="auto")
@@ -275,7 +275,6 @@ from transformers.cache_utils import (
275275
StaticCache,
276276
SlidingWindowCache,
277277
QuantoQuantizedCache,
278-
QuantizedCacheConfig,
279278
)
280279

281280
model_id = "meta-llama/Llama-2-7b-chat-hf"

docs/source/ko/internal/generation_utils.md

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -345,12 +345,6 @@ generation_output[:2]
345345
[[autodoc]] Cache
346346
- update
347347

348-
[[autodoc]] CacheConfig
349-
- update
350-
351-
[[autodoc]] QuantizedCacheConfig
352-
- validate
353-
354348
[[autodoc]] DynamicCache
355349
- update
356350
- get_seq_length

src/transformers/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,15 +365,28 @@
365365
]
366366
_import_structure["activations"] = []
367367
_import_structure["cache_utils"] = [
368+
"CacheLayerMixin",
369+
"DynamicLayer",
370+
"StaticLayer",
371+
"SlidingWindowLayer",
372+
"ChunkedSlidingLayer",
373+
"CacheProcessor",
374+
"OffloadedCacheProcessor",
375+
"QuantizedCacheProcessor",
376+
"QuantoQuantizedCacheProcessor",
377+
"HQQQuantizedCacheProcessor",
368378
"Cache",
369379
"CacheConfig",
370380
"DynamicCache",
371381
"EncoderDecoderCache",
372382
"HQQQuantizedCache",
383+
"HQQQuantizedCacheProcessor",
373384
"HybridCache",
385+
"HybridChunkedCache",
374386
"OffloadedCache",
375387
"OffloadedStaticCache",
376388
"QuantizedCache",
389+
"QuantoQuantizedCacheProcessor",
377390
"QuantizedCacheConfig",
378391
"QuantoQuantizedCache",
379392
"SinkCache",

0 commit comments

Comments
 (0)