Skip to content

Commit 457b478

Browse files
zucchini-nlpArthurZucker
authored andcommitted
Fix cache-related tests (#39676)
* fix * fix kyutai at last * fix unrelated tests and copies * update musicgen as well * revert tensor * fix old test failures * why it wasn't added?
1 parent 862cb55 commit 457b478

File tree

14 files changed

+89
-38
lines changed

14 files changed

+89
-38
lines changed

src/transformers/generation/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2055,7 +2055,7 @@ def _prepare_cache_for_generation(
20552055
generation_config.cache_implementation = None
20562056

20572057
generation_config.cache_implementation = generation_config.cache_implementation or getattr(
2058-
self.config.get_text_config(), "cache_implementation", None
2058+
self.config.get_text_config(decoder=True), "cache_implementation", None
20592059
)
20602060
if generation_config.cache_implementation is not None:
20612061
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:

src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1215,12 +1215,15 @@ def _prepare_model_inputs(
12151215
cache_methods = [
12161216
"_prepare_cache_for_generation",
12171217
"_get_cache",
1218-
"_supports_default_dynamic_cache",
12191218
"_get_layer_device_map_for_cache_init",
12201219
]
12211220
for method in cache_methods:
12221221
setattr(self.codec_model, method, types.MethodType(getattr(self, method).__func__, self.codec_model))
12231222

1223+
setattr(
1224+
self.codec_model, "_supports_default_dynamic_cache", types.MethodType(lambda x: True, self.codec_model)
1225+
)
1226+
12241227
self.codec_model._prepare_cache_for_generation(
12251228
generation_config=self.codec_model.generation_config,
12261229
model_kwargs=temporary_model_kwargs,

src/transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,12 +344,15 @@ def _prepare_model_inputs(
344344
cache_methods = [
345345
"_prepare_cache_for_generation",
346346
"_get_cache",
347-
"_supports_default_dynamic_cache",
348347
"_get_layer_device_map_for_cache_init",
349348
]
350349
for method in cache_methods:
351350
setattr(self.codec_model, method, types.MethodType(getattr(self, method).__func__, self.codec_model))
352351

352+
setattr(
353+
self.codec_model, "_supports_default_dynamic_cache", types.MethodType(lambda x: True, self.codec_model)
354+
)
355+
353356
self.codec_model._prepare_cache_for_generation(
354357
generation_config=self.codec_model.generation_config,
355358
model_kwargs=temporary_model_kwargs,

src/transformers/models/musicgen/modeling_musicgen.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1246,7 +1246,29 @@ def generate(
12461246
input_ids_length=input_ids_length,
12471247
)
12481248

1249-
# 6. Prepare `input_ids` which will be used for auto-regressive generation
1249+
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
1250+
1251+
# 6. Prepare the cache.
1252+
# - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`.
1253+
# - different models have a different cache name expected by the model (default = "past_key_values")
1254+
# - `max_length`, prepared above, is used to determine the maximum cache length
1255+
max_cache_length = generation_config.max_length - 1
1256+
if (
1257+
input_ids_length.shape[1] != input_ids_length
1258+
and model_input_name == "inputs_embeds"
1259+
and not self.config.is_encoder_decoder
1260+
):
1261+
max_cache_length += input_ids_length.shape[1]
1262+
self._prepare_cache_for_generation(
1263+
generation_config,
1264+
model_kwargs,
1265+
assistant_model=None,
1266+
batch_size=batch_size,
1267+
max_cache_length=max_cache_length,
1268+
device=input_ids_length.device,
1269+
)
1270+
1271+
# 7. Prepare `input_ids` which will be used for auto-regressive generation
12501272
# Build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to MusicGen)
12511273
input_ids, delay_pattern_mask = self.build_delay_pattern_mask(
12521274
input_ids,
@@ -1260,15 +1282,15 @@ def generate(
12601282
# stash the delay mask so that we don't have to recompute it in each forward pass
12611283
model_kwargs["delay_pattern_mask"] = delay_pattern_mask
12621284

1263-
# 7. determine generation mode
1285+
# 8. determine generation mode
12641286
generation_mode = generation_config.get_generation_mode()
12651287

1266-
# 8. prepare batched CFG externally (to enable coexistence with the unbatched CFG)
1288+
# 9. prepare batched CFG externally (to enable coexistence with the unbatched CFG)
12671289
if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1:
12681290
logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale))
12691291
generation_config.guidance_scale = None
12701292

1271-
# 9. prepare distribution pre_processing samplers
1293+
# 10. prepare distribution pre_processing samplers
12721294
logits_processor = self._get_logits_processor(
12731295
generation_config=generation_config,
12741296
input_ids_seq_length=input_ids_length,

src/transformers/models/musicgen_melody/modeling_musicgen_melody.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2162,6 +2162,28 @@ def generate(
21622162
input_ids_length=input_ids_length,
21632163
)
21642164

2165+
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
2166+
2167+
# 7. Prepare the cache.
2168+
# - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`.
2169+
# - different models have a different cache name expected by the model (default = "past_key_values")
2170+
# - `max_length`, prepared above, is used to determine the maximum cache length
2171+
max_cache_length = generation_config.max_length - 1
2172+
if (
2173+
inputs_tensor.shape[1] != input_ids_length
2174+
and model_input_name == "inputs_embeds"
2175+
and not self.config.is_encoder_decoder
2176+
):
2177+
max_cache_length += inputs_tensor.shape[1]
2178+
self._prepare_cache_for_generation(
2179+
generation_config,
2180+
model_kwargs,
2181+
assistant_model=None,
2182+
batch_size=batch_size,
2183+
max_cache_length=max_cache_length,
2184+
device=inputs_tensor.device,
2185+
)
2186+
21652187
# build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to MusicGen)
21662188
input_ids, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
21672189
input_ids,
@@ -2175,15 +2197,15 @@ def generate(
21752197
if streamer is not None:
21762198
streamer.put(input_ids.cpu())
21772199

2178-
# 7. determine generation mode
2200+
# 8. determine generation mode
21792201
generation_mode = generation_config.get_generation_mode()
21802202

2181-
# 8. prepare batched CFG externally (to enable coexistence with the unbatched CFG)
2203+
# 9. prepare batched CFG externally (to enable coexistence with the unbatched CFG)
21822204
if generation_config.guidance_scale is not None and generation_config.guidance_scale > 1:
21832205
logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale))
21842206
generation_config.guidance_scale = None
21852207

2186-
# 9. prepare distribution pre_processing samplers
2208+
# 10. prepare distribution pre_processing samplers
21872209
logits_processor = self._get_logits_processor(
21882210
generation_config=generation_config,
21892211
input_ids_seq_length=input_ids_length,

src/transformers/models/rag/modeling_rag.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1204,8 +1204,6 @@ def _reorder_stacked(hidden_states, new_order):
12041204
if isinstance(past_key_values, EncoderDecoderCache):
12051205
reordered_past = EncoderDecoderCache.from_legacy_cache(reordered_past)
12061206

1207-
if isinstance(past_key_values, EncoderDecoderCache):
1208-
reordered_past = EncoderDecoderCache.from_legacy_cache(reordered_past)
12091207
return reordered_past
12101208

12111209
def marginalize(self, seq_logits, doc_scores, n_docs=None):
@@ -1593,13 +1591,6 @@ def extend_enc_output(tensor, num_beams=None):
15931591
if generation_config.num_return_sequences > generation_config.num_beams:
15941592
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
15951593

1596-
# 11. interleave input_ids with `num_beams` additional sequences per batch
1597-
input_ids, model_kwargs = self._expand_inputs_for_generation(
1598-
input_ids=input_ids,
1599-
expand_size=generation_config.num_beams,
1600-
is_encoder_decoder=self.config.is_encoder_decoder,
1601-
**model_kwargs,
1602-
)
16031594
return self._beam_search(
16041595
input_ids,
16051596
logits_processor=pre_processor,

src/transformers/models/roformer/modeling_roformer.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,17 @@ def forward(
261261
.transpose(1, 2)
262262
)
263263

264+
# Apply RoPE if self attention
265+
if not is_cross_attention and sinusoidal_pos is not None:
266+
if self.rotary_value:
267+
query_layer, key_layer, value_layer = self.apply_rotary_position_embeddings(
268+
sinusoidal_pos, query_layer, key_layer, value_layer
269+
)
270+
else:
271+
query_layer, key_layer = self.apply_rotary_position_embeddings(
272+
sinusoidal_pos, query_layer, key_layer
273+
)
274+
264275
if past_key_value is not None:
265276
# save all key/value_layer to cache to be re-used for fast auto-regressive generation
266277
cache_position = cache_position if not is_cross_attention else None
@@ -381,13 +392,13 @@ def forward(
381392
):
382393
self_outputs = self.self(
383394
hidden_states,
384-
attention_mask,
385-
sinusoidal_pos,
386-
head_mask,
387-
encoder_hidden_states,
388-
past_key_value,
389-
output_attentions,
390-
cache_position,
395+
attention_mask=attention_mask,
396+
sinusoidal_pos=sinusoidal_pos,
397+
head_mask=head_mask,
398+
encoder_hidden_states=encoder_hidden_states,
399+
past_key_value=past_key_value,
400+
output_attentions=output_attentions,
401+
cache_position=cache_position,
391402
)
392403
attention_output = self.output(self_outputs[0], hidden_states)
393404
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them

src/transformers/models/superglue/modeling_superglue.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def forward(
274274
# such that the encoder's padding tokens are not attended to.
275275
is_cross_attention = encoder_hidden_states is not None
276276
current_states = encoder_hidden_states if is_cross_attention else hidden_states
277-
attention_mask = encoder_attention_mask if is_cross_attention else encoder_attention_mask
277+
attention_mask = encoder_attention_mask if is_cross_attention else attention_mask
278278

279279
batch_size = hidden_states.shape[0]
280280
key_layer = (

tests/models/llava_next/test_modeling_llava_next.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -515,7 +515,7 @@ def test_small_model_integration_test_full_vision_state_selection(self):
515515
# test that changing `strategy` won't error out
516516
model.vision_feature_select_strategy = "full"
517517

518-
inputs = self.processor(self.prompt, self.image, return_tensors="pt").to(model.device)
518+
inputs = self.processor(text=self.prompt, images=self.image, return_tensors="pt").to(model.device)
519519

520520
# verify generation
521521
output = model.generate(**inputs, max_new_tokens=30)
@@ -536,7 +536,7 @@ def test_granite_vision(self):
536536
model = LlavaNextForConditionalGeneration.from_pretrained(granite_model_path)
537537
self.processor = AutoProcessor.from_pretrained(granite_model_path)
538538
prompt = "<|user|>\n<image>\nWhat is shown in this image?\n<|assistant|>\n"
539-
inputs = self.processor(prompt, self.image, return_tensors="pt").to(model.device)
539+
inputs = self.processor(text=prompt, images=self.image, return_tensors="pt").to(model.device)
540540

541541
# verify generation
542542
output = model.generate(**inputs, max_new_tokens=30)

tests/models/llava_next_video/test_modeling_llava_next_video.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,9 @@ def test_small_model_integration_test_batch_matches_single(self):
467467
padding=True,
468468
).to(torch_device)
469469

470-
inputs_single = self.processor(self.prompt_video, videos=[self.video], return_tensors="pt").to(torch_device)
470+
inputs_single = self.processor(text=self.prompt_video, videos=[self.video], return_tensors="pt").to(
471+
torch_device
472+
)
471473

472474
# verify generation
473475
output_batched = model.generate(**inputs_batched, do_sample=False, max_new_tokens=50)

0 commit comments

Comments
 (0)