diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index fa2b82ab4c2ba4..f202e2fb2aab81 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1481,6 +1481,7 @@ def _prepare_cache_for_generation( model_kwargs: Dict, assistant_model: "PreTrainedModel", batch_size: int, + max_cache_length: int, device: torch.device, ) -> bool: """ @@ -1547,8 +1548,8 @@ def _prepare_cache_for_generation( ) model_kwargs[cache_name] = self._get_cache( cache_implementation=generation_config.cache_implementation, - batch_size=generation_config.num_beams * generation_config.num_return_sequences * batch_size, - max_cache_len=generation_config.max_length, + batch_size=max(generation_config.num_beams, generation_config.num_return_sequences) * batch_size, + max_cache_len=max_cache_length, device=device, model_kwargs=model_kwargs, ) @@ -1888,7 +1889,16 @@ def generate( # TODO (joao): remove `user_defined_cache` after v4.47 (remove default conversion to legacy format) cache_name = "past_key_values" if "mamba" not in self.__class__.__name__.lower() else "cache_params" user_defined_cache = model_kwargs.get(cache_name) - self._prepare_cache_for_generation(generation_config, model_kwargs, assistant_model, batch_size, device) + max_cache_length = generation_config.max_length + if ( + inputs_tensor.shape[1] != input_ids_length + and model_input_name == "inputs_embeds" + and not self.config.is_encoder_decoder + ): + max_cache_length += inputs_tensor.shape[1] + self._prepare_cache_for_generation( + generation_config, model_kwargs, assistant_model, batch_size, max_cache_length, device + ) # 8. determine generation mode generation_mode = generation_config.get_generation_mode(assistant_model) @@ -1936,8 +1946,8 @@ def generate( raise ValueError("assisted generate is only supported for batch_size = 1") if not model_kwargs["use_cache"]: raise ValueError("assisted generate requires `use_cache=True`") - if generation_config.cache_implementation == "static": - raise ValueError("assisted generate is not supported with `static_cache`") + if generation_config.cache_implementation in ["static", "hybrid", "sliding_window"]: + raise ValueError("assisted generate is not supported with Static cache classes`") if self._is_stateful: # In assisted generation we need the ability to confirm whether the model would pick certain tokens, # which is not possible with stateful models (they can't reset to a previous subset of generated text) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 439ea58ae97767..65507795c84dd8 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1453,6 +1453,9 @@ def _prepare_model_kwargs(input_ids, attention_mask, signature): model = model_class(config).to(torch_device).eval() signature = inspect.signature(model.forward).parameters.keys() + # no cache as some models require special cache classes to be init outside forward + model.generation_config.use_cache = False + # Without padding model_kwargs = _prepare_model_kwargs(input_ids, attention_mask, signature) next_logits_wo_padding = model(**model_kwargs).logits[:, -1, :] @@ -1593,6 +1596,59 @@ def test_generate_from_inputs_embeds_decoder_only(self): outputs_from_embeds_wo_ids.tolist(), ) + @pytest.mark.generate + def test_generate_from_inputs_embeds_with_static_cache(self): + """ + Test that StaticCache can generate from inputs_embeds and calculates max_cache_length + correctly in `generate()`. We force the model to not stop generation until max-length is reached + to verify that the cache length is indeed set correctly and we don't run out of index when slicing the cache. + """ + for model_class in self.all_generative_model_classes: + if not model_class._supports_static_cache: + self.skipTest(reason="This model does not support the static cache format") + + config, input_ids, attention_mask = self._get_input_ids_and_config() + if config.is_encoder_decoder: + self.skipTest(reason="This model is encoder-decoder and has Encoder-Decoder Cache") + + model = model_class(config).to(torch_device).eval() + if "inputs_embeds" not in inspect.signature(model.prepare_inputs_for_generation).parameters.keys(): + self.skipTest(reason="This model does not support `inputs_embeds` in generation") + + model.config.use_cache = True + model.config.is_decoder = True + batch_size, seq_length = input_ids.shape + max_cache_len = 30 + + # here we force to not stop at eos and go until max-length + model.generation_config.eos_token_id = model.config.eos_token_id = -1 + generation_kwargs = { + "max_length": max_cache_len, + "cache_implementation": "static", + "return_dict_in_generate": True, # Required to return `past_key_values` + } + + head_dim = ( + model.config.head_dim + if hasattr(model.config, "head_dim") + else model.config.hidden_size // model.config.num_attention_heads + ) + num_key_value_heads = ( + model.config.num_attention_heads + if getattr(config, "num_key_value_heads", None) is None + else model.config.num_key_value_heads + ) + num_hidden_layers = config.num_hidden_layers + + inputs_embeds = model.get_input_embeddings()(input_ids) + outputs = model.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, **generation_kwargs) + + # we should get `max_length` in shape, not `max_length - embeds_length` + cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim) + self.assertTrue(isinstance(outputs.past_key_values, StaticCache)) + self.assertTrue(len(outputs.past_key_values.key_cache) == num_hidden_layers) + self.assertTrue(outputs.past_key_values.key_cache[0].shape == cache_shape) + @pytest.mark.generate def test_generate_continue_from_past_key_values(self): # Tests that we can continue generating from past key values, returned from a previous `generate` call diff --git a/tests/models/gemma2/test_modeling_gemma2.py b/tests/models/gemma2/test_modeling_gemma2.py index 433bcd5da9a45f..918ed847f83d9e 100644 --- a/tests/models/gemma2/test_modeling_gemma2.py +++ b/tests/models/gemma2/test_modeling_gemma2.py @@ -16,9 +16,10 @@ import unittest +from parameterized import parameterized from pytest import mark -from transformers import AutoModelForCausalLM, AutoTokenizer, Gemma2Config, is_torch_available, pipeline +from transformers import AutoModelForCausalLM, AutoTokenizer, Gemma2Config, HybridCache, is_torch_available, pipeline from transformers.testing_utils import ( require_flash_attn, require_read_token, @@ -59,7 +60,7 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase): if is_torch_available() else () ) - all_generative_model_classes = () + all_generative_model_classes = (Gemma2ForCausalLM,) if is_torch_available() else () pipeline_model_mapping = ( { "feature-extraction": Gemma2Model, @@ -89,6 +90,101 @@ def test_model_outputs_equivalence(self, **kwargs): def test_eager_matches_sdpa_inference(self): pass + @parameterized.expand([("random",), ("same",)]) + @unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding") + def test_assisted_decoding_matches_greedy_search(self, assistant_type): + pass + + @unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding") + def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type): + pass + + @unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding") + def test_assisted_decoding_sample(self): + pass + + @unittest.skip("Gemma2 has HybridCache which is not compatible with dola decoding") + def test_dola_decoding_sample(self): + pass + + @parameterized.expand([(1, False), (1, True), (4, False)]) + @unittest.skip("Gemma2 has HybridCache and doesn't support old tuple format at all") + def test_new_cache_format(self, num_beams, do_sample): + pass + + @unittest.skip("Gemma2 has HybridCache and doesn't support continue from past kv") + def test_generate_continue_from_past_key_values(self): + pass + + @unittest.skip("Gemma2 has HybridCache and doesn't support low_memory generation") + def test_beam_search_low_memory(self): + pass + + @unittest.skip("Gemma2 has HybridCache and doesn't support contrastive generation") + def test_contrastive_generate(self): + pass + + @unittest.skip("Gemma2 has HybridCache and doesn't support contrastive generation") + def test_contrastive_generate_dict_outputs_use_cache(self): + pass + + @unittest.skip("Gemma2 has HybridCache and doesn't support contrastive generation") + def test_contrastive_generate_low_memory(self): + pass + + @unittest.skip("Gemma2 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.") + def test_generate_with_static_cache(self): + pass + + @unittest.skip("Gemma2 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.") + def test_generate_from_inputs_embeds_with_static_cache(self): + pass + + # overwrite because HybridCache has fixed length for key/values + def _check_attentions_for_generate( + self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1 + ): + self.assertIsInstance(attentions, tuple) + self.assertListEqual( + [isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions) + ) + self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups) + + for idx, iter_attentions in enumerate(attentions): + tgt_len = min_length + idx if not use_cache else 1 + src_len = min_length + idx if not use_cache else max_length + + expected_shape = ( + batch_size * num_beam_groups, + config.num_attention_heads, + tgt_len, + src_len, + ) + # check attn size + self.assertListEqual( + [layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions) + ) + + # overwrite because HybridCache has fixed length for key/values + def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config, num_beam_groups=1): + self.assertIsInstance(past_key_values, HybridCache) + + # check shape key, value (batch, head, max_seq_length, head_features) + head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads + num_key_value_heads = ( + config.num_attention_heads + if getattr(config, "num_key_value_heads", None) is None + else config.num_key_value_heads + ) + num_hidden_layers = config.num_hidden_layers + + # we should get `max_length` in shape, not `max_length - embeds_length` + # `+1` because the test in Mixin subtracts 1 which is needed for tuple cache + static_cache_shape = (batch_size, num_key_value_heads, seq_length + 1, head_dim) + static_layers = [layer_idx for layer_idx, boolean in enumerate(past_key_values.is_sliding) if not boolean] + self.assertTrue(len(past_key_values.key_cache) == num_hidden_layers) + self.assertTrue(past_key_values.key_cache[static_layers[0]].shape == static_cache_shape) + @unittest.skip("Gemma2's eager attn/sdpa attn outputs are expected to be different") def test_sdpa_equivalence(self): pass @@ -203,6 +299,5 @@ def test_model_9b_flash_attn(self): output = model.generate(**inputs, max_new_tokens=100, do_sample=False) output_text = tokenizer.batch_decode(output, skip_special_tokens=False) - print(output_text) self.assertEqual(output_text, EXPECTED_TEXTS)