diff --git a/src/transformers/models/recurrentgemma/configuration_recurrentgemma.py b/src/transformers/models/recurrentgemma/configuration_recurrentgemma.py index f5840de5c34f66..b1feb74808f00e 100644 --- a/src/transformers/models/recurrentgemma/configuration_recurrentgemma.py +++ b/src/transformers/models/recurrentgemma/configuration_recurrentgemma.py @@ -51,7 +51,7 @@ class RecurrentGemmaConfig(PretrainedConfig): Dimension of the hidden representations. intermediate_size (`int`, *optional*, defaults to 24576): Dimension of the MLP representations. - num_heads (`int`, *optional*, defaults to 10): + num_attention_heads (`int`, *optional*, defaults to 10): The number of heads for the attention block and the number of heads/blocks for the block-diagonal layers used in the RG-LRU gates. This number must divide `hidden_size` and `lru_width`. @@ -104,7 +104,7 @@ def __init__( vocab_size=256000, hidden_size=2560, intermediate_size=3 * 2560, - num_heads=10, + num_attention_heads=10, lru_width=None, embeddings_scale_by_sqrt_dim=True, attention_window_size=2048, @@ -124,8 +124,7 @@ def __init__( self.vocab_size = vocab_size self.hidden_size = hidden_size self.intermediate_size = intermediate_size - self.num_heads = num_heads - self.head_dim = self.hidden_size // self.num_heads + self.num_attention_heads = num_attention_heads self.lru_width = lru_width if lru_width is not None else hidden_size self.embeddings_scale_by_sqrt_dim = embeddings_scale_by_sqrt_dim self.attention_window_size = attention_window_size @@ -134,7 +133,9 @@ def __init__( self.rms_norm_eps = rms_norm_eps self.use_cache = use_cache self.rope_theta = rope_theta - self._block_types = block_types + self._block_types = list(block_types) + + self.head_dim = self.hidden_size // self.num_attention_heads super().__init__( pad_token_id=pad_token_id, diff --git a/src/transformers/models/recurrentgemma/modeling_recurrentgemma.py b/src/transformers/models/recurrentgemma/modeling_recurrentgemma.py index 277a911a1635d2..17b415ffe0103e 100644 --- a/src/transformers/models/recurrentgemma/modeling_recurrentgemma.py +++ b/src/transformers/models/recurrentgemma/modeling_recurrentgemma.py @@ -103,9 +103,9 @@ def _apply_rope( Returns: Rotated keys or queries in first half (along with original in second half). """ - batch_size, sequence_length = positions.shape + batch_size, sequence_length, *_ = inputs.shape x_rope, x = torch.chunk(inputs, 2, dim=-1) - positions = positions.reshape(batch_size, sequence_length, 1, 1) + positions = positions.reshape(1, sequence_length, 1, 1) freq = torch.arange(x_rope.shape[-1] // 2, device=x.device) freq_exponents = 2 * freq / x_rope.shape[-1] @@ -163,7 +163,7 @@ def _compute_causal_mask( return mask -def _compute_forward_pass_mask( +def compute_forward_pass_mask( segment_pos: torch.Tensor, window_size: int, ) -> torch.Tensor: @@ -179,12 +179,7 @@ def _compute_forward_pass_mask( """ segment_ids = torch.cumsum(segment_pos == 0, dim=-1) positions = torch.arange(segment_pos.shape[-1], device=segment_pos.device) - positions = torch.repeat_interleave( - positions[None], segment_pos.shape[0], dim=0 - ) - return _compute_causal_mask( - positions, positions, window_size, segment_ids, segment_ids - ) + return _compute_causal_mask(positions, positions, window_size, segment_ids, segment_ids) def _compute_cache_mask( @@ -202,14 +197,11 @@ def _compute_cache_mask( inference step with a KV-cache of the local attention. """ device = num_tokens.device - q_positions = num_tokens[:, None] + q_positions = num_tokens[None] k_positions = torch.arange(window_size + 1, device=device) - window_size - k_positions = torch.repeat_interleave( - k_positions[None], q_positions.shape[0], dim=0 - ) - k_positions = k_positions + num_tokens[:, None] - return _compute_causal_mask(q_positions, k_positions, window_size, None, - None) + k_positions = k_positions + num_tokens + # Add batch dimension + return _compute_causal_mask(q_positions, k_positions, window_size, None, None) def _update_attention_cache( @@ -271,7 +263,7 @@ def _attention_cache_from_prompt( return dict( keys=torch.concatenate([k_padding, keys[:, -w:]], dim=1), values=torch.concatenate([v_padding, values[:, -w:]], dim=1), - num_tokens=segment_pos[:, -1] + 1, + num_tokens=segment_pos[-1] + 1, ) @@ -385,7 +377,7 @@ def forward( the input sequence. """ b, t, _ = x.shape - assert segment_pos.shape == (b, t), segment_pos.shape + assert segment_pos.shape == (t,), f"{segment_pos.shape} != {(t,)}" # Generate keys, values and queries. queries = self.proj_q(x) @@ -410,16 +402,22 @@ def forward( keys = torch.concatenate([cache["keys"], keys], dim=-3) values = torch.concatenate([cache["values"], values], dim=-3) + if attention_mask is None: + attention_mask = _compute_cache_mask(segment_pos, self.window_size) else: new_cache = _attention_cache_from_prompt( keys, values, segment_pos, self.window_size ) + if attention_mask is None: + attention_mask = compute_forward_pass_mask(segment_pos, self.window_size) + # Compute attention. logits = einops.einsum(queries, keys, "b t n h, b s n h -> b n t s") logits = logits * (self.head_dim ** -0.5) - # Expand for heads axis. - attn_mask = torch.unsqueeze(attention_mask, dim=1) + + # Expand for batch and heads axis. + attn_mask = attention_mask[None, None].type(torch.bool) masked_logits = torch.where(attn_mask, logits, _MIN_LOGITS_VALUE) masked_logits = masked_logits.type(torch.float32) @@ -447,8 +445,7 @@ def init_cache( return dict( keys=torch.zeros(shape, device=device, dtype=dtype), values=torch.zeros(shape, device=device, dtype=dtype), - num_tokens=torch.zeros([batch_size], dtype=torch.int32, - device=device), + num_tokens=torch.zeros([], dtype=torch.int32, device=device), ) @@ -810,6 +807,7 @@ def forward( than the returned updated cache is empty initialized and filled in from the input sequence. """ + assert segment_pos.shape == (x.shape[1],), f"{segment_pos.shape} != {(x.shape[1],)}" raw_x = x inputs_normalized = self.temporal_pre_norm(raw_x) @@ -1057,7 +1055,7 @@ def rnn_scan( assert h0 is None or h0.dtype == acc_dtype # Multiply `a` by the reset. - a = a * ~reset[..., None] + a = a * ~reset if x.shape[1] == 1: # Using scan in sampling mode. @@ -1201,8 +1199,8 @@ def __call__( """ bs, l, _ = x.shape - assert segment_pos.shape == (bs, l), segment_pos.shape - reset = segment_pos == 0 + assert segment_pos.shape == (l,), f"{segment_pos.shape} != {(l,)}" + reset = segment_pos[None, :, None] == 0 # Gates for x and a. gate_x = torch.sigmoid(self.input_gate(x)) @@ -1219,7 +1217,7 @@ def __call__( # Apply gamma normalization to the input. We need to clip the derivatives of # `sqrt` in order to prevent NaNs during training in bfloat16. multiplier = SqrtBoundDerivative.apply(1 - a_square) - multiplier = reset[..., None] + ~reset[..., None] * multiplier + multiplier = reset + ~reset * multiplier normalized_x = gated_x * multiplier.type(x.dtype) y, last_h = rnn_scan( @@ -1308,6 +1306,8 @@ def forward( Returns: The output of the convolution and the updated state. """ + assert segment_pos.shape == (x.shape[1],) + if state is not None: # 1. Decoding mode: # - We have access to the previous `self.temporal_width - 1` inputs. @@ -1350,8 +1350,7 @@ def forward( end_idx=end_idx, max_look_ahead=temporal_shift, ) - x_window *= window_mask[:, :, None].type(x.dtype).to( - device=x.device) + x_window *= window_mask[None, :, None].type(x.dtype).to(device=x.device) x_window = self._pad_window(x_window, output_len) @@ -1426,8 +1425,8 @@ def _compute_document_mask( """Creates a mask to prevent mixing of information between documents. Args: - segment_pos: Position of each token in the sequence. In particular, a - zero indicates the start of a new document. + segment_pos: Position of each token in the sequence. In particular, + a zero indicates the start of a new document. start_idx: The starting index of the convolution window. end_idx: The ending index of the convolution window. max_look_ahead: How much to look ahead at most to detect a document @@ -1437,17 +1436,12 @@ def _compute_document_mask( An integer mask where `1` indicates a position that should be included in the convolution, and `0` a position that should be excluded. """ - batch_size = segment_pos.shape[0] not_a_document_boundary = (segment_pos != 0).type(torch.int32) - mask = torch.ones( - (batch_size, end_idx - start_idx), - device=segment_pos.device, - ) + mask = torch.ones((end_idx - start_idx), device=segment_pos.device) for shift in range(1, max_look_ahead + 1): # At each position, look ahead by `shift` tokens to see if a # document boundary is present there. - mask *= not_a_document_boundary[:, - start_idx + shift: end_idx + shift] + mask *= not_a_document_boundary[start_idx + shift: end_idx + shift] return mask def _pad_window( @@ -1568,7 +1562,7 @@ def __init__(self, config: RecurrentGemmaConfig, batch_size, dtype=torch.float16 self.states.append(ResidualBlock.init_cache( batch_size=batch_size, width=config.hidden_size, - num_heads=config.num_heads, + num_heads=config.num_attention_heads, attention_window_size=config.attention_window_size, temporal_block_type=block_type, lru_width=config.lru_width, @@ -1599,7 +1593,7 @@ class GriffinOutput(ModelOutput): """ last_hidden_state: Optional[torch.FloatTensor] = None - state: Optional[GriffinCache] = None + past_key_values: Optional[GriffinCache] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None @@ -1613,7 +1607,7 @@ class GriffinCausalLMOutput(ModelOutput): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - state (`GriffinCache`): + past_key_values (`GriffinCache`): The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to avoid providing the old `input_ids`. @@ -1627,8 +1621,9 @@ class GriffinCausalLMOutput(ModelOutput): loss: Optional[torch.FloatTensor] = None logits: Optional[torch.FloatTensor] = None - state: Optional[GriffinCache] = None + past_key_values: Optional[GriffinCache] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: tuple | None = None # END: adapted from mamba. @@ -1803,7 +1798,7 @@ def __init__(self, config: RecurrentGemmaConfig): ResidualBlock( width=self.config.hidden_size, mlp_expanded_width=self.config.intermediate_size, - num_heads=self.config.num_heads, + num_heads=self.config.num_attention_heads, attention_window_size=self.config.attention_window_size, temporal_block_type=block_type, lru_width=self.config.lru_width, @@ -1834,14 +1829,13 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, cache_position: Optional[torch.LongTensor] = None, - cache: Optional[GriffinCache] = None, + past_key_values: Optional[GriffinCache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs, ) -> Union[Tuple, GriffinOutput]: - print(kwargs) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) @@ -1854,7 +1848,10 @@ def forward( ) if cache_position is None: - raise ValueError("You must provide a `cache_position`.") + if input_ids is not None: + cache_position = torch.arange(input_ids.shape[1], device=input_ids.device) + else: + cache_position = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) if self.gradient_checkpointing and self.training and use_cache: logger.warning_once( @@ -1877,6 +1874,7 @@ def forward( all_hidden_states = () if output_hidden_states else None new_cache = None + cache = past_key_values for i, residual_block in enumerate(self.blocks): if output_hidden_states: @@ -1885,14 +1883,14 @@ def forward( layer_outputs = self._gradient_checkpointing_func( residual_block.__call__, hidden_states, - cache_position[None], + cache_position, attention_mask, None if cache is None else cache.states[i], ) else: layer_outputs = residual_block( hidden_states, - cache_position[None], + cache_position, attention_mask, None if cache is None else cache.states[i], ) @@ -1920,7 +1918,7 @@ def forward( return GriffinOutput( last_hidden_state=hidden_states, - state=new_cache, + past_key_values=new_cache, hidden_states=all_hidden_states, ) @@ -2039,7 +2037,7 @@ def forward( cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, - cache: Optional[GriffinCache] = None, + past_key_values: Optional[GriffinCache] = None, labels: Optional[torch.LongTensor] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, @@ -2081,7 +2079,7 @@ def forward( input_ids=input_ids, cache_position=cache_position, attention_mask=attention_mask, - cache=cache, + past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_hidden_states=output_hidden_states, @@ -2117,41 +2115,41 @@ def forward( return GriffinCausalLMOutput( loss=loss, logits=logits, - state=outputs.state, + past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, + attentions=(), ) def prepare_inputs_for_generation( self, input_ids, - state: Optional[GriffinCache] = None, + past_key_values: Optional[GriffinCache] = None, cache_position: Optional[torch.LongTensor] = None, inputs_embeds=None, attention_mask=None, **kwargs, ): - if state is not None: + if past_key_values is not None or cache_position.shape[0] == 1: input_ids = input_ids[:, -1].unsqueeze(-1) + if cache_position is not None: + cache_position = cache_position[-1:] - if inputs_embeds is not None and state is None: + if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} - model_inputs["state"] = state + model_inputs["past_key_values"] = past_key_values model_inputs["cache_position"] = cache_position - batch_size = input_ids.shape[0] - num_tokens = torch.zeros([batch_size], dtype=torch.int32,) - - if state is not None: + if past_key_values is not None: attn_mask = _compute_cache_mask( - num_tokens, + torch.zeros([], dtype=torch.int32, device=input_ids.device), self.config.attention_window_size, ) else: - attn_mask = _compute_forward_pass_mask( - cache_position[None], + attn_mask = compute_forward_pass_mask( + cache_position, self.config.attention_window_size, ) diff --git a/tests/models/recurrentgemma/test_modeling_recurrentgemma.py b/tests/models/recurrentgemma/test_modeling_recurrentgemma.py index b5bcc00f7e063e..5d7440b7fe06a9 100644 --- a/tests/models/recurrentgemma/test_modeling_recurrentgemma.py +++ b/tests/models/recurrentgemma/test_modeling_recurrentgemma.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" Testing suite for the PyTorch RecurrentGemma model. """ +""" Testing suite for the PyTorch RecurrentRecurrentGemma model. """ import tempfile import unittest @@ -41,6 +41,7 @@ import torch from transformers import RecurrentGemmaForCausalLM, RecurrentGemmaModel + from transformers.models.recurrentgemma.modeling_recurrentgemma import compute_forward_pass_mask class RecurrentGemmaModelTester: @@ -50,14 +51,14 @@ def __init__( batch_size=13, seq_length=48, is_training=True, - use_input_mask=False, + use_input_mask=True, use_token_type_ids=False, use_labels=True, num_hidden_layers=3, vocab_size=99, hidden_size=32, intermediate_size=3 * 32, - num_heads=2, + num_attention_heads=2, lru_width=2 * 32, embeddings_scale_by_sqrt_dim=True, attention_window_size=16, @@ -85,7 +86,7 @@ def __init__( self.vocab_size = vocab_size self.hidden_size = hidden_size self.intermediate_size = intermediate_size - self.num_heads = num_heads + self.num_attention_heads = num_attention_heads self.lru_width = lru_width if lru_width is not None else hidden_size self.embeddings_scale_by_sqrt_dim = embeddings_scale_by_sqrt_dim self.attention_window_size = attention_window_size @@ -108,7 +109,10 @@ def prepare_config_and_inputs(self): input_mask = None if self.use_input_mask: - input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device) + input_mask = compute_forward_pass_mask( + torch.arange(self.seq_length).to(torch_device), + self.attention_window_size, + ) token_type_ids = None if self.use_token_type_ids: @@ -133,7 +137,7 @@ def get_config(self): vocab_size=self.vocab_size, hidden_size=self.hidden_size, intermediate_size=self.intermediate_size, - num_heads=self.num_heads, + num_attention_heads=self.num_attention_heads, lru_width=self.lru_width, embeddings_scale_by_sqrt_dim=self.embeddings_scale_by_sqrt_dim, attention_window_size=self.attention_window_size, @@ -145,7 +149,7 @@ def get_config(self): pad_token_id=self.pad_token_id, ) - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->Gemma + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->RecurrentGemma def create_and_check_model( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): @@ -156,7 +160,7 @@ def create_and_check_model( result = model(input_ids) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model_as_decoder with Llama->Gemma + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model_as_decoder with Llama->RecurrentGemma def create_and_check_model_as_decoder( self, config, @@ -187,7 +191,7 @@ def create_and_check_model_as_decoder( result = model(input_ids, attention_mask=input_mask) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_for_causal_lm with Llama->Gemma + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_for_causal_lm with Llama->RecurrentGemma def create_and_check_for_causal_lm( self, config, @@ -206,7 +210,7 @@ def create_and_check_for_causal_lm( result = model(input_ids, attention_mask=input_mask, labels=token_labels) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_decoder_model_past_large_inputs with Llama->Gemma + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_decoder_model_past_large_inputs with Llama->RecurrentGemma def create_and_check_decoder_model_past_large_inputs( self, config, @@ -269,7 +273,7 @@ def create_and_check_decoder_model_past_large_inputs( # test that outputs are equal for slice self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) - # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs_for_common with Llama->Gemma + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.prepare_config_and_inputs_for_common with Llama->RecurrentGemma def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() ( @@ -281,20 +285,24 @@ def prepare_config_and_inputs_for_common(self): token_labels, choice_labels, ) = config_and_inputs - inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} + inputs_dict = { + "input_ids": input_ids, + "attention_mask": input_mask, + "labels": token_labels, + } return config, inputs_dict @require_torch class RecurrentGemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): - all_model_classes = (RecurrentGemmaModel, RecurrentGemmaForCausalLM) if is_torch_available() else () + all_model_classes = (RecurrentGemmaForCausalLM,) if is_torch_available() else () all_generative_model_classes = (RecurrentGemmaForCausalLM,) if is_torch_available() else () pipeline_model_mapping = ( { - "feature-extraction": RecurrentGemmaModel, - # "text-classification": GemmaForSequenceClassification, + # "feature-extraction": RecurrentGemmaModel, + # "text-classification": RecurrentGemmaForSequenceClassification, "text-generation": RecurrentGemmaForCausalLM, - # "zero-shot": GemmaForSequenceClassification, + # "zero-shot": RecurrentGemmaForSequenceClassification, } if is_torch_available() else {} @@ -313,6 +321,8 @@ def is_pipeline_test_to_skip( return True def setUp(self): + # We don't output attentions + self.has_attentions = False self.model_tester = RecurrentGemmaModelTester(self) self.config_tester = ConfigTester(self, config_class=RecurrentGemmaConfig, hidden_size=37) @@ -331,17 +341,16 @@ def test_model_various_embeddings(self): # def test_RecurrentGemma_sequence_classification_model(self): # config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - # print(config) # config.num_labels = 3 # input_ids = input_dict["input_ids"] # attention_mask = input_ids.ne(1).to(torch_device) # sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) - # model = GemmaForSequenceClassification(config) + # model = RecurrentGemmaForSequenceClassification(config) # model.to(torch_device) # model.eval() # result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) # self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - # + # def test_RecurrentGemma_sequence_classification_model_for_single_label(self): # config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() # config.num_labels = 3 @@ -349,12 +358,12 @@ def test_model_various_embeddings(self): # input_ids = input_dict["input_ids"] # attention_mask = input_ids.ne(1).to(torch_device) # sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) - # model = GemmaForSequenceClassification(config) + # model = RecurrentGemmaForSequenceClassification(config) # model.to(torch_device) # model.eval() # result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) # self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - # + # def test_RecurrentGemma_sequence_classification_model_for_multi_label(self): # config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() # config.num_labels = 3 @@ -364,7 +373,7 @@ def test_model_various_embeddings(self): # sequence_labels = ids_tensor( # [self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size # ).to(torch.float) - # model = GemmaForSequenceClassification(config) + # model = RecurrentGemmaForSequenceClassification(config) # model.to(torch_device) # model.eval() # result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) @@ -383,6 +392,157 @@ def test_save_load_fast_init_from_base(self): def test_past_key_values_format(self): pass + @require_flash_attn + @require_torch_gpu + @pytest.mark.flash_attn_test + @slow + def test_flash_attn_2_generate_padding_right(self): + import torch + + for model_class in self.all_generative_model_classes: + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to( + torch_device + ) + + dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device) + dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device) + + model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False) + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_2", + low_cpu_mem_usage=True, + ).to(torch_device) + + with self.assertRaises(ValueError): + _ = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False + ) + + @require_flash_attn + @require_torch_gpu + @pytest.mark.flash_attn_test + @slow + def test_flash_attn_2_generate_use_cache(self): + import torch + + max_new_tokens = 30 + + for model_class in self.all_generative_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + dummy_input = inputs_dict[model_class.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + # make sure that all models have enough positions for generation + if hasattr(config, "max_position_embeddings"): + config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 + + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) + # NOTE: RecurrentGemma apparently does not support right padding + use_cache with FA2. + dummy_attention_mask[:, -1] = 1 + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_2", + low_cpu_mem_usage=True, + ).to(torch_device) + + # Just test that a large cache works as expected + _ = model.generate( + dummy_input, + attention_mask=dummy_attention_mask, + max_new_tokens=max_new_tokens, + do_sample=False, + use_cache=True, + ) + + @require_flash_attn + @require_torch_gpu + @pytest.mark.flash_attn_test + @slow + def test_flash_attn_2_inference_padding_right(self): + self.skipTest("RecurrentGemma flash attention does not support right padding") + + @require_torch_sdpa + @require_torch_gpu + @slow + def test_sdpa_equivalence(self): + for model_class in self.all_model_classes: + if not model_class._supports_sdpa: + return + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_sdpa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.float16, attn_implementation="sdpa" + ) + model_sdpa.to(torch_device) + + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="eager") + model.to(torch_device) + + dummy_input = inputs_dict[model_class.main_input_name] + dummy_input = dummy_input.to(torch_device) + outputs = model(dummy_input, output_hidden_states=True) + outputs_sdpa = model_sdpa(dummy_input, output_hidden_states=True) + + logits = outputs.hidden_states[-1] + logits_sdpa = outputs_sdpa.hidden_states[-1] + + # gemma sdpa needs a high tolerance + assert torch.allclose(logits_sdpa, logits, atol=3e-3) + + @require_flash_attn + @require_torch_gpu + @pytest.mark.flash_attn_test + @slow + def test_flash_attn_2_equivalence(self): + for model_class in self.all_model_classes: + if not model_class._supports_flash_attn_2: + return + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_2" + ) + model_fa.to(torch_device) + + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="eager") + model.to(torch_device) + + dummy_input = inputs_dict[model_class.main_input_name] + dummy_input = dummy_input.to(torch_device) + outputs = model(dummy_input, output_hidden_states=True) + outputs_fa = model_fa(dummy_input, output_hidden_states=True) + + logits = outputs.hidden_states[-1] + logits_fa = outputs_fa.hidden_states[-1] + + # gemma flash attention 2 needs a high tolerance + assert torch.allclose(logits_fa, logits, atol=3e-3) + @require_torch_gpu @slow @@ -391,7 +551,7 @@ class RecurrentGemmaIntegrationTest(unittest.TestCase): input_text = ["Hello I am doing", "Hi today"] def test_model_2b_fp32(self): - model_id = "google/recurrentgemma-2b" + model_id = "google/gemma-2b" EXPECTED_TEXTS = [ "Hello I am doing a project on the 1990s and I need to know what the most popular music", "Hi today I am going to share with you a very easy and simple recipe of Kaju Kat", @@ -407,45 +567,45 @@ def test_model_2b_fp32(self): self.assertEqual(output_text, EXPECTED_TEXTS) - # def test_model_2b_fp16(self): - # model_id = "google/gemma-2b" - # EXPECTED_TEXTS = [ - # "Hello I am doing a project on the 1990s and I need to know what the most popular music", - # "Hi today I am going to share with you a very easy and simple recipe of Kaju Kat", - # ] - # - # model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.float16).to( - # torch_device - # ) - # - # tokenizer = AutoTokenizer.from_pretrained(model_id) - # inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device) - # - # output = model.generate(**inputs, max_new_tokens=20, do_sample=False) - # output_text = tokenizer.batch_decode(output, skip_special_tokens=True) - # - # self.assertEqual(output_text, EXPECTED_TEXTS) - # - # def test_model_2b_fp16_static_cache(self): - # model_id = "google/gemma-2b" - # EXPECTED_TEXTS = [ - # "Hello I am doing a project on the 1990s and I need to know what the most popular music", - # "Hi today I am going to share with you a very easy and simple recipe of Kaju Kat", - # ] - # - # model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.float16).to( - # torch_device - # ) - # - # model.generation_config.cache_implementation = "static" - # - # tokenizer = AutoTokenizer.from_pretrained(model_id) - # inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device) - # - # output = model.generate(**inputs, max_new_tokens=20, do_sample=False) - # output_text = tokenizer.batch_decode(output, skip_special_tokens=True) - # - # self.assertEqual(output_text, EXPECTED_TEXTS) + def test_model_2b_fp16(self): + model_id = "google/recurrentgemma-2b" + EXPECTED_TEXTS = [ + "Hello I am doing a project on the 1990s and I need to know what the most popular music", + "Hi today I am going to share with you a very easy and simple recipe of Kaju Kat", + ] + + model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.float16).to( + torch_device + ) + + tokenizer = AutoTokenizer.from_pretrained(model_id) + inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device) + + output = model.generate(**inputs, max_new_tokens=20, do_sample=False) + output_text = tokenizer.batch_decode(output, skip_special_tokens=True) + + self.assertEqual(output_text, EXPECTED_TEXTS) + + def test_model_2b_fp16_static_cache(self): + model_id = "google/recurrentgemma-2b" + EXPECTED_TEXTS = [ + "Hello I am doing a project on the 1990s and I need to know what the most popular music", + "Hi today I am going to share with you a very easy and simple recipe of Kaju Kat", + ] + + model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.float16).to( + torch_device + ) + + model.generation_config.cache_implementation = "static" + + tokenizer = AutoTokenizer.from_pretrained(model_id) + inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device) + + output = model.generate(**inputs, max_new_tokens=20, do_sample=False) + output_text = tokenizer.batch_decode(output, skip_special_tokens=True) + + self.assertEqual(output_text, EXPECTED_TEXTS) def test_model_2b_bf16(self): model_id = "google/recurrentgemma-2b" @@ -485,3 +645,159 @@ def test_model_2b_eager(self): output_text = tokenizer.batch_decode(output, skip_special_tokens=True) self.assertEqual(output_text, EXPECTED_TEXTS) + + @require_torch_sdpa + def test_model_2b_sdpa(self): + model_id = "google/recurrentgemma-2b" + EXPECTED_TEXTS = [ + "Hello I am doing a project on the 1990s and I need to know what the most popular music", + "Hi today I am going to share with you a very easy and simple recipe of Khichdi", + ] + + model = AutoModelForCausalLM.from_pretrained( + model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="sdpa" + ) + model.to(torch_device) + + tokenizer = AutoTokenizer.from_pretrained(model_id) + inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device) + + output = model.generate(**inputs, max_new_tokens=20, do_sample=False) + output_text = tokenizer.batch_decode(output, skip_special_tokens=True) + + self.assertEqual(output_text, EXPECTED_TEXTS) + + @pytest.mark.flash_attn_test + @require_flash_attn + def test_model_2b_flash_attn(self): + model_id = "google/recurrentgemma-2b" + EXPECTED_TEXTS = [ + "Hello I am doing a project on the 1990s and I need to know what the most popular music", + "Hi today I am going to share with you a very easy and simple recipe of Kaju Kat", + ] + + model = AutoModelForCausalLM.from_pretrained( + model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + ) + model.to(torch_device) + + tokenizer = AutoTokenizer.from_pretrained(model_id) + inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device) + + output = model.generate(**inputs, max_new_tokens=20, do_sample=False) + output_text = tokenizer.batch_decode(output, skip_special_tokens=True) + + self.assertEqual(output_text, EXPECTED_TEXTS) + + @require_bitsandbytes + def test_model_2b_4bit(self): + model_id = "google/recurrentgemma-2b" + EXPECTED_TEXTS = [ + "Hello I am doing a project and I need to make a 3d model of a house. I have been using", + "Hi today I'd like to share with you my experience with the new wattpad wattpad wattpad wattpad wattpad wattpad wattpad", + ] + + model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, load_in_4bit=True) + + tokenizer = AutoTokenizer.from_pretrained(model_id) + inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device) + + output = model.generate(**inputs, max_new_tokens=20, do_sample=False) + output_text = tokenizer.batch_decode(output, skip_special_tokens=True) + + self.assertEqual(output_text, EXPECTED_TEXTS) + + @unittest.skip("The test will not fit our CI runners") + def test_model_7b_fp32(self): + model_id = "google/recurrentgemma-7b" + EXPECTED_TEXTS = [ + "Hello my name is ***** ***** I will be assisting you today. I am sorry to hear about your issue. I will", + "Hi,\n\nI have a problem with my 2005 1.6 16", + ] + + model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True).to(torch_device) + + tokenizer = AutoTokenizer.from_pretrained(model_id) + inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device) + + output = model.generate(**inputs, max_new_tokens=20, do_sample=False) + output_text = tokenizer.batch_decode(output, skip_special_tokens=True) + + self.assertEqual(output_text, EXPECTED_TEXTS) + + def test_model_7b_fp16(self): + model_id = "google/recurrentgemma-7b" + EXPECTED_TEXTS = [ + """Hello I am doing a project on a 1999 4.0L 4x4. I""", + "Hi today I am going to show you how to make a simple and easy to make a DIY 3D", + ] + + model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.float16).to( + torch_device + ) + + tokenizer = AutoTokenizer.from_pretrained(model_id) + inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device) + + output = model.generate(**inputs, max_new_tokens=20, do_sample=False) + output_text = tokenizer.batch_decode(output, skip_special_tokens=True) + + self.assertEqual(output_text, EXPECTED_TEXTS) + + def test_model_7b_bf16(self): + model_id = "google/recurrentgemma-7b" + EXPECTED_TEXTS = [ + """Hello I am doing a project on a 1991 240sx and I am trying to find""", + "Hi today I am going to show you how to make a very simple and easy to make a very simple and", + ] + + model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16).to( + torch_device + ) + + tokenizer = AutoTokenizer.from_pretrained(model_id) + inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device) + + output = model.generate(**inputs, max_new_tokens=20, do_sample=False) + output_text = tokenizer.batch_decode(output, skip_special_tokens=True) + + self.assertEqual(output_text, EXPECTED_TEXTS) + + def test_model_7b_fp16_static_cache(self): + model_id = "google/recurrentgemma-7b" + EXPECTED_TEXTS = [ + """Hello I am doing a project on a 1999 4.0L 4x4. I""", + "Hi today I am going to show you how to make a simple and easy to make a DIY 3D", + ] + + model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.float16).to( + torch_device + ) + + model.generation_config.cache_implementation = "static" + + tokenizer = AutoTokenizer.from_pretrained(model_id) + inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device) + + output = model.generate(**inputs, max_new_tokens=20, do_sample=False) + output_text = tokenizer.batch_decode(output, skip_special_tokens=True) + + self.assertEqual(output_text, EXPECTED_TEXTS) + + @require_bitsandbytes + def test_model_7b_4bit(self): + model_id = "google/recurrentgemma-7b" + EXPECTED_TEXTS = [ + "Hello I am doing a project for my school and I am trying to make a program that will take a number and then", + """Hi today I am going to talk about the new update for the game called "The new update" and I""", + ] + + model = AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, load_in_4bit=True) + + tokenizer = AutoTokenizer.from_pretrained(model_id) + inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device) + + output = model.generate(**inputs, max_new_tokens=20, do_sample=False) + output_text = tokenizer.batch_decode(output, skip_special_tokens=True) + + self.assertEqual(output_text, EXPECTED_TEXTS)