Skip to content

Commit 53e99d1

Browse files
committed
feat(cache): SlidingWindowCache uses index_copy_ to avoid useless copy
Applying the same change done in StaticCache.
1 parent d329ad2 commit 53e99d1

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

src/transformers/cache_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -969,8 +969,11 @@ def update(
969969
k_out = k_out[:, :, indices]
970970
v_out = v_out[:, :, indices]
971971

972-
k_out[:, :, cache_position] = key_states
973-
v_out[:, :, cache_position] = value_states
972+
# Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to
973+
# `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place
974+
# operation, that avoids copies and uses less memory.
975+
k_out.index_copy_(2, cache_position, key_states)
976+
v_out.index_copy_(2, cache_position, value_states)
974977

975978
# `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
976979
self.key_cache[layer_idx].zero_()

0 commit comments

Comments
 (0)