Skip to content

Commit 4ffdb17

Browse files
committed
feat(cache): use optimized StaticCache class for XLA
This is actually a ripoff of the work originally done as a contribution to transformers: huggingface/transformers#31129 The original contribution has not been merged yet, but it shows lower memory usage and better performance on XLA. So I think it's worth adding it here, to be integrated on optimum-tpu.
1 parent 7cce24c commit 4ffdb17

File tree

2 files changed

+63
-2
lines changed

2 files changed

+63
-2
lines changed

optimum/tpu/static_cache_xla.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from typing import Any, Dict, Optional, Tuple
2+
3+
import torch
4+
from transformers import StaticCache
5+
6+
7+
class StaticCacheXla(StaticCache):
8+
def update(
9+
self,
10+
key_states: torch.Tensor,
11+
value_states: torch.Tensor,
12+
layer_idx: int,
13+
cache_kwargs: Optional[Dict[str, Any]] = None,
14+
) -> Tuple[torch.Tensor, torch.Tensor]:
15+
"""
16+
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
17+
It is VERY important to index using a tensor, otherwise you introduce a copy to the device.
18+
19+
Parameters:
20+
key_states (`torch.Tensor`):
21+
The new key states to cache.
22+
value_states (`torch.Tensor`):
23+
The new value states to cache.
24+
layer_idx (`int`):
25+
The index of the layer to cache the states for.
26+
cache_kwargs (`Dict[str, Any]`, `optional`):
27+
Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input
28+
to know how where to write in the cache.
29+
30+
Return:
31+
A tuple containing the updated key and value states.
32+
"""
33+
cache_position = cache_kwargs.get("cache_position")
34+
k_out = self.key_cache[layer_idx]
35+
v_out = self.value_cache[layer_idx]
36+
37+
# `index_copy_(dim, index, source)` functions similarly to `tensor[index] = source`,
38+
# but it is used for better generality and it uses less memory on XLA.
39+
# For more information, refer to: https://pytorch.org/cppdocs/notes/tensor_indexing.html
40+
k_out.index_copy_(2, cache_position, key_states)
41+
v_out.index_copy_(2, cache_position, value_states)
42+
43+
return k_out, v_out
44+
45+
46+
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
47+
"""Returns the sequence length of the cached states that were seen by the model."""
48+
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
49+
# limit the check to the first batch member and head dimension.
50+
# TODO: deprecate this function in favor of `cache_position`
51+
key_cache = self.key_cache[layer_idx]
52+
device = key_cache.device
53+
54+
# index_select(dim, index) performs the same operation as item = tensor[..., index, ...]
55+
# but it is used for better generality and it uses less memory on XLA.
56+
# For more information, refer to: https://pytorch.org/cppdocs/notes/tensor_indexing.html
57+
item = key_cache.index_select(0, torch.tensor(0, device=device))
58+
head = item.index_select(1, torch.tensor(0, device=device))
59+
60+
return head.any(dim=-1).sum()

text-generation-inference/server/text_generation_server/generator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@
1212
import torch.multiprocessing as mp
1313
import torch_xla.core.xla_model as xm
1414
import torch_xla.distributed.xla_multiprocessing as xmp
15-
from transformers import AutoTokenizer, PreTrainedTokenizerBase, StaticCache
15+
from transformers import AutoTokenizer, PreTrainedTokenizerBase
1616
from transformers.generation import GenerationConfig
1717

1818
import optimum.tpu.xla_logger as logger
1919
from optimum.tpu import AutoModelForCausalLM
2020
from optimum.tpu.generation import TokenSelector
21+
from optimum.tpu.static_cache_xla import StaticCacheXla
2122
from optimum.tpu.xla_mp_comm import AgentMailbox, RootMailbox
2223

2324
from .generator_base import Generator
@@ -529,7 +530,7 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
529530

530531
extra_args = {}
531532
if self._supports_static_cache:
532-
self.past_key_values = StaticCache(
533+
self.past_key_values = StaticCacheXla(
533534
config=self.model.config,
534535
max_batch_size=len(self.slots),
535536
max_cache_len=self.model.config.sequence_length,

0 commit comments

Comments
 (0)