Skip to content

Commit d97a02d

Browse files
committed
remove CacheProcessorList
1 parent 04d7a0b commit d97a02d

File tree

1 file changed

+26
-67
lines changed

1 file changed

+26
-67
lines changed

src/transformers/cache_utils.py

Lines changed: 26 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -86,45 +86,6 @@ def post_update(
8686
return key_tensors, value_tensors
8787

8888

89-
class CacheProcessorList(list):
90-
"""
91-
list of cache processors that can be applied to a cache.
92-
"""
93-
94-
def init(self, cache: "Cache", **kwargs) -> None:
95-
"""Initialize all processors in the list."""
96-
for processor in self:
97-
processor.init(cache, **kwargs)
98-
99-
def pre_update(
100-
self,
101-
cache: "Cache",
102-
key_states: torch.Tensor,
103-
value_states: torch.Tensor,
104-
layer_idx: int,
105-
cache_kwargs: Optional[dict[str, Any]] = None,
106-
) -> tuple[torch.Tensor, torch.Tensor]:
107-
"""Apply pre_update hook for all processors."""
108-
for processor in self:
109-
key_states, value_states = processor.pre_update(cache, key_states, value_states, layer_idx, cache_kwargs)
110-
return key_states, value_states
111-
112-
def post_update(
113-
self,
114-
cache: "Cache",
115-
key_tensors: torch.Tensor,
116-
value_tensors: torch.Tensor,
117-
layer_idx: int,
118-
cache_kwargs: Optional[dict[str, Any]] = None,
119-
) -> tuple[torch.Tensor, torch.Tensor]:
120-
"""Apply post_update hook for all processors."""
121-
for processor in self:
122-
key_tensors, value_tensors = processor.post_update(
123-
cache, key_tensors, value_tensors, layer_idx, cache_kwargs
124-
)
125-
return key_tensors, value_tensors
126-
127-
12889
class KVList:
12990
"""Efficiently simulates layer-indexed key or value lists from a layered cache.
13091
This allows for BC access, e.g., cache.key_cache[idx] or cache.value_cache[idx]."""
@@ -228,8 +189,8 @@ class Cache:
228189
config_or_ddp_cache_data (`PretrainedConfig` or `Iterable`, *optional*):
229190
Model configuration for shape/device info, or DDP-distributed cache data for compatibility.
230191
If DDP-distributed cache data, must be an iterable of (key_states, value_states) tuples for each layer.
231-
processors (`CacheProcessorList`, *optional*):
232-
List of cache processors to apply (e.g., quantization, offloading). Defaults to empty list.
192+
processor (`CacheProcessor`, *optional*):
193+
Cache processor to apply (e.g., quantization, offloading).
233194
pattern_block (`tuple[Type[CacheLayer], ...]`, *optional*):
234195
Pattern of cache layer types to use. Defaults to `(DynamicLayer,)`. Must be a tuple whose length divides
235196
the total number of layers. The pattern repeats to fill all layers. Examples: `(StaticLayer,)` for a
@@ -258,13 +219,13 @@ def __init__(
258219
config_or_ddp_cache_data: Optional[
259220
Union[PretrainedConfig, Iterable[tuple[torch.Tensor, torch.Tensor]]]
260221
] = None,
261-
processors: Optional[CacheProcessorList] = None,
222+
processor: Optional[CacheProcessor] = None,
262223
pattern_block: Optional[tuple[type["CacheLayer"], ...]] = None,
263224
*args,
264225
**kwargs,
265226
):
266227
self.layers: list[CacheLayer] = []
267-
self.processors = processors if processors is not None else CacheProcessorList()
228+
self.processor = processor
268229
pattern_block = pattern_block or self.pattern_block or (DynamicLayer,)
269230

270231
if isinstance(config_or_ddp_cache_data, PretrainedConfig):
@@ -280,7 +241,8 @@ def __init__(
280241
assert pattern_block == (DynamicLayer,), "torch DDP is only supported for DynamicCache"
281242
for key_states, value_states in _distributed_cache_data:
282243
self.layers.append(DynamicLayer.from_kv(key_states, value_states))
283-
self.processors.init(self, **kwargs)
244+
if self.processor is not None:
245+
self.processor.init(self, **kwargs)
284246
return
285247
else:
286248
model_config = kwargs.pop("config", None)
@@ -292,7 +254,8 @@ def __init__(
292254
layer = layer_type(self.config.to_layer(idx))
293255
self.layers.append(layer)
294256

295-
self.processors.init(self, **kwargs)
257+
if self.processor is not None:
258+
self.processor.init(self, **kwargs)
296259

297260
def grow_layers_to(self, layer_idx):
298261
while len(self.layers) <= layer_idx:
@@ -335,12 +298,16 @@ def update(
335298
Return:
336299
A tuple containing the updated key and value states.
337300
"""
338-
key_states, value_states = self.processors.pre_update(self, key_states, value_states, layer_idx, cache_kwargs)
301+
if self.processor is not None:
302+
key_states, value_states = self.processor.pre_update(
303+
self, key_states, value_states, layer_idx, cache_kwargs
304+
)
339305
self.grow_layers_to(layer_idx)
340306
key_tensors, value_tensors = self.layers[layer_idx].update(key_states, value_states, cache_kwargs)
341-
key_tensors, value_tensors = self.processors.post_update(
342-
self, key_tensors, value_tensors, layer_idx, cache_kwargs
343-
)
307+
if self.processor is not None:
308+
key_tensors, value_tensors = self.processor.post_update(
309+
self, key_tensors, value_tensors, layer_idx, cache_kwargs
310+
)
344311
return key_tensors, value_tensors
345312

346313
def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]:
@@ -1015,8 +982,7 @@ class OffloadedCache(DynamicCache):
1015982

1016983
def __init__(self, config: Optional[CacheConfig] = None) -> None:
1017984
# Create the underlying cache with offload processor
1018-
processors = CacheProcessorList([OffloadedCacheProcessor()])
1019-
super().__init__(processors=processors, config=config)
985+
super().__init__(processor=OffloadedCacheProcessor(), config=config)
1020986

1021987

1022988
class StaticLayer(CacheLayer):
@@ -1115,7 +1081,7 @@ class StaticCache(Cache):
11151081
11161082
Parameters:
11171083
config_or_ddp_cache_data (`Union`, *optional*): Model configuration for shape/device info, or DDP-distributed cache data for compatibility.
1118-
processors (`Optional`, *optional*): List of cache processors to apply (e.g., quantization, offloading). Defaults to empty list.
1084+
processor (`Optional`, *optional*): Cache processor to apply (e.g., quantization, offloading).
11191085
pattern_block (`Optional`, *optional*): Pattern of cache layer types to use. Defaults to `(StaticLayer,)` for backward compatibility.
11201086
11211087
@@ -1429,7 +1395,7 @@ class HybridCache(Cache):
14291395
14301396
Parameters:
14311397
config_or_ddp_cache_data (`PretrainedConfig` or `Iterable`, *optional*): Model configuration for shape/device info. No DDP-distributed cache data is supported.
1432-
processors (`CacheProcessorList`, *optional*): List of cache processors to apply (e.g., quantization, offloading). Defaults to empty list.
1398+
processor (`CacheProcessor`, *optional*): Cache processor to apply (e.g., quantization, offloading).
14331399
pattern_block (`tuple[Type[CacheLayer], ...]`, *optional*): Pattern of cache layer types to use. Defaults to `(SlidingWindowLayer, StaticLayer, ..., StaticLayer)`
14341400
for backward compatibility.
14351401
Example:
@@ -1455,7 +1421,7 @@ class HybridCache(Cache):
14551421
def __init__(
14561422
self,
14571423
config_or_ddp_cache_data=None,
1458-
processors: Optional[CacheProcessorList] = None,
1424+
processor: Optional[CacheProcessor] = None,
14591425
pattern_block: Optional[tuple[type["CacheLayer"], ...]] = None,
14601426
*args,
14611427
**kwargs,
@@ -1469,7 +1435,7 @@ def __init__(
14691435
self.is_sliding = [False] * model_config.num_hidden_layers
14701436

14711437
pattern_block = tuple(SlidingWindowLayer if sl else StaticLayer for sl in self.is_sliding)
1472-
super().__init__(config_or_ddp_cache_data, processors, pattern_block, *args, **kwargs)
1438+
super().__init__(config_or_ddp_cache_data, processor, pattern_block, *args, **kwargs)
14731439

14741440

14751441
class HybridChunkedCache(Cache):
@@ -1878,18 +1844,14 @@ def __init__(
18781844
offload_device: Union[str, torch.device] = "cpu",
18791845
layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None,
18801846
) -> None:
1881-
# Create offload processor
1882-
processors = CacheProcessorList([OffloadedCacheProcessor(offload_device)])
1883-
1884-
# Initialize the base StaticCache with the processor
18851847
super().__init__(
18861848
config=config,
18871849
max_batch_size=max_batch_size,
18881850
max_cache_len=max_cache_len,
18891851
device=device,
18901852
dtype=dtype,
18911853
layer_device_map=layer_device_map,
1892-
processors=processors,
1854+
processor=OffloadedCacheProcessor(offload_device),
18931855
)
18941856

18951857

@@ -2230,8 +2192,7 @@ def __init__(self, cache_config: QuantizedCacheConfig) -> None:
22302192
else:
22312193
raise ValueError(f"Unknown quantization backend `{cache_config.backend}`")
22322194

2233-
processors = CacheProcessorList([processor])
2234-
super().__init__(processors=processors)
2195+
super().__init__(processor=processor)
22352196

22362197
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
22372198
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
@@ -2240,7 +2201,7 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
22402201
# since we cannot get the seq_length of each layer directly and rely on `_seen_tokens` which is
22412202
# updated every "layer_idx" == 0, this is a hack to get the actual seq_length for the given layer_idx
22422203
# this part of code otherwise fails when used to verify attn_weight shape in some models
2243-
return self.processors[0]._seen_tokens if layer_idx == 0 else self.processors[0]._seen_tokens - 1
2204+
return self.processor._seen_tokens if layer_idx == 0 else self.processor._seen_tokens - 1
22442205

22452206

22462207
class QuantoQuantizedCache(QuantizedCache):
@@ -2283,8 +2244,7 @@ class QuantoQuantizedCache(QuantizedCache):
22832244
"""
22842245

22852246
def __init__(self, cache_config: QuantizedCacheConfig) -> None:
2286-
processors = CacheProcessorList([QuantoQuantizedCacheProcessor(cache_config)])
2287-
Cache.__init__(self, processors=processors)
2247+
Cache.__init__(self, processor=QuantoQuantizedCacheProcessor(cache_config))
22882248

22892249

22902250
class HQQQuantizedCache(QuantizedCache):
@@ -2327,8 +2287,7 @@ class HQQQuantizedCache(QuantizedCache):
23272287
"""
23282288

23292289
def __init__(self, cache_config: QuantizedCacheConfig) -> None:
2330-
processors = CacheProcessorList([HQQQuantizedCacheProcessor(cache_config)])
2331-
Cache.__init__(self, processors=processors)
2290+
Cache.__init__(self, processor=HQQQuantizedCacheProcessor(cache_config))
23322291

23332292

23342293
class SinkCache(Cache):

0 commit comments

Comments
 (0)