Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
1c3cbcc
Squash for refactor: Replace monolithic cache classes with modular La…
manueldeprada Jun 29, 2025
04d7a0b
fix quantized, add tests
manueldeprada Jun 30, 2025
26c28af
remove CacheProcessorList
manueldeprada Jun 30, 2025
16a6624
raushan review, arthur review
manueldeprada Jul 2, 2025
aec9ccd
joao review: minor things
manueldeprada Jul 2, 2025
e80c68a
remove cache configs, make CacheLayer a mixin (joaos review)
manueldeprada Jul 4, 2025
27916bc
back to storage inside Cache()
manueldeprada Jul 9, 2025
fd83e14
remove cachebase for decorator
manueldeprada Jul 10, 2025
c200447
no more __getattr__
manueldeprada Jul 10, 2025
f327570
Merge branch 'main' of github.com:huggingface/transformers into cache…
manueldeprada Jul 10, 2025
5b1b1f1
fix tests
manueldeprada Jul 10, 2025
58dbcfe
joaos review except docs
manueldeprada Jul 11, 2025
0c6d2ff
fix ast deprecations for python 3.14: replace node.n by node.value an…
manueldeprada Jun 19, 2025
6a77408
Revert "back to storage inside Cache()"
manueldeprada Jul 14, 2025
13ec4a4
cyril review
manueldeprada Jul 14, 2025
7029a90
simplify cache export
manueldeprada Jul 14, 2025
5fa9901
Merge branch 'main' of github.com:huggingface/transformers into cache…
manueldeprada Jul 14, 2025
dd7458b
fix lfm2 cache
manueldeprada Jul 14, 2025
dc08253
HybridChunked to layer
manueldeprada Jul 14, 2025
a952124
BC proxy object for cache.key_cache[i]=...
manueldeprada Jul 14, 2025
dbbc4d5
reorder classes
manueldeprada Jul 14, 2025
4bb48fc
bfff come on LFM2
manueldeprada Jul 14, 2025
00b1f96
better tests for hybrid and hybridChunked
manueldeprada Jul 15, 2025
def346e
Merge branch 'main' of github.com:huggingface/transformers into cache…
manueldeprada Jul 15, 2025
38e8603
complete coverage for hybrid chunked caches (prefill chunking)
manueldeprada Jul 15, 2025
34a3022
reimplementing HybridChunked
manueldeprada Jul 15, 2025
4222653
cyril review
manueldeprada Jul 16, 2025
063459c
Merge branch 'main' of github.com:huggingface/transformers into cache…
manueldeprada Jul 17, 2025
1acc648
fix ci
manueldeprada Jul 17, 2025
ca39ffe
docs for cache refactor
manueldeprada Jul 18, 2025
731d0b7
docs
manueldeprada Jul 11, 2025
574b820
Merge branch 'main' of github.com:huggingface/transformers into cache…
manueldeprada Jul 18, 2025
a479470
oopsie
manueldeprada Jul 18, 2025
9c0bdcc
oopsie
manueldeprada Jul 18, 2025
04091ed
Merge branch 'main' into cache-refactor-1
manueldeprada Jul 21, 2025
8642bb6
Merge branch 'main' of github.com:huggingface/transformers into cache…
manueldeprada Jul 21, 2025
83968bd
Merge branch 'cache-refactor-1' of https://github.com/manueldeprada/t…
manueldeprada Jul 21, 2025
0c4700d
fix after merge
manueldeprada Jul 21, 2025
b3a35e9
cyril review
manueldeprada Jul 21, 2025
e4878ad
arthur review
manueldeprada Jul 22, 2025
38fb99d
Merge branch 'main' of github.com:huggingface/transformers into cache…
manueldeprada Jul 22, 2025
8df1595
opsie
manueldeprada Jul 22, 2025
ad65a02
fix lfm2
manueldeprada Jul 22, 2025
d9fbb04
opsie2
manueldeprada Jul 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 6 additions & 10 deletions docs/source/en/cache_explanation.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,22 +82,18 @@ When you use Transformers' [`Cache`] class, the self-attention module performs s

## Cache storage implementation

The actual storage of key-value pairs varies between cache implementations. As an example, consider the [`DynamicCache`].
Caches are structured as a list of layers, where each layer contains a key and value cache. The key and value caches are tensors with the shape `[batch_size, num_heads, seq_len, head_dim]`.

Layers can be of different types (e.g. `DynamicLayer`, `StaticLayer`, `SlidingWindowLayer`), which mostly changes how sequence length is handled and how the cache is updated.

In [`DynamicCache`], the key-value pairs are stored as two lists of tensors. Each tensor in the lists have the shape `[batch_size, num_heads, seq_len, head_dim]`.
- `key_cache`: A list of tensors, one for each layer.
- `value_cache`: A list of tensors, one for each layer.
The simplest is a `DynamicLayer` that grows as more tokens are processed. The sequence length dimension (`seq_len`) increases with each new token:

When new tokens are processed:

1. For each layer, the new key and value states are concatenated with the existing cache.
```py
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
cache.layers[idx].keys = torch.cat([cache.layers[idx].keys, key_states], dim=-2)
cache.layers[idx].values = torch.cat([cache.layers[idx].values, value_states], dim=-2)
```

2. The cache grows dynamically as more tokens are processed. The sequence length dimension (`seq_len`) increases with each new token.
Other layer types like `StaticLayer` and `SlidingWindowLayer` have a fixed sequence length that is set when the cache is created. This makes them compatible with `torch.compile`. In the case of `SlidingWindowLayer`, existing tokens are shifted out of the cache when a new token is added.

The example below demonstrates how to create a generation loop with [`DynamicCache`]. As discussed, the attention mask is a concatenation of past and current token values and `1` is added to the cache position for the next token.

Expand Down
77 changes: 52 additions & 25 deletions docs/source/en/internal/generation_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -356,66 +356,93 @@ A [`Constraint`] can be used to force the generation to include specific tokens

## Caches

[[autodoc]] Cache
[[autodoc]] CacheLayerMixin
- update
- get_seq_length
- get_mask_sizes
- get_max_cache_shape
- reset
- reorder_cache

[[autodoc]] CacheConfig
- update
[[autodoc]] DynamicLayer
- update
- crop
- batch_repeat_interleave
- batch_select_indices

[[autodoc]] QuantizedCacheConfig
- validate
[[autodoc]] StaticLayer
- update

[[autodoc]] DynamicCache
[[autodoc]] SlidingWindowLayer
- update

[[autodoc]] CacheProcessor
- pre_update
- post_update

[[autodoc]] OffloadedCacheProcessor
- pre_update

[[autodoc]] QuantizedCacheProcessor
- post_update

[[autodoc]] QuantoQuantizedCacheProcessor
- post_update

[[autodoc]] HQQQuantizedCacheProcessor
- post_update

[[autodoc]] Cache
- update
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing: Cache (and its methods)

- get_seq_length
- get_mask_sizes
- get_max_cache_shape
- reset
- reorder_cache
- crop
- batch_repeat_interleave
- batch_select_indices

[[autodoc]] DynamicCache
- to_legacy_cache
- from_legacy_cache

[[autodoc]] QuantizedCache
- update
- get_seq_length

[[autodoc]] QuantoQuantizedCache

[[autodoc]] QuantoQuantizedCacheProcessor

[[autodoc]] HQQQuantizedCache

[[autodoc]] HQQQuantizedCacheProcessor

[[autodoc]] OffloadedCache
- update
- prefetch_layer
- evict_previous_layer

[[autodoc]] StaticCache
- update
- get_seq_length
- reset

[[autodoc]] OffloadedStaticCache
- update
- get_seq_length
- reset

[[autodoc]] HybridCache
- update
- get_seq_length
- reset

[[autodoc]] HybridChunkedCache

[[autodoc]] SlidingWindowCache
- update
- reset

[[autodoc]] EncoderDecoderCache
- get_seq_length
- to_legacy_cache
- from_legacy_cache
- reset
- reorder_cache

[[autodoc]] MambaCache
- update_conv_state
- update_ssm_state
- reset

[[autodoc]] CacheConfig

[[autodoc]] QuantizedCacheConfig


## Watermark Utils

[[autodoc]] WatermarkingConfig
Expand Down
7 changes: 3 additions & 4 deletions docs/source/en/kv_cache.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ The [`QuantizedCache`] reduces memory requirements by quantizing the KV values t
> [!WARNING]
> Quantizing the cache can harm latency if the context length is short and there is enough GPU memory available for generation without enabling cache quantization. Try to find a balance between memory efficiency and latency.

Enable [`QuantizedCache`] by configuring `cache_implementation="quantized"` in [`GenerationConfig`], and indicate the quantization backend in [`QuantizedCacheConfig`]. Any additional quantization related parameters should also be passed either as a dict or an instance of [`QuantizedCacheConfig`]. You should use the default values for these additional parameters unless you're running out-of-memory. In that case, consider decreasing the residual length.
Enable [`QuantizedCache`] by configuring `cache_implementation="quantized"` in [`GenerationConfig`], and the quantization backend, as well as any additional quantization related parameters should also be passed either as a dict. You should use the default values for these additional parameters unless you're running out-of-memory. In that case, consider decreasing the residual length.

<hfoptions id="quantized-cache">
<hfoption id="HQQQuantizedCache">
Expand All @@ -143,7 +143,7 @@ For [`HQQQuantizedCache`], we recommend setting the `axis-key` and `axis-value`

```py
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, HQQQuantizedCache, QuantizedCacheConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, HQQQuantizedCache

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16, device_map="auto")
Expand All @@ -161,7 +161,7 @@ For [`QuantoQuantizedCache`], we recommend setting the `axis-key` and `axis-valu

```py
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, QuantoQuantizedCache, QuantizedCacheConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, QuantoQuantizedCache

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16, device_map="auto")
Expand Down Expand Up @@ -275,7 +275,6 @@ from transformers.cache_utils import (
StaticCache,
SlidingWindowCache,
QuantoQuantizedCache,
QuantizedCacheConfig,
)

model_id = "meta-llama/Llama-2-7b-chat-hf"
Expand Down
6 changes: 0 additions & 6 deletions docs/source/ko/internal/generation_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -345,12 +345,6 @@ generation_output[:2]
[[autodoc]] Cache
- update

[[autodoc]] CacheConfig
- update

[[autodoc]] QuantizedCacheConfig
- validate

[[autodoc]] DynamicCache
- update
- get_seq_length
Expand Down
13 changes: 13 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,15 +365,28 @@
]
_import_structure["activations"] = []
_import_structure["cache_utils"] = [
"CacheLayerMixin",
"DynamicLayer",
"StaticLayer",
"SlidingWindowLayer",
"ChunkedSlidingLayer",
"CacheProcessor",
"OffloadedCacheProcessor",
"QuantizedCacheProcessor",
"QuantoQuantizedCacheProcessor",
"HQQQuantizedCacheProcessor",
"Cache",
"CacheConfig",
"DynamicCache",
"EncoderDecoderCache",
"HQQQuantizedCache",
"HQQQuantizedCacheProcessor",
"HybridCache",
"HybridChunkedCache",
"OffloadedCache",
"OffloadedStaticCache",
"QuantizedCache",
"QuantoQuantizedCacheProcessor",
"QuantizedCacheConfig",
"QuantoQuantizedCache",
"SinkCache",
Expand Down
Loading