Skip to content

Commit 71efd4c

Browse files
committed
Inherit StaticCacheXLA from StaticCache instead for compatibilty with isinstance(past_key_value, StaticCache)
1 parent 9d67ac1 commit 71efd4c

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/transformers/cache_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -935,7 +935,7 @@ def reset(self):
935935
self.key_cache.zero_()
936936
self.value_cache.zero_()
937937

938-
class StaticCacheXLA(Cache):
938+
class StaticCacheXLA(StaticCache):
939939
"""
940940
Static Cache class to be used with `torch.compile(model)`.
941941
@@ -953,7 +953,7 @@ class StaticCacheXLA(Cache):
953953
"""
954954

955955
def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None:
956-
super().__init__()
956+
super().__init__(config, max_batch_size, max_cache_len, device, dtype)
957957
self.max_batch_size = max_batch_size
958958
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
959959
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads

0 commit comments

Comments
 (0)