Skip to content

Commit ce0bbd5

Browse files
authored
Generate: SinkCache can handle iterative prompts (#27907)
1 parent 94c7653 commit ce0bbd5

File tree

6 files changed

+116
-34
lines changed

6 files changed

+116
-34
lines changed

src/transformers/cache_utils.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,21 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
3838
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
3939
raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")
4040

41+
def get_max_length(self) -> Optional[int]:
42+
"""Returns the maximum sequence length of the cached states, if there is any."""
43+
raise NotImplementedError("Make sure to implement `get_max_length` in a subclass.")
44+
45+
def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
46+
"""Given the sequence length of the new inputs, returns the usable length of the cache."""
47+
# Cache without size limit -> all cache is usable
48+
# Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
49+
# length, we will need to evict part of the cache (and thus not all cache is usable)
50+
max_length = self.get_max_length()
51+
previous_seq_length = self.get_seq_length(layer_idx)
52+
if max_length is not None and previous_seq_length + new_seq_length > max_length:
53+
return max_length - new_seq_length
54+
return previous_seq_length
55+
4156

4257
class DynamicCache(Cache):
4358
"""
@@ -120,6 +135,10 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
120135
return 0
121136
return self.key_cache[layer_idx].shape[-2]
122137

138+
def get_max_length(self) -> Optional[int]:
139+
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
140+
return None
141+
123142
def reorder_cache(self, beam_idx: torch.LongTensor):
124143
"""Reorders the cache for beam search, given the selected beam indices."""
125144
for layer_idx in range(len(self.key_cache)):
@@ -209,8 +228,11 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
209228
# Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
210229
if len(self.key_cache) <= layer_idx:
211230
return 0
212-
cache_length = self.key_cache[layer_idx].shape[-2]
213-
return min(cache_length, self.window_length - 1)
231+
return self.key_cache[layer_idx].shape[-2]
232+
233+
def get_max_length(self) -> Optional[int]:
234+
"""Returns the maximum sequence length of the cached states."""
235+
return self.window_length
214236

215237
def update(
216238
self,
@@ -267,7 +289,9 @@ def update(
267289

268290
# On RoPE models, we need to recompute the Key rotation as the tokens are shifted
269291
if using_rope:
270-
rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(key_states, cos, sin)
292+
rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
293+
key_states, cos[: self.window_length], sin[: self.window_length]
294+
)
271295
if partial_rotation_size is not None:
272296
keys_to_keep, keys_pass = (
273297
keys_to_keep[..., :partial_rotation_size],

src/transformers/models/llama/modeling_llama.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ def forward(
398398
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
399399
"with a layer index."
400400
)
401-
kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
401+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
402402
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
403403
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
404404

@@ -503,7 +503,7 @@ def forward(
503503

504504
kv_seq_len = key_states.shape[-2]
505505
if past_key_value is not None:
506-
kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
506+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
507507
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
508508

509509
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
@@ -910,7 +910,7 @@ def forward(
910910
use_legacy_cache = not isinstance(past_key_values, Cache)
911911
if use_legacy_cache:
912912
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
913-
past_key_values_length = past_key_values.get_seq_length()
913+
past_key_values_length = past_key_values.get_usable_length(seq_length)
914914

915915
if position_ids is None:
916916
device = input_ids.device if input_ids is not None else inputs_embeds.device
@@ -1127,8 +1127,10 @@ def prepare_inputs_for_generation(
11271127
if isinstance(past_key_values, Cache):
11281128
cache_length = past_key_values.get_seq_length()
11291129
past_length = past_key_values.seen_tokens
1130+
max_cache_length = past_key_values.get_max_length()
11301131
else:
11311132
cache_length = past_length = past_key_values[0][0].shape[2]
1133+
max_cache_length = None
11321134

11331135
# Keep only the unprocessed tokens:
11341136
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
@@ -1142,10 +1144,13 @@ def prepare_inputs_for_generation(
11421144
input_ids = input_ids[:, past_length:]
11431145
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
11441146

1145-
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
1146-
# older attention values, as their corresponding values are not part of the input.
1147-
if cache_length < past_length and attention_mask is not None:
1148-
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
1147+
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1148+
if (
1149+
max_cache_length is not None
1150+
and attention_mask is not None
1151+
and cache_length + input_ids.shape[1] > max_cache_length
1152+
):
1153+
attention_mask = attention_mask[:, -max_cache_length:]
11491154

11501155
position_ids = kwargs.get("position_ids", None)
11511156
if attention_mask is not None and position_ids is None:

src/transformers/models/mistral/modeling_mistral.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def forward(
268268
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
269269
"with a layer index."
270270
)
271-
kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
271+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
272272
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
273273
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
274274

@@ -363,7 +363,7 @@ def forward(
363363

364364
kv_seq_len = key_states.shape[-2]
365365
if past_key_value is not None:
366-
kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
366+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
367367

368368
# Because the input can be padded, the absolute sequence length depends on the max position id.
369369
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
@@ -850,15 +850,13 @@ def forward(
850850
else:
851851
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
852852

853-
seq_length_with_past = seq_length
854853
past_key_values_length = 0
855854

856855
if use_cache:
857856
use_legacy_cache = not isinstance(past_key_values, Cache)
858857
if use_legacy_cache:
859858
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
860-
past_key_values_length = past_key_values.get_seq_length()
861-
seq_length_with_past = seq_length_with_past + past_key_values_length
859+
past_key_values_length = past_key_values.get_usable_length(seq_length)
862860

863861
if position_ids is None:
864862
device = input_ids.device if input_ids is not None else inputs_embeds.device
@@ -1092,8 +1090,10 @@ def prepare_inputs_for_generation(
10921090
if isinstance(past_key_values, Cache):
10931091
cache_length = past_key_values.get_seq_length()
10941092
past_length = past_key_values.seen_tokens
1093+
max_cache_length = past_key_values.get_max_length()
10951094
else:
10961095
cache_length = past_length = past_key_values[0][0].shape[2]
1096+
max_cache_length = None
10971097

10981098
# Keep only the unprocessed tokens:
10991099
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
@@ -1107,10 +1107,13 @@ def prepare_inputs_for_generation(
11071107
input_ids = input_ids[:, past_length:]
11081108
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
11091109

1110-
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
1111-
# older attention values, as their corresponding values are not part of the input.
1112-
if cache_length < past_length and attention_mask is not None:
1113-
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
1110+
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1111+
if (
1112+
max_cache_length is not None
1113+
and attention_mask is not None
1114+
and cache_length + input_ids.shape[1] > max_cache_length
1115+
):
1116+
attention_mask = attention_mask[:, -max_cache_length:]
11141117

11151118
position_ids = kwargs.get("position_ids", None)
11161119
if attention_mask is not None and position_ids is None:

src/transformers/models/persimmon/modeling_persimmon.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ def forward(
295295
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
296296
"with a layer index."
297297
)
298-
kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
298+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
299299
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
300300

301301
# Partial rotary embedding
@@ -612,7 +612,7 @@ def forward(
612612
use_legacy_cache = not isinstance(past_key_values, Cache)
613613
if use_legacy_cache:
614614
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
615-
past_key_values_length = past_key_values.get_seq_length()
615+
past_key_values_length = past_key_values.get_usable_length(seq_length)
616616
seq_length_with_past = seq_length_with_past + past_key_values_length
617617

618618
if position_ids is None:
@@ -831,8 +831,10 @@ def prepare_inputs_for_generation(
831831
if isinstance(past_key_values, Cache):
832832
cache_length = past_key_values.get_seq_length()
833833
past_length = past_key_values.seen_tokens
834+
max_cache_length = past_key_values.get_max_length()
834835
else:
835836
cache_length = past_length = past_key_values[0][0].shape[2]
837+
max_cache_length = None
836838

837839
# Keep only the unprocessed tokens:
838840
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
@@ -846,10 +848,13 @@ def prepare_inputs_for_generation(
846848
input_ids = input_ids[:, past_length:]
847849
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
848850

849-
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
850-
# older attention values, as their corresponding values are not part of the input.
851-
if cache_length < past_length and attention_mask is not None:
852-
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
851+
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
852+
if (
853+
max_cache_length is not None
854+
and attention_mask is not None
855+
and cache_length + input_ids.shape[1] > max_cache_length
856+
):
857+
attention_mask = attention_mask[:, -max_cache_length:]
853858

854859
position_ids = kwargs.get("position_ids", None)
855860
if attention_mask is not None and position_ids is None:

src/transformers/models/phi/modeling_phi.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ def forward(
334334
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
335335
"with a layer index."
336336
)
337-
kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
337+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
338338
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
339339

340340
# Partial rotary embedding
@@ -444,7 +444,7 @@ def forward(
444444

445445
kv_seq_len = key_states.shape[-2]
446446
if past_key_value is not None:
447-
kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
447+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
448448
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
449449

450450
# Partial rotary embedding
@@ -855,15 +855,13 @@ def forward(
855855
else:
856856
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
857857

858-
seq_length_with_past = seq_length
859858
past_key_values_length = 0
860859

861860
if use_cache:
862861
use_legacy_cache = not isinstance(past_key_values, Cache)
863862
if use_legacy_cache:
864863
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
865-
past_key_values_length = past_key_values.get_seq_length()
866-
seq_length_with_past = seq_length_with_past + past_key_values_length
864+
past_key_values_length = past_key_values.get_usable_length(seq_length)
867865

868866
if position_ids is None:
869867
device = input_ids.device if input_ids is not None else inputs_embeds.device
@@ -1085,8 +1083,10 @@ def prepare_inputs_for_generation(
10851083
if isinstance(past_key_values, Cache):
10861084
cache_length = past_key_values.get_seq_length()
10871085
past_length = past_key_values.seen_tokens
1086+
max_cache_length = past_key_values.get_max_length()
10881087
else:
10891088
cache_length = past_length = past_key_values[0][0].shape[2]
1089+
max_cache_length = None
10901090

10911091
# Keep only the unprocessed tokens:
10921092
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
@@ -1100,10 +1100,13 @@ def prepare_inputs_for_generation(
11001100
input_ids = input_ids[:, past_length:]
11011101
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
11021102

1103-
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
1104-
# older attention values, as their corresponding values are not part of the input.
1105-
if cache_length < past_length and attention_mask is not None:
1106-
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
1103+
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1104+
if (
1105+
max_cache_length is not None
1106+
and attention_mask is not None
1107+
and cache_length + input_ids.shape[1] > max_cache_length
1108+
):
1109+
attention_mask = attention_mask[:, -max_cache_length:]
11071110

11081111
position_ids = kwargs.get("position_ids", None)
11091112
if attention_mask is not None and position_ids is None:

tests/test_cache_utils.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,3 +187,45 @@ def test_sink_cache_hard(self):
187187
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=3000, past_key_values=cache)
188188
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
189189
self.assertTrue(decoded[0].endswith("to perform a variety of tasks. The Transformer is a neural network"))
190+
191+
def test_sink_cache_iterative_prompts(self):
192+
"""Tests that SinkCache supports more than one new token at once, when shifting the cache"""
193+
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
194+
model = AutoModelForCausalLM.from_pretrained(
195+
"HuggingFaceH4/zephyr-7b-beta", device_map="auto", torch_dtype=torch.float16
196+
)
197+
prompt = (
198+
"Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences "
199+
"and must-see attractions."
200+
)
201+
202+
# Prepare generation settings
203+
cache = SinkCache(window_length=256, num_sink_tokens=4)
204+
input_ids = torch.tensor([], device=model.device, dtype=torch.int)
205+
for _ in range(3):
206+
# Tokenize the prompt with the correct chat template
207+
chat = [{"role": "user", "content": prompt}]
208+
tokenized_chat = tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to(
209+
model.device
210+
)
211+
input_ids = torch.cat((input_ids, tokenized_chat), dim=1)
212+
213+
# Perform the generation
214+
gen_out = model.generate(
215+
input_ids, do_sample=False, max_new_tokens=100, past_key_values=cache, use_cache=True
216+
)
217+
input_ids = gen_out
218+
219+
# We went well beyond the cache length
220+
self.assertTrue(input_ids.shape[1] > cache.get_max_length() * 1.5)
221+
222+
# And it still produces a coherent english
223+
decoded = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
224+
last_output = (
225+
"<|assistant|>\nAs the sun began to set over the Pacific Ocean, I found myself standing on the shores of "
226+
"Waikiki Beach, my heart filled with awe and wonder. I had just returned from a two-week journey to the "
227+
"beautiful island of Hawaii, and it had been an unforgettable experience filled with cultural experiences "
228+
"and must-see attractions that left me breathless.\n\nOne of the most memorable experiences of my trip "
229+
"was visiting the historic district of Honolulu. Here,"
230+
)
231+
self.assertTrue(decoded[0].endswith(last_output))

0 commit comments

Comments
 (0)