Skip to content

Commit 10c06b3

Browse files
committed
Use the Index_copy method to update static cache inplace and avoid recompilation during each iteration in XLA
1 parent bdb9106 commit 10c06b3

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

src/transformers/cache_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from .configuration_utils import PretrainedConfig
1010
from .utils import is_hqq_available, is_quanto_available, logging
1111

12-
1312
if is_quanto_available():
1413
from quanto import QBitsTensor, qint2, qint4
1514

@@ -792,8 +791,8 @@ def update(
792791
k_out = self.key_cache[layer_idx]
793792
v_out = self.value_cache[layer_idx]
794793

795-
k_out[:, :, cache_position] = key_states
796-
v_out[:, :, cache_position] = value_states
794+
k_out.index_copy_(2, cache_position, key_states)
795+
v_out.index_copy_(2, cache_position, value_states)
797796

798797
return k_out, v_out
799798

0 commit comments

Comments
 (0)