Skip to content

Commit 1ad0a9a

Browse files
committed
Use Index_copy method to update static cache inplace and avoid recompilation during each iteration in XLA
1 parent 10c06b3 commit 1ad0a9a

File tree

1 file changed

+18
-1
lines changed

1 file changed

+18
-1
lines changed

src/transformers/cache_utils.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch
88

99
from .configuration_utils import PretrainedConfig
10-
from .utils import is_hqq_available, is_quanto_available, logging
10+
from .utils import is_hqq_available, is_quanto_available, logging, is_torch_xla_available
1111

1212
if is_quanto_available():
1313
from quanto import QBitsTensor, qint2, qint4
@@ -791,6 +791,23 @@ def update(
791791
k_out = self.key_cache[layer_idx]
792792
v_out = self.value_cache[layer_idx]
793793

794+
if is_torch_xla_available(): # If torch_xla is available, do out-of-place operation on KV_Cache and create a new list
795+
k_out = k_out.index_copy(2, cache_position, key_states)
796+
v_out = v_out.index_copy(2, cache_position, value_states)
797+
798+
updated_key_cache = [
799+
k_out if i == layer_idx else self.key_cache[i] for i in range(len(self.key_cache))
800+
]
801+
802+
updated_value_cache = [
803+
v_out if i == layer_idx else self.value_cache[i] for i in range(len(self.value_cache))
804+
]
805+
806+
self.key_cache = updated_key_cache
807+
self.value_cache = updated_value_cache
808+
809+
return k_out, v_out
810+
794811
k_out.index_copy_(2, cache_position, key_states)
795812
v_out.index_copy_(2, cache_position, value_states)
796813

0 commit comments

Comments
 (0)