@@ -268,7 +268,7 @@ def forward(
268268 "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
269269 "with a layer index."
270270 )
271- kv_seq_len += past_key_value .get_seq_length ( self .layer_idx )
271+ kv_seq_len += past_key_value .get_usable_length ( kv_seq_len , self .layer_idx )
272272 cos , sin = self .rotary_emb (value_states , seq_len = kv_seq_len )
273273 query_states , key_states = apply_rotary_pos_emb (query_states , key_states , cos , sin , position_ids )
274274
@@ -363,7 +363,7 @@ def forward(
363363
364364 kv_seq_len = key_states .shape [- 2 ]
365365 if past_key_value is not None :
366- kv_seq_len += past_key_value .get_seq_length ( self .layer_idx )
366+ kv_seq_len += past_key_value .get_usable_length ( kv_seq_len , self .layer_idx )
367367
368368 # Because the input can be padded, the absolute sequence length depends on the max position id.
369369 rotary_seq_len = max (kv_seq_len , position_ids [:, - 1 ].max ().item ()) + 1
@@ -850,15 +850,13 @@ def forward(
850850 else :
851851 raise ValueError ("You have to specify either decoder_input_ids or decoder_inputs_embeds" )
852852
853- seq_length_with_past = seq_length
854853 past_key_values_length = 0
855854
856855 if use_cache :
857856 use_legacy_cache = not isinstance (past_key_values , Cache )
858857 if use_legacy_cache :
859858 past_key_values = DynamicCache .from_legacy_cache (past_key_values )
860- past_key_values_length = past_key_values .get_seq_length ()
861- seq_length_with_past = seq_length_with_past + past_key_values_length
859+ past_key_values_length = past_key_values .get_usable_length (seq_length )
862860
863861 if position_ids is None :
864862 device = input_ids .device if input_ids is not None else inputs_embeds .device
@@ -1092,8 +1090,10 @@ def prepare_inputs_for_generation(
10921090 if isinstance (past_key_values , Cache ):
10931091 cache_length = past_key_values .get_seq_length ()
10941092 past_length = past_key_values .seen_tokens
1093+ max_cache_length = past_key_values .get_max_length ()
10951094 else :
10961095 cache_length = past_length = past_key_values [0 ][0 ].shape [2 ]
1096+ max_cache_length = None
10971097
10981098 # Keep only the unprocessed tokens:
10991099 # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
@@ -1107,10 +1107,13 @@ def prepare_inputs_for_generation(
11071107 input_ids = input_ids [:, past_length :]
11081108 # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
11091109
1110- # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
1111- # older attention values, as their corresponding values are not part of the input.
1112- if cache_length < past_length and attention_mask is not None :
1113- attention_mask = attention_mask [:, - (cache_length + input_ids .shape [1 ]) :]
1110+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1111+ if (
1112+ max_cache_length is not None
1113+ and attention_mask is not None
1114+ and cache_length + input_ids .shape [1 ] > max_cache_length
1115+ ):
1116+ attention_mask = attention_mask [:, - max_cache_length :]
11141117
11151118 position_ids = kwargs .get ("position_ids" , None )
11161119 if attention_mask is not None and position_ids is None :
0 commit comments