Skip to content

Commit ed93bf8

Browse files
committed
get_seq_length in StaticCacheXLA uses out-of-place index_select operation.
This is necessary to for XLA as tensors are not materilzed yet
1 parent 670d432 commit ed93bf8

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

src/transformers/cache_utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1020,7 +1020,13 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
10201020
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
10211021
# limit the check to the first batch member and head dimension.
10221022
# TODO: deprecate this function in favor of `cache_position`
1023-
raise NotImplementedError("StaticCacheXLA is not implemented yet")
1023+
key_cache = self.key_cache[layer_idx]
1024+
device = key_cache.device
1025+
1026+
item = key_cache.index_select(0, torch.tensor(0, device=device))
1027+
head = item.index_select(1, torch.tensor(0, device=device))
1028+
1029+
return head.any(dim=-1).sum()
10241030

10251031
def get_max_length(self) -> Optional[int]:
10261032
"""Returns the maximum sequence length of the cached states."""

0 commit comments

Comments
 (0)