Skip to content

Commit 9d67ac1

Browse files
committed
Warn user to use cache_position when calling the forward path
torch.arange(past_length) where past_length keeps changing causes recompilation in XLA
1 parent ed93bf8 commit 9d67ac1

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

src/transformers/cache_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1020,6 +1020,7 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
10201020
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
10211021
# limit the check to the first batch member and head dimension.
10221022
# TODO: deprecate this function in favor of `cache_position`
1023+
logger.debug("Use cache_position parameter in your model for better performance.")
10231024
key_cache = self.key_cache[layer_idx]
10241025
device = key_cache.device
10251026

0 commit comments

Comments
 (0)