Skip to content

Commit a6b7562

Browse files
committed
back to storage inside Cache()
1 parent e80c68a commit a6b7562

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+444
-348
lines changed

docs/source/en/cache_explanation.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@ Layers can be of different types (e.g. `DynamicLayer`, `StaticLayer`, `SlidingWi
8989
The simplest is a `DynamicLayer` that grows as more tokens are processed. The sequence length dimension (`seq_len`) increases with each new token:
9090

9191
```py
92-
cache.layers[idx].keys = torch.cat([cache.layers[idx].keys, key_states], dim=-2)
93-
cache.layers[idx].values = torch.cat([cache.layers[idx].values, value_states], dim=-2)
92+
cache.key_cache[idx] = torch.cat([cache.key_cache[idx], key_states], dim=-2)
93+
cache.value_cache[idx] = torch.cat([cache.value_cache[idx], value_states], dim=-2)
9494
```
9595

9696
Other layers like `StaticLayer` and `SlidingWindowLayer` have a fixed sequence length that is set when the cache is created. This makes them compatible with `torch.compile`. In the case of `SlidingWindowLayer`, existing tokens are shifted out of the cache when a new token is added.

src/transformers/cache_utils.py

Lines changed: 265 additions & 168 deletions
Large diffs are not rendered by default.

src/transformers/integrations/executorch.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -282,8 +282,8 @@ def __init__(self, model: PreTrainedModel):
282282
dtype=self.model.dtype,
283283
)
284284
for i in range(len(self.static_cache)):
285-
self.register_buffer(f"key_cache_{i}", self.static_cache.layers[i].keys, persistent=False)
286-
self.register_buffer(f"value_cache_{i}", self.static_cache.layers[i].values, persistent=False)
285+
self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False)
286+
self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i], persistent=False)
287287

288288
def forward(self, input_ids: torch.Tensor, cache_position: torch.Tensor):
289289
"""
@@ -413,8 +413,8 @@ def __init__(
413413

414414
# Register all key and value cache tensors as buffers
415415
for i in range(len(self.cache)):
416-
self.register_buffer(f"key_cache_{i}", self.cache.layers[i].keys, persistent=False)
417-
self.register_buffer(f"value_cache_{i}", self.cache.layers[i].values, persistent=False)
416+
self.register_buffer(f"key_cache_{i}", self.cache.key_cache[i], persistent=False)
417+
self.register_buffer(f"value_cache_{i}", self.cache.value_cache[i], persistent=False)
418418

419419
def forward(
420420
self,
@@ -559,8 +559,8 @@ def __init__(self, model, max_static_cache_length, batch_size):
559559

560560
# Register cache buffers to make them exportable
561561
for i in range(len(self.static_cache)):
562-
self.register_buffer(f"key_cache_{i}", self.static_cache.layers[i].keys, persistent=False)
563-
self.register_buffer(f"value_cache_{i}", self.static_cache.layers[i].values, persistent=False)
562+
self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False)
563+
self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i], persistent=False)
564564

565565
def forward(self, decoder_input_ids, encoder_hidden_states, cache_position):
566566
# Get outputs from decoder

src/transformers/models/bart/modeling_bart.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,8 +230,8 @@ def forward(
230230
current_states = key_value_states if is_cross_attention else hidden_states
231231
if is_cross_attention and past_key_value is not None and is_updated:
232232
# reuse k,v, cross_attentions
233-
key_states = curr_past_key_value.layers[self.layer_idx].keys
234-
value_states = curr_past_key_value.layers[self.layer_idx].values
233+
key_states = curr_past_key_value.key_cache[self.layer_idx]
234+
value_states = curr_past_key_value.value_cache[self.layer_idx]
235235
else:
236236
key_states = self.k_proj(current_states)
237237
value_states = self.v_proj(current_states)

src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1293,8 +1293,8 @@ def forward(
12931293
current_states = key_value_states if is_cross_attention else hidden_states
12941294
if is_cross_attention and past_key_value is not None and is_updated:
12951295
# reuse k,v, cross_attentions
1296-
key_states = curr_past_key_value.layers[self.layer_idx].keys
1297-
value_states = curr_past_key_value.layers[self.layer_idx].values
1296+
key_states = curr_past_key_value.key_cache[self.layer_idx]
1297+
value_states = curr_past_key_value.value_cache[self.layer_idx]
12981298
else:
12991299
key_states = self.k_proj(current_states)
13001300
value_states = self.v_proj(current_states)

src/transformers/models/biogpt/modeling_biogpt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,8 @@ def forward(
207207
current_states = key_value_states if is_cross_attention else hidden_states
208208
if is_cross_attention and past_key_value is not None and is_updated:
209209
# reuse k,v, cross_attentions
210-
key_states = curr_past_key_value.layers[self.layer_idx].keys
211-
value_states = curr_past_key_value.layers[self.layer_idx].values
210+
key_states = curr_past_key_value.key_cache[self.layer_idx]
211+
value_states = curr_past_key_value.value_cache[self.layer_idx]
212212
else:
213213
key_states = self.k_proj(current_states)
214214
value_states = self.v_proj(current_states)

src/transformers/models/blenderbot/modeling_blenderbot.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,8 +229,8 @@ def forward(
229229
current_states = key_value_states if is_cross_attention else hidden_states
230230
if is_cross_attention and past_key_value is not None and is_updated:
231231
# reuse k,v, cross_attentions
232-
key_states = curr_past_key_value.layers[self.layer_idx].keys
233-
value_states = curr_past_key_value.layers[self.layer_idx].values
232+
key_states = curr_past_key_value.key_cache[self.layer_idx]
233+
value_states = curr_past_key_value.value_cache[self.layer_idx]
234234
else:
235235
key_states = self.k_proj(current_states)
236236
value_states = self.v_proj(current_states)

src/transformers/models/blenderbot_small/modeling_blenderbot_small.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,8 @@ def forward(
213213
current_states = key_value_states if is_cross_attention else hidden_states
214214
if is_cross_attention and past_key_value is not None and is_updated:
215215
# reuse k,v, cross_attentions
216-
key_states = curr_past_key_value.layers[self.layer_idx].keys
217-
value_states = curr_past_key_value.layers[self.layer_idx].values
216+
key_states = curr_past_key_value.key_cache[self.layer_idx]
217+
value_states = curr_past_key_value.value_cache[self.layer_idx]
218218
else:
219219
key_states = self.k_proj(current_states)
220220
value_states = self.v_proj(current_states)

src/transformers/models/dia/modeling_dia.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -356,8 +356,8 @@ def forward(
356356
is_updated = past_key_values.is_updated.get(self.layer_idx) if past_key_values is not None else False
357357
if past_key_values is not None and is_updated:
358358
# reuse k,v, cross_attentions
359-
key_states = past_key_values.cross_attention_cache.layers[self.layer_idx].keys
360-
value_states = past_key_values.cross_attention_cache.layers[self.layer_idx].values
359+
key_states = past_key_values.cross_attention_cache.key_cache[self.layer_idx]
360+
value_states = past_key_values.cross_attention_cache.value_cache[self.layer_idx]
361361
else:
362362
key_states = self.k_proj(cross_attention_states).view(cross_shape).transpose(1, 2)
363363
value_states = self.v_proj(cross_attention_states).view(cross_shape).transpose(1, 2)

src/transformers/models/dia/modular_dia.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,8 @@ def forward(
182182
is_updated = past_key_values.is_updated.get(self.layer_idx) if past_key_values is not None else False
183183
if past_key_values is not None and is_updated:
184184
# reuse k,v, cross_attentions
185-
key_states = past_key_values.cross_attention_cache.layers[self.layer_idx].keys
186-
value_states = past_key_values.cross_attention_cache.layers[self.layer_idx].values
185+
key_states = past_key_values.cross_attention_cache.key_cache[self.layer_idx]
186+
value_states = past_key_values.cross_attention_cache.value_cache[self.layer_idx]
187187
else:
188188
key_states = self.k_proj(cross_attention_states).view(cross_shape).transpose(1, 2)
189189
value_states = self.v_proj(cross_attention_states).view(cross_shape).transpose(1, 2)

0 commit comments

Comments
 (0)