@@ -86,45 +86,6 @@ def post_update(
86
86
return key_tensors , value_tensors
87
87
88
88
89
- class CacheProcessorList (list ):
90
- """
91
- list of cache processors that can be applied to a cache.
92
- """
93
-
94
- def init (self , cache : "Cache" , ** kwargs ) -> None :
95
- """Initialize all processors in the list."""
96
- for processor in self :
97
- processor .init (cache , ** kwargs )
98
-
99
- def pre_update (
100
- self ,
101
- cache : "Cache" ,
102
- key_states : torch .Tensor ,
103
- value_states : torch .Tensor ,
104
- layer_idx : int ,
105
- cache_kwargs : Optional [dict [str , Any ]] = None ,
106
- ) -> tuple [torch .Tensor , torch .Tensor ]:
107
- """Apply pre_update hook for all processors."""
108
- for processor in self :
109
- key_states , value_states = processor .pre_update (cache , key_states , value_states , layer_idx , cache_kwargs )
110
- return key_states , value_states
111
-
112
- def post_update (
113
- self ,
114
- cache : "Cache" ,
115
- key_tensors : torch .Tensor ,
116
- value_tensors : torch .Tensor ,
117
- layer_idx : int ,
118
- cache_kwargs : Optional [dict [str , Any ]] = None ,
119
- ) -> tuple [torch .Tensor , torch .Tensor ]:
120
- """Apply post_update hook for all processors."""
121
- for processor in self :
122
- key_tensors , value_tensors = processor .post_update (
123
- cache , key_tensors , value_tensors , layer_idx , cache_kwargs
124
- )
125
- return key_tensors , value_tensors
126
-
127
-
128
89
class KVList :
129
90
"""Efficiently simulates layer-indexed key or value lists from a layered cache.
130
91
This allows for BC access, e.g., cache.key_cache[idx] or cache.value_cache[idx]."""
@@ -228,8 +189,8 @@ class Cache:
228
189
config_or_ddp_cache_data (`PretrainedConfig` or `Iterable`, *optional*):
229
190
Model configuration for shape/device info, or DDP-distributed cache data for compatibility.
230
191
If DDP-distributed cache data, must be an iterable of (key_states, value_states) tuples for each layer.
231
- processors (`CacheProcessorList `, *optional*):
232
- List of cache processors to apply (e.g., quantization, offloading). Defaults to empty list .
192
+ processor (`CacheProcessor `, *optional*):
193
+ Cache processor to apply (e.g., quantization, offloading).
233
194
pattern_block (`tuple[Type[CacheLayer], ...]`, *optional*):
234
195
Pattern of cache layer types to use. Defaults to `(DynamicLayer,)`. Must be a tuple whose length divides
235
196
the total number of layers. The pattern repeats to fill all layers. Examples: `(StaticLayer,)` for a
@@ -258,13 +219,13 @@ def __init__(
258
219
config_or_ddp_cache_data : Optional [
259
220
Union [PretrainedConfig , Iterable [tuple [torch .Tensor , torch .Tensor ]]]
260
221
] = None ,
261
- processors : Optional [CacheProcessorList ] = None ,
222
+ processor : Optional [CacheProcessor ] = None ,
262
223
pattern_block : Optional [tuple [type ["CacheLayer" ], ...]] = None ,
263
224
* args ,
264
225
** kwargs ,
265
226
):
266
227
self .layers : list [CacheLayer ] = []
267
- self .processors = processors if processors is not None else CacheProcessorList ()
228
+ self .processor = processor
268
229
pattern_block = pattern_block or self .pattern_block or (DynamicLayer ,)
269
230
270
231
if isinstance (config_or_ddp_cache_data , PretrainedConfig ):
@@ -280,7 +241,8 @@ def __init__(
280
241
assert pattern_block == (DynamicLayer ,), "torch DDP is only supported for DynamicCache"
281
242
for key_states , value_states in _distributed_cache_data :
282
243
self .layers .append (DynamicLayer .from_kv (key_states , value_states ))
283
- self .processors .init (self , ** kwargs )
244
+ if self .processor is not None :
245
+ self .processor .init (self , ** kwargs )
284
246
return
285
247
else :
286
248
model_config = kwargs .pop ("config" , None )
@@ -292,7 +254,8 @@ def __init__(
292
254
layer = layer_type (self .config .to_layer (idx ))
293
255
self .layers .append (layer )
294
256
295
- self .processors .init (self , ** kwargs )
257
+ if self .processor is not None :
258
+ self .processor .init (self , ** kwargs )
296
259
297
260
def grow_layers_to (self , layer_idx ):
298
261
while len (self .layers ) <= layer_idx :
@@ -335,12 +298,16 @@ def update(
335
298
Return:
336
299
A tuple containing the updated key and value states.
337
300
"""
338
- key_states , value_states = self .processors .pre_update (self , key_states , value_states , layer_idx , cache_kwargs )
301
+ if self .processor is not None :
302
+ key_states , value_states = self .processor .pre_update (
303
+ self , key_states , value_states , layer_idx , cache_kwargs
304
+ )
339
305
self .grow_layers_to (layer_idx )
340
306
key_tensors , value_tensors = self .layers [layer_idx ].update (key_states , value_states , cache_kwargs )
341
- key_tensors , value_tensors = self .processors .post_update (
342
- self , key_tensors , value_tensors , layer_idx , cache_kwargs
343
- )
307
+ if self .processor is not None :
308
+ key_tensors , value_tensors = self .processor .post_update (
309
+ self , key_tensors , value_tensors , layer_idx , cache_kwargs
310
+ )
344
311
return key_tensors , value_tensors
345
312
346
313
def __getitem__ (self , layer_idx : int ) -> tuple [torch .Tensor , torch .Tensor ]:
@@ -1015,8 +982,7 @@ class OffloadedCache(DynamicCache):
1015
982
1016
983
def __init__ (self , config : Optional [CacheConfig ] = None ) -> None :
1017
984
# Create the underlying cache with offload processor
1018
- processors = CacheProcessorList ([OffloadedCacheProcessor ()])
1019
- super ().__init__ (processors = processors , config = config )
985
+ super ().__init__ (processor = OffloadedCacheProcessor (), config = config )
1020
986
1021
987
1022
988
class StaticLayer (CacheLayer ):
@@ -1115,7 +1081,7 @@ class StaticCache(Cache):
1115
1081
1116
1082
Parameters:
1117
1083
config_or_ddp_cache_data (`Union`, *optional*): Model configuration for shape/device info, or DDP-distributed cache data for compatibility.
1118
- processors (`Optional`, *optional*): List of cache processors to apply (e.g., quantization, offloading). Defaults to empty list .
1084
+ processor (`Optional`, *optional*): Cache processor to apply (e.g., quantization, offloading).
1119
1085
pattern_block (`Optional`, *optional*): Pattern of cache layer types to use. Defaults to `(StaticLayer,)` for backward compatibility.
1120
1086
1121
1087
@@ -1429,7 +1395,7 @@ class HybridCache(Cache):
1429
1395
1430
1396
Parameters:
1431
1397
config_or_ddp_cache_data (`PretrainedConfig` or `Iterable`, *optional*): Model configuration for shape/device info. No DDP-distributed cache data is supported.
1432
- processors (`CacheProcessorList `, *optional*): List of cache processors to apply (e.g., quantization, offloading). Defaults to empty list .
1398
+ processor (`CacheProcessor `, *optional*): Cache processor to apply (e.g., quantization, offloading).
1433
1399
pattern_block (`tuple[Type[CacheLayer], ...]`, *optional*): Pattern of cache layer types to use. Defaults to `(SlidingWindowLayer, StaticLayer, ..., StaticLayer)`
1434
1400
for backward compatibility.
1435
1401
Example:
@@ -1455,7 +1421,7 @@ class HybridCache(Cache):
1455
1421
def __init__ (
1456
1422
self ,
1457
1423
config_or_ddp_cache_data = None ,
1458
- processors : Optional [CacheProcessorList ] = None ,
1424
+ processor : Optional [CacheProcessor ] = None ,
1459
1425
pattern_block : Optional [tuple [type ["CacheLayer" ], ...]] = None ,
1460
1426
* args ,
1461
1427
** kwargs ,
@@ -1469,7 +1435,7 @@ def __init__(
1469
1435
self .is_sliding = [False ] * model_config .num_hidden_layers
1470
1436
1471
1437
pattern_block = tuple (SlidingWindowLayer if sl else StaticLayer for sl in self .is_sliding )
1472
- super ().__init__ (config_or_ddp_cache_data , processors , pattern_block , * args , ** kwargs )
1438
+ super ().__init__ (config_or_ddp_cache_data , processor , pattern_block , * args , ** kwargs )
1473
1439
1474
1440
1475
1441
class HybridChunkedCache (Cache ):
@@ -1878,18 +1844,14 @@ def __init__(
1878
1844
offload_device : Union [str , torch .device ] = "cpu" ,
1879
1845
layer_device_map : Optional [dict [int , Union [str , torch .device , int ]]] = None ,
1880
1846
) -> None :
1881
- # Create offload processor
1882
- processors = CacheProcessorList ([OffloadedCacheProcessor (offload_device )])
1883
-
1884
- # Initialize the base StaticCache with the processor
1885
1847
super ().__init__ (
1886
1848
config = config ,
1887
1849
max_batch_size = max_batch_size ,
1888
1850
max_cache_len = max_cache_len ,
1889
1851
device = device ,
1890
1852
dtype = dtype ,
1891
1853
layer_device_map = layer_device_map ,
1892
- processors = processors ,
1854
+ processor = OffloadedCacheProcessor ( offload_device ) ,
1893
1855
)
1894
1856
1895
1857
@@ -2230,8 +2192,7 @@ def __init__(self, cache_config: QuantizedCacheConfig) -> None:
2230
2192
else :
2231
2193
raise ValueError (f"Unknown quantization backend `{ cache_config .backend } `" )
2232
2194
2233
- processors = CacheProcessorList ([processor ])
2234
- super ().__init__ (processors = processors )
2195
+ super ().__init__ (processor = processor )
2235
2196
2236
2197
def get_seq_length (self , layer_idx : Optional [int ] = 0 ) -> int :
2237
2198
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
@@ -2240,7 +2201,7 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
2240
2201
# since we cannot get the seq_length of each layer directly and rely on `_seen_tokens` which is
2241
2202
# updated every "layer_idx" == 0, this is a hack to get the actual seq_length for the given layer_idx
2242
2203
# this part of code otherwise fails when used to verify attn_weight shape in some models
2243
- return self .processors [ 0 ] ._seen_tokens if layer_idx == 0 else self .processors [ 0 ] ._seen_tokens - 1
2204
+ return self .processor ._seen_tokens if layer_idx == 0 else self .processor ._seen_tokens - 1
2244
2205
2245
2206
2246
2207
class QuantoQuantizedCache (QuantizedCache ):
@@ -2283,8 +2244,7 @@ class QuantoQuantizedCache(QuantizedCache):
2283
2244
"""
2284
2245
2285
2246
def __init__ (self , cache_config : QuantizedCacheConfig ) -> None :
2286
- processors = CacheProcessorList ([QuantoQuantizedCacheProcessor (cache_config )])
2287
- Cache .__init__ (self , processors = processors )
2247
+ Cache .__init__ (self , processor = QuantoQuantizedCacheProcessor (cache_config ))
2288
2248
2289
2249
2290
2250
class HQQQuantizedCache (QuantizedCache ):
@@ -2327,8 +2287,7 @@ class HQQQuantizedCache(QuantizedCache):
2327
2287
"""
2328
2288
2329
2289
def __init__ (self , cache_config : QuantizedCacheConfig ) -> None :
2330
- processors = CacheProcessorList ([HQQQuantizedCacheProcessor (cache_config )])
2331
- Cache .__init__ (self , processors = processors )
2290
+ Cache .__init__ (self , processor = HQQQuantizedCacheProcessor (cache_config ))
2332
2291
2333
2292
2334
2293
class SinkCache (Cache ):
0 commit comments