Skip to content

Commit

Permalink
Docs: add more cross-references to the KV cache docs (huggingface#33323)
Browse files Browse the repository at this point in the history
* add more cross-references

* nit

* import guard

* more import guards

* nit

* Update src/transformers/generation/configuration_utils.py
  • Loading branch information
gante authored and itazap committed Sep 20, 2024
1 parent 548d485 commit f6eab64
Show file tree
Hide file tree
Showing 29 changed files with 99 additions and 57 deletions.
20 changes: 10 additions & 10 deletions docs/source/en/kv_cache.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ More concretely, key-value cache acts as a memory bank for these generative mode


See an example below for how to implement your own generation loop.

```python
>>> import torch
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache

>>> model_id = "meta-llama/Llama-2-7b-chat-hf"
>>> model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="cuda:0")
>>> tokenizer = AutoTokenizer.from_pretrained(model_id)
Expand All @@ -69,10 +69,10 @@ More concretely, key-value cache acts as a memory bank for these generative mode
>>> max_new_tokens = 10

>>> for _ in range(max_new_tokens):
... outputs = model(**inputs, cache_position=cache_position, past_key_values=past_key_values, use_cache=True)
... outputs = model(**inputs, cache_position=cache_position, past_key_values=past_key_values, use_cache=True)
... # Greedily sample one next token
... next_token_ids = outputs.logits[:, -1:].argmax(-1)
... generated_ids = torch.cat([generated_ids, next_token_ids], dim=-1)
... generated_ids = torch.cat([generated_ids, next_token_ids], dim=-1)
...
... # Prepare inputs for the next generation step by leaaving unprocessed tokens, in our case we have only one new token
... # and expanding attn mask for the new token, as explained above
Expand Down Expand Up @@ -222,7 +222,7 @@ before successfully generating 40 beams.

### Static Cache

Since the "DynamicCache" dynamically grows with each generation step, it prevents you from taking advantage of JIT optimizations. The [`~StaticCache`] pre-allocates
Since the "DynamicCache" dynamically grows with each generation step, it prevents you from taking advantage of JIT optimizations. The [`~StaticCache`] pre-allocates
a specific maximum size for the keys and values, allowing you to generate up to the maximum length without having to modify cache size. Check the below usage example.

For more examples with Static Cache and JIT compilation, take a look at [StaticCache & torchcompile](./llm_optims#static-kv-cache-and-torchcompile)
Expand Down Expand Up @@ -267,7 +267,7 @@ This will use the [`~OffloadedStaticCache`] implementation instead.

As the name suggests, this cache type implements a sliding window over previous keys and values, retaining only the last `sliding_window` tokens. It should be used with models like Mistral that support sliding window attention. Additionally, similar to Static Cache, this one is JIT-friendly and can be used with the same compile tecniques as Static Cache.

Note that you can use this cache only for models that support sliding window, e.g. Mistral models.
Note that you can use this cache only for models that support sliding window, e.g. Mistral models.


```python
Expand Down Expand Up @@ -324,7 +324,7 @@ We have seen how to use each of the cache types when generating. What if you wan

The general format when doing iterative generation is as below. First you have to initialize an empty cache of the type you want, and you can start feeding in new prompts iteratively. Keeping track of dialogues history and formatting can be done with chat templates, read more on that in [chat_templating](./chat_templating)

In case you are using Sink Cache, you have to crop your inputs to that maximum length because Sink Cache can generate text longer than its maximum window size, but it expects the first input to not exceed the maximum cache length.
In case you are using Sink Cache, you have to crop your inputs to that maximum length because Sink Cache can generate text longer than its maximum window size, but it expects the first input to not exceed the maximum cache length.


```python
Expand Down Expand Up @@ -354,9 +354,9 @@ In case you are using Sink Cache, you have to crop your inputs to that maximum l
... inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
... if isinstance(past_key_values, SinkCache):
... inputs = {k: v[:, -max_cache_length:] for k, v in inputs.items()}
...
...
... input_length = inputs["input_ids"].shape[1]
...
...
... outputs = model.generate(**inputs, do_sample=False, max_new_tokens=256, past_key_values=past_key_values)
... completion = tokenizer.decode(outputs[0, input_length: ], skip_special_tokens=True)
... messages.append({"role": "assistant", "content": completion})
Expand Down Expand Up @@ -400,4 +400,4 @@ Sometimes you would want to first fill-in cache object with key/values for certa

>>> print(responses)
['<s> You are a helpful assistant. Help me to write a blogpost about travelling.\n\nTitle: The Ultimate Guide to Travelling: Tips, Tricks, and', '<s> You are a helpful assistant. What is the capital of France?\n\nYes, the capital of France is Paris.</s>']
```
```
2 changes: 1 addition & 1 deletion docs/source/en/llm_optims.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ This guide will show you how to use the optimization techniques available in Tra

During decoding, a LLM computes the key-value (kv) values for each input token and since it is autoregressive, it computes the same kv values each time because the generated output becomes part of the input now. This is not very efficient because you're recomputing the same kv values each time.

To optimize this, you can use a kv-cache to store the past keys and values instead of recomputing them each time. However, since the kv-cache grows with each generation step and is dynamic, it prevents you from taking advantage of [`torch.compile`](./perf_torch_compile), a powerful optimization tool that fuses PyTorch code into fast and optimized kernels.
To optimize this, you can use a kv-cache to store the past keys and values instead of recomputing them each time. However, since the kv-cache grows with each generation step and is dynamic, it prevents you from taking advantage of [`torch.compile`](./perf_torch_compile), a powerful optimization tool that fuses PyTorch code into fast and optimized kernels. We have an entire guide dedicated to kv-caches [here](./kv_cache).

The *static kv-cache* solves this issue by pre-allocating the kv-cache size to a maximum value which allows you to combine it with `torch.compile` for up to a 4x speed up. Your speed up may vary depending on the model size (larger models have a smaller speed up) and hardware.

Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/llm_tutorial_optimization.md
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,7 @@ Using the key-value cache has two advantages:
- Significant increase in computational efficiency as less computations are performed compared to computing the full \\( \mathbf{QK}^T \\) matrix. This leads to an increase in inference speed
- The maximum required memory is not increased quadratically with the number of generated tokens, but only increases linearly.

> One should *always* make use of the key-value cache as it leads to identical results and a significant speed-up for longer input sequences. Transformers has the key-value cache enabled by default when making use of the text pipeline or the [`generate` method](https://huggingface.co/docs/transformers/main_classes/text_generation).
> One should *always* make use of the key-value cache as it leads to identical results and a significant speed-up for longer input sequences. Transformers has the key-value cache enabled by default when making use of the text pipeline or the [`generate` method](https://huggingface.co/docs/transformers/main_classes/text_generation). We have an entire guide dedicated to caches [here](./kv_cache).
<Tip warning={true}>

Expand Down
37 changes: 34 additions & 3 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,34 @@
logger = logging.get_logger(__name__)
METADATA_FIELDS = ("_from_model_config", "_commit_hash", "_original_object_hash", "transformers_version")
NEEDS_CACHE_CONFIG = {}
NEED_SETUP_CACHE_CLASSES_MAPPING = {}
QUANT_BACKEND_CLASSES_MAPPING = {}
ALL_CACHE_IMPLEMENTATIONS = []

if is_torch_available():
from ..cache_utils import QuantizedCacheConfig
from ..cache_utils import (
HQQQuantizedCache,
HybridCache,
MambaCache,
OffloadedStaticCache,
QuantizedCacheConfig,
QuantoQuantizedCache,
SlidingWindowCache,
StaticCache,
)

NEEDS_CACHE_CONFIG["quantized"] = QuantizedCacheConfig
NEED_SETUP_CACHE_CLASSES_MAPPING = {
"static": StaticCache,
"offloaded_static": OffloadedStaticCache,
"sliding_window": SlidingWindowCache,
"hybrid": HybridCache,
"mamba": MambaCache,
}
QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache}
ALL_CACHE_IMPLEMENTATIONS = list(NEED_SETUP_CACHE_CLASSES_MAPPING.keys()) + list(
QUANT_BACKEND_CLASSES_MAPPING.keys()
)


class GenerationMode(ExplicitEnum):
Expand All @@ -70,7 +93,7 @@ class GenerationMode(ExplicitEnum):

class GenerationConfig(PushToHubMixin):
# no-format
r"""
rf"""
Class that holds a configuration for a generation task. A `generate` call supports the following generation methods
for text-decoder, text-to-text, speech-to-text, and vision-to-text models:
Expand Down Expand Up @@ -146,7 +169,10 @@ class GenerationConfig(PushToHubMixin):
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
speed up decoding.
cache_implementation (`str`, *optional*, default to `None`):
Cache class that should be used when generating.
Name of the cache class that will be instantiated in `generate`, for faster decoding. Possible values are:
{ALL_CACHE_IMPLEMENTATIONS}. We support other cache types, but they must be manually instantiated and
passed to `generate` through the `past_key_values` argument. See our
[cache documentation](https://huggingface.co/docs/transformers/en/kv_cache) for further information.
cache_config (`CacheConfig` or `dict`, *optional*, default to `None`):
Arguments used in the key-value cache class can be passed in `cache_config`. Can be passed as a `Dict` and
it will be converted to its repsective `CacheConfig` internally.
Expand Down Expand Up @@ -699,6 +725,11 @@ def validate(self, is_init=False):
)

# 5. check cache-related arguments
if self.cache_implementation is not None and self.cache_implementation not in ALL_CACHE_IMPLEMENTATIONS:
raise ValueError(
f"Invalid `cache_implementation` ({self.cache_implementation}). Choose one of: "
f"{ALL_CACHE_IMPLEMENTATIONS}"
)
if self.cache_config is not None:
cache_class = NEEDS_CACHE_CONFIG.get(self.cache_implementation)
if cache_class is None:
Expand Down
23 changes: 6 additions & 17 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,8 @@
Cache,
DynamicCache,
EncoderDecoderCache,
HQQQuantizedCache,
HybridCache,
MambaCache,
OffloadedCache,
OffloadedStaticCache,
QuantizedCacheConfig,
QuantoQuantizedCache,
SlidingWindowCache,
StaticCache,
)
from ..integrations.deepspeed import is_deepspeed_zero3_enabled
from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
Expand Down Expand Up @@ -68,7 +61,12 @@
_prepare_attention_mask,
_prepare_token_type_ids,
)
from .configuration_utils import GenerationConfig, GenerationMode
from .configuration_utils import (
NEED_SETUP_CACHE_CLASSES_MAPPING,
QUANT_BACKEND_CLASSES_MAPPING,
GenerationConfig,
GenerationMode,
)
from .logits_process import (
EncoderNoRepeatNGramLogitsProcessor,
EncoderRepetitionPenaltyLogitsProcessor,
Expand Down Expand Up @@ -118,15 +116,6 @@
if is_accelerate_available():
from accelerate.hooks import AlignDevicesHook, add_hook_to_module

NEED_SETUP_CACHE_CLASSES_MAPPING = {
"static": StaticCache,
"offloaded_static": OffloadedStaticCache,
"sliding_window": SlidingWindowCache,
"hybrid": HybridCache,
"mamba": MambaCache,
}
QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache}


@dataclass
class GenerateDecoderOnlyOutput(ModelOutput):
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/bloom/modeling_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,8 @@ def _init_weights(self, module: nn.Module):
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
Two formats are allowed:
- a [`~cache_utils.Cache`] instance;
- a [`~cache_utils.Cache`] instance, see our
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
cache format.
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/codegen/modeling_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,8 @@ def _init_weights(self, module):
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
Two formats are allowed:
- a [`~cache_utils.Cache`] instance;
- a [`~cache_utils.Cache`] instance, see our
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
cache format.
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,7 +721,8 @@ def _init_weights(self, module):
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
Two formats are allowed:
- a [`~cache_utils.Cache`] instance;
- a [`~cache_utils.Cache`] instance, see our
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
cache format.
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/dbrx/modeling_dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,7 +948,8 @@ def _init_weights(self, module: nn.Module):
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
Two formats are allowed:
- a [`~cache_utils.Cache`] instance;
- a [`~cache_utils.Cache`] instance, see our
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
cache format.
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,7 +836,8 @@ def forward(
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
Two formats are allowed:
- a [`~cache_utils.Cache`] instance;
- a [`~cache_utils.Cache`] instance, see our
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
cache format.
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,8 @@ def _init_weights(self, module):
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
Two formats are allowed:
- a [`~cache_utils.Cache`] instance;
- a [`~cache_utils.Cache`] instance, see our
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
cache format.
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,8 @@ def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False):
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
Two formats are allowed:
- a [`~cache_utils.Cache`] instance;
- a [`~cache_utils.Cache`] instance, see our
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
cache format.
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/git/modeling_git.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,8 @@ def _init_weights(self, module):
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
Two formats are allowed:
- a [`~cache_utils.Cache`] instance;
- a [`~cache_utils.Cache`] instance, see our
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
cache format.
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/gpt_neo/modeling_gpt_neo.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,8 @@ def _init_weights(self, module):
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
Two formats are allowed:
- a [`~cache_utils.Cache`] instance;
- a [`~cache_utils.Cache`] instance, see our
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
cache format.
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/gpt_neox/modeling_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,7 +820,8 @@ def forward(
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
Two formats are allowed:
- a [`~cache_utils.Cache`] instance;
- a [`~cache_utils.Cache`] instance, see our
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
cache format.
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/gptj/modeling_gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,8 @@ def _init_weights(self, module):
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
Two formats are allowed:
- a [`~cache_utils.Cache`] instance;
- a [`~cache_utils.Cache`] instance, see our
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
cache format.
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/granite/modeling_granite.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,8 @@ def _init_weights(self, module):
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
Two formats are allowed:
- a [`~cache_utils.Cache`] instance;
- a [`~cache_utils.Cache`] instance, see our
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
cache format.
Expand Down
Loading

0 comments on commit f6eab64

Please sign in to comment.