From adf81a6d652dffe37e41476b75e9fb290f984bdd Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Tue, 10 Sep 2024 12:02:37 +0200 Subject: [PATCH] VLM: fixes after refactor (#32907) * leave only half of the changes * fix tests * [run-slow] llava, llava_next, llava_next_video, vipllava, video_llava * fix tests, first try * [run-slow] llava, llava_next, llava_next_video, vipllava, video_llava * fix, second try * [run-slow] llava, llava_next, llava_next_video, vipllava, video_llava * fix * [run-slow] llava, llava_next, llava_next_video, vipllava, video_llava --- .../models/llava/modeling_llava.py | 8 +- .../models/llava/processing_llava.py | 2 +- .../models/llava_next/modeling_llava_next.py | 13 +- .../llava_next/processing_llava_next.py | 47 ++-- .../llava_next_video/diff_llava_next_video.py | 231 ++++++++--------- .../modeling_llava_next_video.py | 236 ++++++++---------- .../processing_llava_next_video.py | 77 ++++-- .../video_llava/modeling_video_llava.py | 43 ++-- .../video_llava/processing_video_llava.py | 16 +- .../models/vipllava/modeling_vipllava.py | 8 +- tests/models/llava/test_modeling_llava.py | 49 +++- .../llava_next/test_modeling_llava_next.py | 82 +++--- .../test_modeling_llava_next_video.py | 135 ++++++---- .../video_llava/test_modeling_video_llava.py | 129 ++++------ .../models/vipllava/test_modeling_vipllava.py | 9 +- 15 files changed, 581 insertions(+), 504 deletions(-) diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index ae53156d9ba2cd..94388af99ec17f 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -476,6 +476,7 @@ def forward( inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( image_features, inputs_embeds, input_ids, attention_mask, labels ) + cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) else: # Retrieve the first layer to inspect the logits and mask out the hidden states # that are set to 0 @@ -506,6 +507,9 @@ def forward( attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 + cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[ + -target_length: + ] # TODO: @raushan retain only the new behavior after v4.47 else: @@ -585,9 +589,7 @@ def prepare_inputs_for_generation( **kwargs, ) - if legacy_processing: - model_inputs["pixel_values"] = pixel_values - elif cache_position[0] == 0: + if legacy_processing or cache_position[0] == 0: # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore # Otherwise we need pixel values to be passed to model model_inputs["pixel_values"] = pixel_values diff --git a/src/transformers/models/llava/processing_llava.py b/src/transformers/models/llava/processing_llava.py index 99244d993b71cb..678724ae95be41 100644 --- a/src/transformers/models/llava/processing_llava.py +++ b/src/transformers/models/llava/processing_llava.py @@ -136,6 +136,7 @@ def __call__( raise ValueError("Invalid input text. Please provide a string, or a list of strings") # try to expand inputs in processing if we have the necessary parts + prompt_strings = text if image_inputs.get("pixel_values") is not None: if self.patch_size is not None and self.vision_feature_select_strategy is not None: # Replace the image token with the expanded image token sequence @@ -150,7 +151,6 @@ def __call__( sample = sample.replace(self.image_token, self.image_token * num_image_tokens) prompt_strings.append(sample) else: - prompt_strings = text logger.warning_once( "Expanding inputs for image tokens in LLaVa should be done in processing. " "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index 5fe029f13e7349..18a17c6dcd06b9 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -848,6 +848,7 @@ def forward( position_ids, labels=labels, ) + cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) else: # Retrieve the first layer to inspect the logits and mask out the hidden states # that are set to 0 @@ -877,6 +878,9 @@ def forward( extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 + cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[ + -target_length: + ] # TODO: @raushan retain only the new behavior after v4.47 else: @@ -956,12 +960,9 @@ def prepare_inputs_for_generation( **kwargs, ) - if legacy_processing: - model_inputs["pixel_values"] = pixel_values - model_inputs["image_sizes"] = image_sizes - elif cache_position[0] == 0: - # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore - # Otherwise we need pixel values to be passed to model + # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore + # Otherwise we need pixel values to be passed to model + if legacy_processing or cache_position[0] == 0: model_inputs["pixel_values"] = pixel_values model_inputs["image_sizes"] = image_sizes diff --git a/src/transformers/models/llava_next/processing_llava_next.py b/src/transformers/models/llava_next/processing_llava_next.py index f84578d1f3466e..2a2df041283ed3 100644 --- a/src/transformers/models/llava_next/processing_llava_next.py +++ b/src/transformers/models/llava_next/processing_llava_next.py @@ -140,30 +140,29 @@ def __call__( elif not isinstance(text, list) and not isinstance(text[0], str): raise ValueError("Invalid input text. Please provide a string, or a list of strings") - if self.patch_size is None or self.vision_feature_select_strategy is None: - prompt_strings = text - logger.warning_once( - "Expanding inputs for image tokens in LLaVa-NeXT should be done in processing. " - "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " - "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " - "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." - ) - # cannot infer image expansion length if no images are found - elif not image_inputs: - prompt_strings = text - else: - image_sizes = image_inputs["image_sizes"] - height, width = get_image_size(to_numpy_array(image_inputs["pixel_values"][0][0])) - prompt_strings = [] - for image_size, sample in zip(image_sizes, text): - # Replace the image token with the expanded image token sequence - orig_height, orig_width = image_size - num_image_tokens = self._get_number_of_features(orig_height, orig_width, height, width) - if self.vision_feature_select_strategy == "default": - num_image_tokens -= 1 - - sample = sample.replace(self.image_token, self.image_token * num_image_tokens) - prompt_strings.append(sample) + prompt_strings = text + if image_inputs: + if self.patch_size is None or self.vision_feature_select_strategy is None: + logger.warning_once( + "Expanding inputs for image tokens in LLaVa-NeXT should be done in processing. " + "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " + "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " + "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." + ) + else: + image_sizes = iter(image_inputs["image_sizes"]) + height, width = get_image_size(to_numpy_array(image_inputs["pixel_values"][0][0])) + prompt_strings = [] + for sample in text: + while self.image_token in sample: + image_size = next(image_sizes) + orig_height, orig_width = image_size + num_image_tokens = self._get_number_of_features(orig_height, orig_width, height, width) + if self.vision_feature_select_strategy == "default": + num_image_tokens -= 1 + sample = sample.replace(self.image_token, "" * num_image_tokens, 1) + prompt_strings.append(sample) + prompt_strings = [sample.replace("", self.image_token) for sample in prompt_strings] text_inputs = self.tokenizer( prompt_strings, diff --git a/src/transformers/models/llava_next_video/diff_llava_next_video.py b/src/transformers/models/llava_next_video/diff_llava_next_video.py index b4018db586e74e..e765dfb95cc335 100644 --- a/src/transformers/models/llava_next_video/diff_llava_next_video.py +++ b/src/transformers/models/llava_next_video/diff_llava_next_video.py @@ -29,7 +29,6 @@ image_size_to_num_patches, ) -from ...cache_utils import Cache from ...utils import ( logging, replace_return_docstrings, @@ -389,13 +388,17 @@ def forward( # if the number of image/video tokens is more than image embeddings seq length, then prob we expanded it in processing # not very reliable, but we don't expect one to actually pass 500+ images for one prompt - img_token_count = (input_ids == self.config.image_token_index).sum(1).max() - video_token_count = (input_ids == self.config.video_token_index).sum(1).max() - inputs_expanded = ( - img_token_count < self.config.image_seq_length and video_token_count < self.config.video_seq_length + img_token_not_enough = (input_ids == self.config.image_token_index).sum( + 1 + ).max() < self.config.image_seq_length + video_token_not_enough = (input_ids == self.config.video_token_index).sum( + 1 + ).max() < self.config.video_seq_length + inputs_not_expanded = (img_token_not_enough and pixel_values is not None) or ( + video_token_not_enough and pixel_values_videos is not None ) - pixels_present = input_ids.shape[-1] == 1 and pixel_values is not None and pixel_values_videos is not None - legacy_processing = inputs_expanded or pixels_present + pixels_present = input_ids.shape[-1] == 1 and (pixel_values is not None or pixel_values_videos is not None) + legacy_processing = inputs_not_expanded or pixels_present image_features = feature_lens = None if pixel_values is not None and pixel_values.size(0) > 0: @@ -414,75 +417,76 @@ def forward( video_features = torch.cat(video_features, dim=0) video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device) - if legacy_processing: - logger.warning_once( - "Expanding inputs for image.video tokens in LLaVa-NeXT-Video should be done in processing. " - "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " - "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " - "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." + if legacy_processing: + logger.warning_once( + "Expanding inputs for image.video tokens in LLaVa-NeXT-Video should be done in processing. " + "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " + "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " + "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." + ) + if input_ids.shape[1] != 1: + iterator = ( + (image_features, feature_lens, self.config.image_token_index), + (video_features, video_feature_lens, self.config.video_token_index), ) - if input_ids.shape[1] != 1: - iterator = ( - (image_features, feature_lens, self.config.image_token_index), - (video_features, video_feature_lens, self.config.video_token_index), - ) - for features, lens, special_token in zip(iterator): - if features is not None: - ( - inputs_embeds, - attention_mask, - position_ids, - labels, - input_ids, - ) = self._merge_input_ids_with_image_features( - features, - lens, - inputs_embeds, - input_ids, - attention_mask, - position_ids, - labels=labels, - image_token_index=special_token, - ) - else: - # Retrieve the first layer to inspect the logits and mask out the hidden states that are set to 0 - first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] - # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 - batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) - # Get the target length - target_length = input_ids.shape[1] - past_length = first_layer_past_key_value.shape[-1] - extended_attention_mask = torch.ones( - (attention_mask.shape[0], past_length), - dtype=attention_mask.dtype, - device=attention_mask.device, - ) - # Filter out only the tokens that can be un-attended, this can happen - # if one uses Llava + Fused modules where the cache on the - # first iteration is already big enough, or if one passes custom cache - valid_indices = non_attended_tokens < extended_attention_mask.size(-1) - new_batch_index = batch_index[valid_indices] - new_non_attended_tokens = non_attended_tokens[valid_indices] - # Zero-out the places where we don't need to attend - extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 - attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) - position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 - - # TODO: @raushan retain only the new behavior after v4.47 + for features, lens, special_token in iterator: + if features is not None: + ( + inputs_embeds, + attention_mask, + position_ids, + labels, + input_ids, + ) = self._merge_input_ids_with_image_features( + features, + lens, + inputs_embeds, + input_ids, + attention_mask, + position_ids, + labels=labels, + image_token_index=special_token, + ) + cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) else: - if image_features is not None: - special_image_mask = ( - (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - - if video_features is not None: - special_image_mask = ( - (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds) - ) - video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) + # Retrieve the first layer to inspect the logits and mask out the hidden states that are set to 0 + first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] + # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 + batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) + # Get the target length + target_length = input_ids.shape[1] + past_length = first_layer_past_key_value.shape[-1] + extended_attention_mask = torch.ones( + (attention_mask.shape[0], past_length), + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + # Filter out only the tokens that can be un-attended, this can happen + # if one uses Llava + Fused modules where the cache on the + # first iteration is already big enough, or if one passes custom cache + valid_indices = non_attended_tokens < extended_attention_mask.size(-1) + new_batch_index = batch_index[valid_indices] + new_non_attended_tokens = non_attended_tokens[valid_indices] + # Zero-out the places where we don't need to attend + extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 + attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) + position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 + cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:] + + # TODO: @raushan retain only the new behavior after v4.47 + else: + if image_features is not None: + special_image_mask = ( + (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + if video_features is not None: + special_image_mask = ( + (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds) + ) + video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) outputs = self.language_model( attention_mask=attention_mask, @@ -493,6 +497,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) logits = outputs[0] @@ -534,58 +539,34 @@ def prepare_inputs_for_generation( pixel_values_videos=None, image_sizes=None, attention_mask=None, + cache_position=None, **kwargs, ): - if past_key_values is not None: - if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - else: - cache_length = past_length = past_key_values[0][0].shape[2] - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - elif self.config.image_token_index in input_ids or self.config.video_token_index in input_ids: - input_ids = input_ids[:, input_ids.shape[1] - 1 :] - - # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the - # older attention values, as their corresponding values are not part of the input. - if cache_length < past_length and attention_mask is not None: - attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :] - - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - "pixel_values": pixel_values, - "pixel_values_videos": pixel_values_videos, - "image_sizes": image_sizes, - } + if input_ids is not None: + img_token_not_enough = (input_ids == self.config.image_token_index).sum( + 1 + ).max() < self.config.image_seq_length + video_token_not_enough = (input_ids == self.config.video_token_index).sum( + 1 + ).max() < self.config.video_seq_length + legacy_processing = (img_token_not_enough and pixel_values is not None) or ( + video_token_not_enough and pixel_values_videos is not None + ) + + model_inputs = self.language_model.prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + **kwargs, ) + + # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore + # Otherwise we need pixel values to be passed to model + if legacy_processing or cache_position[0] == 0: + model_inputs["pixel_values"] = pixel_values + model_inputs["pixel_values_videos"] = pixel_values_videos + model_inputs["image_sizes"] = image_sizes + return model_inputs diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index 78e8c5a077233b..7d6776738c39fd 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -31,7 +31,6 @@ from ... import PreTrainedModel from ...activations import ACT2FN -from ...cache_utils import Cache from ...image_processing_utils import select_best_resolution from ...modeling_outputs import ModelOutput from ...utils import ( @@ -767,6 +766,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, ) -> Union[Tuple, LlavaNextVideoCausalLMOutputWithPast]: r""" @@ -874,13 +874,17 @@ def forward( # if the number of image/video tokens is more than image embeddings seq length, then prob we expanded it in processing # not very reliable, but we don't expect one to actually pass 500+ images for one prompt - img_token_count = (input_ids == self.config.image_token_index).sum(1).max() - video_token_count = (input_ids == self.config.video_token_index).sum(1).max() - inputs_expanded = ( - img_token_count < self.config.image_seq_length and video_token_count < self.config.video_seq_length + img_token_not_enough = (input_ids == self.config.image_token_index).sum( + 1 + ).max() < self.config.image_seq_length + video_token_not_enough = (input_ids == self.config.video_token_index).sum( + 1 + ).max() < self.config.video_seq_length + inputs_not_expanded = (img_token_not_enough and pixel_values is not None) or ( + video_token_not_enough and pixel_values_videos is not None ) - pixels_present = input_ids.shape[-1] == 1 and pixel_values is not None and pixel_values_videos is not None - legacy_processing = inputs_expanded or pixels_present + pixels_present = input_ids.shape[-1] == 1 and (pixel_values is not None or pixel_values_videos is not None) + legacy_processing = inputs_not_expanded or pixels_present image_features = feature_lens = None if pixel_values is not None and pixel_values.size(0) > 0: @@ -899,75 +903,76 @@ def forward( video_features = torch.cat(video_features, dim=0) video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device) - if legacy_processing: - logger.warning_once( - "Expanding inputs for image.video tokens in LLaVa-NeXT-Video should be done in processing. " - "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " - "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " - "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." + if legacy_processing: + logger.warning_once( + "Expanding inputs for image.video tokens in LLaVa-NeXT-Video should be done in processing. " + "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " + "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " + "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." + ) + if input_ids.shape[1] != 1: + iterator = ( + (image_features, feature_lens, self.config.image_token_index), + (video_features, video_feature_lens, self.config.video_token_index), ) - if input_ids.shape[1] != 1: - iterator = ( - (image_features, feature_lens, self.config.image_token_index), - (video_features, video_feature_lens, self.config.video_token_index), - ) - for features, lens, special_token in iterator: - if features is not None: - ( - inputs_embeds, - attention_mask, - position_ids, - labels, - input_ids, - ) = self._merge_input_ids_with_image_features( - features, - lens, - inputs_embeds, - input_ids, - attention_mask, - position_ids, - labels=labels, - image_token_index=special_token, - ) - else: - # Retrieve the first layer to inspect the logits and mask out the hidden states that are set to 0 - first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] - # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 - batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) - # Get the target length - target_length = input_ids.shape[1] - past_length = first_layer_past_key_value.shape[-1] - extended_attention_mask = torch.ones( - (attention_mask.shape[0], past_length), - dtype=attention_mask.dtype, - device=attention_mask.device, - ) - # Filter out only the tokens that can be un-attended, this can happen - # if one uses Llava + Fused modules where the cache on the - # first iteration is already big enough, or if one passes custom cache - valid_indices = non_attended_tokens < extended_attention_mask.size(-1) - new_batch_index = batch_index[valid_indices] - new_non_attended_tokens = non_attended_tokens[valid_indices] - # Zero-out the places where we don't need to attend - extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 - attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) - position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 - - # TODO: @raushan retain only the new behavior after v4.47 + for features, lens, special_token in iterator: + if features is not None: + ( + inputs_embeds, + attention_mask, + position_ids, + labels, + input_ids, + ) = self._merge_input_ids_with_image_features( + features, + lens, + inputs_embeds, + input_ids, + attention_mask, + position_ids, + labels=labels, + image_token_index=special_token, + ) + cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) else: - if image_features is not None: - special_image_mask = ( - (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - - if video_features is not None: - special_image_mask = ( - (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds) - ) - video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) + # Retrieve the first layer to inspect the logits and mask out the hidden states that are set to 0 + first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] + # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 + batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) + # Get the target length + target_length = input_ids.shape[1] + past_length = first_layer_past_key_value.shape[-1] + extended_attention_mask = torch.ones( + (attention_mask.shape[0], past_length), + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + # Filter out only the tokens that can be un-attended, this can happen + # if one uses Llava + Fused modules where the cache on the + # first iteration is already big enough, or if one passes custom cache + valid_indices = non_attended_tokens < extended_attention_mask.size(-1) + new_batch_index = batch_index[valid_indices] + new_non_attended_tokens = non_attended_tokens[valid_indices] + # Zero-out the places where we don't need to attend + extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 + attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) + position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 + cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:] + + # TODO: @raushan retain only the new behavior after v4.47 + else: + if image_features is not None: + special_image_mask = ( + (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + if video_features is not None: + special_image_mask = ( + (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds) + ) + video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) outputs = self.language_model( attention_mask=attention_mask, @@ -978,6 +983,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, num_logits_to_keep=num_logits_to_keep, ) @@ -1020,64 +1026,38 @@ def prepare_inputs_for_generation( pixel_values_videos=None, image_sizes=None, attention_mask=None, + cache_position=None, num_logits_to_keep=None, **kwargs, ): - if past_key_values is not None: - if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - else: - cache_length = past_length = past_key_values[0][0].shape[2] - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - elif self.config.image_token_index in input_ids or self.config.video_token_index in input_ids: - input_ids = input_ids[:, input_ids.shape[1] - 1 :] - - # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the - # older attention values, as their corresponding values are not part of the input. - if cache_length < past_length and attention_mask is not None: - attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :] - - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - if "num_logits_to_keep" != None: - model_inputs["num_logits_to_keep"] = num_logits_to_keep - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - "pixel_values": pixel_values, - "pixel_values_videos": pixel_values_videos, - "image_sizes": image_sizes, - } + if input_ids is not None: + img_token_not_enough = (input_ids == self.config.image_token_index).sum( + 1 + ).max() < self.config.image_seq_length + video_token_not_enough = (input_ids == self.config.video_token_index).sum( + 1 + ).max() < self.config.video_seq_length + legacy_processing = (img_token_not_enough and pixel_values is not None) or ( + video_token_not_enough and pixel_values_videos is not None + ) + + model_inputs = self.language_model.prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + num_logits_to_keep=num_logits_to_keep, + **kwargs, ) + + # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore + # Otherwise we need pixel values to be passed to model + if legacy_processing or cache_position[0] == 0: + model_inputs["pixel_values"] = pixel_values + model_inputs["pixel_values_videos"] = pixel_values_videos + model_inputs["image_sizes"] = image_sizes + return model_inputs def _get_image_features(self, pixel_values, image_sizes): diff --git a/src/transformers/models/llava_next_video/processing_llava_next_video.py b/src/transformers/models/llava_next_video/processing_llava_next_video.py index efbb193ba62a9f..e0e4534e42b565 100644 --- a/src/transformers/models/llava_next_video/processing_llava_next_video.py +++ b/src/transformers/models/llava_next_video/processing_llava_next_video.py @@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, List, Optional, Union from ...feature_extraction_utils import BatchFeature +from ...image_processing_utils import select_best_resolution from ...image_utils import ImageInput, VideoInput, get_image_size, to_numpy_array from ...processing_utils import ProcessorMixin from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy @@ -160,35 +161,29 @@ def __call__( elif not isinstance(text, list) and not isinstance(text[0], str): raise ValueError("Invalid input text. Please provide a string, or a list of strings") - print(self.patch_size, self.vision_feature_select_strategy, image_inputs, videos_inputs.keys()) - if self.patch_size is None or self.vision_feature_select_strategy is None: - prompt_strings = text logger.warning_once( "Expanding inputs for image/video tokens in LLaVa-NeXT-Video should be done in processing. " "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." ) - # cannot infer image expansion length if no images/videos are found - elif not image_inputs and not videos_inputs: - prompt_strings = text else: # images expand taking into account num_of_patches in each image if image_inputs: - image_sizes = image_inputs["image_sizes"] + image_sizes = iter(image_inputs["image_sizes"]) height, width = get_image_size(to_numpy_array(image_inputs["pixel_values"][0][0])) prompt_strings = [] - for image_size, sample in zip(image_sizes, text): - # Replace the image token with the expanded image token sequence - orig_height, orig_width = image_size - num_image_tokens = self._get_number_of_features(orig_height, orig_width, height, width) - if self.vision_feature_select_strategy == "default": - num_image_tokens -= 1 - - sample = sample.replace(self.image_token, self.image_token * num_image_tokens) + for sample in text: + while self.image_token in sample: + image_size = next(image_sizes) + orig_height, orig_width = image_size + num_image_tokens = self._get_number_of_features(orig_height, orig_width, height, width) + if self.vision_feature_select_strategy == "default": + num_image_tokens -= 1 + sample = sample.replace(self.image_token, "" * num_image_tokens, 1) prompt_strings.append(sample) - text = prompt_strings + text = [sample.replace("", self.image_token) for sample in prompt_strings] # videos are easier, simply get frames and multiply if videos_inputs: @@ -197,23 +192,65 @@ def __call__( num_frames = one_video.shape[0] # frame dim is always after batch dim num_image_tokens = (height // self.patch_size) * (width // self.patch_size) num_video_tokens = num_image_tokens // 4 * num_frames # divide by 4 needed for avg pooling layer - prompt_strings = [] for sample in text: sample = sample.replace(self.video_token, self.video_token * num_video_tokens) prompt_strings.append(sample) + text = prompt_strings text_inputs = self.tokenizer( - prompt_strings, + text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length, ) - print(text_inputs.keys()) - return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}) + # Copied from transformers.models.llava_next.processing_llava_next.LlavaNextProcessor._get_number_of_features + def _get_number_of_features(self, orig_height: int, orig_width: int, height: int, width: int) -> int: + image_grid_pinpoints = self.image_processor.image_grid_pinpoints + + height_best_resolution, width_best_resolution = select_best_resolution( + [orig_height, orig_width], image_grid_pinpoints + ) + scale_height, scale_width = height_best_resolution // height, width_best_resolution // width + + patches_height = height // self.patch_size + patches_width = width // self.patch_size + unpadded_features, newline_features = self._get_unpadded_features( + orig_height, orig_width, patches_height, patches_width, scale_height, scale_width + ) + # The base patch covers the entire image (+1 for the CLS) + base_features = patches_height * patches_width + 1 + num_image_tokens = unpadded_features + newline_features + base_features + return num_image_tokens + + # Copied from transformers.models.llava_next.processing_llava_next.LlavaNextProcessor._get_unpadded_features + def _get_unpadded_features(self, height, width, patches_height, patches_width, scale_height, scale_width): + """ + Get number of features for a given image with height/width. LLaVA-NeXT is different from LLaVA + because it divided each image into patches depending on its resolution. Therefore we need to calculate how many + patches an image is divided into and get the number of features from that. + """ + current_height = patches_height * scale_height + current_width = patches_width * scale_width + + original_aspect_ratio = width / height + current_aspect_ratio = current_width / current_height + if original_aspect_ratio > current_aspect_ratio: + new_height = (height * current_width) // width + padding = (current_height - new_height) // 2 + current_height -= padding * 2 + else: + new_width = (width * current_height) // height + padding = (current_width - new_width) // 2 + current_width -= padding * 2 + + unpadded_features = current_height * current_width + newline_features = current_height + return (unpadded_features, newline_features) + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama def batch_decode(self, *args, **kwargs): """ diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index b9263ad15cbf93..08e02d9a702acb 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -529,15 +529,19 @@ def forward( # if the number of image/video tokens is more than image embeddings seq length, then prob we expanded it in processing # not very reliable, but we don't expect one to actually pass 500+ images for one prompt - img_token_count = (input_ids == self.config.image_token_index).sum(1).max() - video_token_count = (input_ids == self.config.video_token_index).sum(1).max() - inputs_expanded = ( - img_token_count < self.config.image_seq_length and video_token_count < self.config.video_seq_length + img_token_not_enough = (input_ids == self.config.image_token_index).sum( + 1 + ).max() < self.config.image_seq_length + video_token_not_enough = (input_ids == self.config.video_token_index).sum( + 1 + ).max() < self.config.video_seq_length + inputs_not_expanded = (img_token_not_enough and pixel_values_images is not None) or ( + video_token_not_enough and pixel_values_videos is not None ) - pixels_present = ( - input_ids.shape[-1] == 1 and pixel_values_images is not None and pixel_values_videos is not None + pixels_present = input_ids.shape[-1] == 1 and ( + pixel_values_images is not None or pixel_values_videos is not None ) - legacy_processing = inputs_expanded or pixels_present + legacy_processing = inputs_not_expanded or pixels_present if pixel_values_images is not None or pixel_values_videos is not None: image_outputs, video_outputs, num_frames = self._get_vision_features( @@ -577,6 +581,7 @@ def forward( labels, num_frames=frames, ) + cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) else: # Retrieve the first layer to inspect the logits and mask out the hidden states # that are set to 0 @@ -606,6 +611,9 @@ def forward( attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 + cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[ + -target_length: + ] # TODO: @raushan retain only the new behavior after v4.47 else: @@ -678,11 +686,16 @@ def prepare_inputs_for_generation( num_logits_to_keep=None, **kwargs, ): - # Trigger the new behavior if we have more than image embeddings seq length tokens for images - legacy_processing = input_ids is not None and ( - (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length - and (input_ids == self.config.video_token_index).sum(1).max() < self.config.video_seq_length - ) + if input_ids is not None: + img_token_not_enough = (input_ids == self.config.image_token_index).sum( + 1 + ).max() < self.config.image_seq_length + video_token_not_enough = (input_ids == self.config.video_token_index).sum( + 1 + ).max() < self.config.video_seq_length + legacy_processing = (img_token_not_enough and pixel_values_images is not None) or ( + video_token_not_enough and pixel_values_videos is not None + ) model_inputs = self.language_model.prepare_inputs_for_generation( input_ids, @@ -694,11 +707,7 @@ def prepare_inputs_for_generation( **kwargs, ) - if legacy_processing: - model_inputs["pixel_values_images"] = pixel_values_images - model_inputs["pixel_values_videos"] = pixel_values_videos - - elif cache_position[0] == 0: + if legacy_processing or cache_position[0] == 0: # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore # Otherwise we need pixel values to be passed to model model_inputs["pixel_values_images"] = pixel_values_images diff --git a/src/transformers/models/video_llava/processing_video_llava.py b/src/transformers/models/video_llava/processing_video_llava.py index a06913d7acf760..bd6f91270965bb 100644 --- a/src/transformers/models/video_llava/processing_video_llava.py +++ b/src/transformers/models/video_llava/processing_video_llava.py @@ -145,24 +145,28 @@ def __call__( elif not isinstance(text, list) and not isinstance(text[0], str): raise ValueError("Invalid input text. Please provide a string, or a list of strings") - if encoded_images is not None and self.patch_size is None or self.vision_feature_select_strategy is None: - prompt_strings = text + prompt_strings = text + if encoded_images is not None and (self.patch_size is None or self.vision_feature_select_strategy is None): logger.warning_once( "Expanding inputs for image tokens in Video-LLaVa should be done in processing. " "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " "Using processors without these attributes in the config is deprecated and will throw an error in v4.44." ) + # Replace the image/video tokens with the expanded token sequence elif encoded_images is not None: - # Replace the image token with the expanded image token sequence - if "pixel_values" in encoded_images: - height, width = get_image_size(to_numpy_array(encoded_images.get("pixel_values")[0])) + if "pixel_values_images" in encoded_images.keys(): + height, width = get_image_size(to_numpy_array(encoded_images.get("pixel_values_images")[0])) num_frames = 1 - else: + + if "pixel_values_videos" in encoded_images.keys(): one_video = to_numpy_array(encoded_images.get("pixel_values_videos")[0]) height, width = get_image_size(one_video[0]) num_frames = one_video.shape[0] # frame dim is always after batch dim + num_image_tokens = (height // self.patch_size) * (width // self.patch_size) + 1 + num_video_tokens = num_image_tokens * num_frames + num_image_tokens = (height // self.patch_size) * (width // self.patch_size) + 1 num_video_tokens = num_image_tokens * num_frames if self.vision_feature_select_strategy == "default": diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index e036d6fb766744..5367b1e088d2aa 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -471,6 +471,7 @@ def forward( inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( image_features, inputs_embeds, input_ids, attention_mask, labels ) + cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) else: # Retrieve the first layer to inspect the logits and mask out the hidden states # that are set to 0 @@ -500,6 +501,9 @@ def forward( attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 + cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[ + -target_length: + ] # TODO: @raushan retain only the new behavior after v4.47 else: @@ -579,9 +583,7 @@ def prepare_inputs_for_generation( **kwargs, ) - if legacy_processing: - model_inputs["pixel_values"] = pixel_values - elif cache_position[0] == 0: + if legacy_processing or cache_position[0] == 0: # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore # Otherwise we need pixel values to be passed to model model_inputs["pixel_values"] = pixel_values diff --git a/tests/models/llava/test_modeling_llava.py b/tests/models/llava/test_modeling_llava.py index 360bbde29c18b5..2fed802b5a2fb3 100644 --- a/tests/models/llava/test_modeling_llava.py +++ b/tests/models/llava/test_modeling_llava.py @@ -302,7 +302,7 @@ def test_small_model_integration_test_llama_single(self): inputs = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16) output = model.generate(**inputs, max_new_tokens=900, do_sample=False) - EXPECTED_DECODED_TEXT = "USER: \nWhat are the things I should be cautious about when I visit this place? ASSISTANT: When visiting this place, which is a pier or dock extending over a body of water, there are a few things to be cautious about. First, be aware of the weather conditions, as sudden changes in weather can make the pier unsafe to walk on. Second, be mindful of the water depth and any potential hazards, such as submerged rocks or debris, that could cause accidents or injuries. Additionally, be cautious of the tides and currents, as they can change rapidly and pose a risk to swimmers or those who venture too close to the edge of the pier. Lastly, be respectful of the environment and other visitors, as the pier is a shared space where people can enjoy the view, relax, or engage in recreational activities." # fmt: skip + EXPECTED_DECODED_TEXT = "USER: \nWhat are the things I should be cautious about when I visit this place? ASSISTANT: When visiting this place, which is a pier or dock extending over a body of water, there are a few things to be cautious about. First, be aware of the weather conditions, as sudden changes in weather can make the pier unsafe to walk on. Second, be mindful of the water depth and any potential hazards, such as submerged rocks or debris, that could cause accidents or injuries. Additionally, be cautious of the tides and currents, as they can change rapidly and pose a risk to swimmers or those who venture too close to the edge of the pier. Finally, be respectful of the environment and other visitors, and follow any posted rules or guidelines for the area." # fmt: skip self.assertEqual( processor.decode(output[0], skip_special_tokens=True), @@ -353,7 +353,10 @@ def test_small_model_integration_test_batch(self): output = model.generate(**inputs, max_new_tokens=20) - EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, there are a few things to be cautious about and items to bring along', 'USER: \nWhat is this?\nASSISTANT: Cats'] # fmt: skip + EXPECTED_DECODED_TEXT = [ + 'USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, there are a few things to be cautious about and items to bring.', + 'USER: \nWhat is this?\nASSISTANT: Cats' + ] # fmt: skip self.assertEqual( self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT, @@ -393,7 +396,7 @@ def test_small_model_integration_test_llama_batched_regression(self): @require_torch @require_vision def test_batched_generation(self): - model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf").to(torch_device) + model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf", load_in_4bit=True) processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") @@ -415,9 +418,9 @@ def test_batched_generation(self): model = model.eval() EXPECTED_OUTPUT = [ - "\n \nUSER: What's the the difference of two images?\nASSISTANT: In the two images, the primary difference is the presence of a small dog in one and a ll", - "\nUSER: Describe the image.\nASSISTANT: The image features a small, fluffy dog sitting on a sidewalk. The dog is holding", - "\nUSER: Describe the image.\nASSISTANT: The image features a lone, adult llama standing on a grassy hill. The llama", + "\n \nUSER: What's the the difference of two images?\nASSISTANT: The difference between the two images is that one shows a dog standing on a grassy field, while", + "\nUSER: Describe the image.\nASSISTANT: The image features a brown and white dog sitting on a sidewalk. The dog is holding a small", + "\nUSER: Describe the image.\nASSISTANT: The image features a lone llama standing on a grassy hill. The llama is the", ] generate_ids = model.generate(**inputs, max_new_tokens=20) @@ -451,26 +454,23 @@ def test_llava_index_error_bug(self): def test_llava_merge_inputs_error_bug(self): # This is a reproducer of https://github.com/huggingface/transformers/pull/28333 and makes sure it does not happen anymore model_id = "llava-hf/llava-1.5-7b-hf" - model = LlavaForConditionalGeneration.from_pretrained( - model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True - ).to(torch_device) + model = LlavaForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True) # Simulate some user inputs pixel_values = torch.randn( - (2, 3, 336, 336), + (1, 3, 336, 336), dtype=torch.float, device=torch_device, ) input_ids = torch.tensor( [ [32001, 32001, 1, 15043, 7084, 32000, 29871, 13, 7900], - [1, 15043, 7084, 29901, 29871, 32000, 29871, 13, 7900], ], dtype=torch.long, device=torch_device, ) attention_mask = torch.tensor( - [[0, 0, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1]], + [[0, 0, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.long, device=torch_device, ) @@ -515,6 +515,31 @@ def test_generation_no_images(self): # Make sure that `generate` works _ = model.generate(**inputs, max_new_tokens=20) + @slow + @require_bitsandbytes + def test_generation_siglip_backbone(self): + model_id = "llava-hf/llava-interleave-qwen-0.5b-hf" + model = LlavaForConditionalGeneration.from_pretrained(model_id, torch_dtype="float16", device_map=torch_device) + processor = AutoProcessor.from_pretrained(model_id) + + # check processing with expansion of inputs (w/o expansion should work with any backbone) + processor.vision_feature_select_strategy = "default" + processor.patch_size = 14 + + image_file = "http://images.cocodataset.org/val2017/000000039769.jpg" + raw_image = Image.open(requests.get(image_file, stream=True).raw) + inputs = processor( + text="<|im_start|>user\n\nWhat are these?<|im_end|>\n<|im_start|>assistant", + images=raw_image, + return_tensors="pt", + ).to(torch_device, torch.float16) + + # Make sure that `generate` works + output = model.generate(**inputs, max_new_tokens=30) + + EXPECTED_DECODED_TEXT = "user\n\nWhat are these?\nassistant The image shows two cats, one on the left and one on the right. They appear to be resting or sleeping on a pink blanket. The cat" + self.assertTrue(processor.batch_decode(output, skip_special_tokens=True)[0] == EXPECTED_DECODED_TEXT) + @slow @require_bitsandbytes def test_expansion_in_processing(self): diff --git a/tests/models/llava_next/test_modeling_llava_next.py b/tests/models/llava_next/test_modeling_llava_next.py index c665631c40331d..3120db216ea4bb 100644 --- a/tests/models/llava_next/test_modeling_llava_next.py +++ b/tests/models/llava_next/test_modeling_llava_next.py @@ -363,11 +363,7 @@ def test_small_model_integration_test(self): output = model(**inputs) expected_slice = torch.tensor( - [ - [-4.7695, -4.5664, -0.2786], - [-10.6250, -10.8906, -2.5254], - [-6.7383, -7.2461, -0.6787], - ], + [[-4.7695, -4.5664, -0.2788], [-10.6172, -10.8828, -2.5273], [-6.7383, -7.2422, -0.6694]], dtype=torch.float32, device=torch_device, ) @@ -471,16 +467,16 @@ def test_small_model_integration_test_batch_different_resolutions(self): output = model(**inputs) expected_slice = torch.tensor( - [[-0.0308, -0.0313, -0.0314], [-0.3064, -0.3013, -0.2986], [-0.1226, -0.1246, -0.1210]], + [[-0.1287, -0.1294, -0.1284], [-0.2744, -0.2698, -0.2671], [-0.1071, -0.1091, -0.1056]], dtype=torch.float32, device=torch_device, ) assert torch.allclose(output.logits[0, -3:, -3:], expected_slice, atol=1e-3) - assert torch.allclose(output.loss, torch.tensor(6.8619, device=torch_device)) + assert torch.allclose(output.loss, torch.tensor(7.0206, device=torch_device), atol=1e-3) # verify generation output = model.generate(**inputs, max_new_tokens=50) - EXPECTED_DECODED_TEXT = '[INST] \nWhat is shown in this image? [/INST] The image shows a forested area with a misty or foggy atmosphere. In the foreground, there is a grassy field with a few deer grazing. The deer are partially obscured by the fog, and the trees in the background' # fmt: skip + EXPECTED_DECODED_TEXT = '[INST] \nWhat is shown in this image? [/INST] The image shows two deer, likely fawns, in a grassy area with trees in the background. The setting appears to be a forest or woodland, and the photo is taken during what seems to be either dawn or dusk, given' # fmt: skip self.assertEqual( self.processor.decode(output[0], skip_special_tokens=True), EXPECTED_DECODED_TEXT, @@ -534,38 +530,66 @@ def test_padding_side_when_merging_inputs(self): # model is in eval mode by default so we should get pad on the left side # we can check the first hidden-states (aka inputs embeds) - # the first element was lo-res image and we expect the first 1414 tokens to be all pads - output_eval = model(**inputs_batched, output_hidden_states=True) - self.assertTrue((output_eval.hidden_states[0][0, :1414, ...] == 0).all().item()) - - # otherwise padding is on the right side, so it's last 1414 tokens - self.processor.padding_side = "right" - inputs_batched = self.processor( - [self.prompt, self.prompt], images=[lowres_img, cats_image], return_tensors="pt", padding=True - ).to(torch_device) - - model.train() + # the first element was lo-res image and we expect the first 732 tokens to be all pads with torch.no_grad(): - output_train = model(**inputs_batched, output_hidden_states=True) - self.assertTrue((output_train.hidden_states[0][0, -1414:, ...] == 0).all().item()) + output_eval = model(**inputs_batched, output_hidden_states=True) + self.assertTrue((output_eval.hidden_states[0][0, :732, ...] == 0).all().item()) with self.assertLogs("transformers", level="WARNING") as logs: model.padding_side = "left" model.train() - model(**inputs_batched, output_hidden_states=True) + with torch.no_grad(): + model(**inputs_batched, output_hidden_states=True) - self.assertIn( - "Padding side is set to 'left' but the model is in training mode. For training", logs.output[0] - ) + self.assertIn("Padding side is set to 'left' but the model is in training mode. For training", logs) with self.assertLogs("transformers", level="WARNING") as logs: model.padding_side = "right" model.eval() - model(**inputs_batched, output_hidden_states=True) + with torch.no_grad(): + model(**inputs_batched, output_hidden_states=True) - self.assertIn( - "Padding side is set to 'right' but the model is in inference mode. For correct", logs.output[0] - ) + self.assertIn("Padding side is set to 'right' but the model is in inference mode. For correct", logs) + + @slow + @require_bitsandbytes + def test_expansion_in_processing_multiimage(self): + model_id = "llava-hf/llava-v1.6-mistral-7b-hf" + model = LlavaNextForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True) + processor = AutoProcessor.from_pretrained(model_id) + + prompt = "USER: \nDescribe the similarity between the two images:\nASSISTANT:" + image_file = "http://images.cocodataset.org/val2017/000000039769.jpg" + raw_image = Image.open(requests.get(image_file, stream=True).raw) + deer_image = Image.open( + requests.get( + "https://4.img-dpreview.com/files/p/TS560x560~forums/56876524/03975b28741443319e9a94615e35667e", + stream=True, + ).raw + ) + + # check processing with expansion of inputs + processor.vision_feature_select_strategy = "default" + processor.patch_size = 14 + inputs_expanded = processor(text=prompt, images=[raw_image, deer_image], return_tensors="pt").to( + torch_device, torch.float16 + ) + self.assertTrue(inputs_expanded.input_ids.shape[-1] == 3969) + + # check processing without expansion of inputs (legacy behavior) + processor.vision_feature_select_strategy = None + processor.patch_size = None + inputs = processor(text=prompt, images=[raw_image, deer_image], return_tensors="pt").to( + torch_device, torch.float16 + ) + self.assertTrue(inputs.input_ids.shape[-1] == 23) + + # generate exactly 20 tokens + output = model.generate(**inputs, min_new_tokens=20, max_new_tokens=20) + output_expanded = model.generate(**inputs_expanded, min_new_tokens=20, max_new_tokens=20) + + # check that both inputs are handled correctly and generate the same output + self.assertListEqual(output_expanded[:, -20:].tolist(), output[:, -20:].tolist()) @slow @require_bitsandbytes diff --git a/tests/models/llava_next_video/test_modeling_llava_next_video.py b/tests/models/llava_next_video/test_modeling_llava_next_video.py index 38b1782b75d6ac..35df4085df0563 100644 --- a/tests/models/llava_next_video/test_modeling_llava_next_video.py +++ b/tests/models/llava_next_video/test_modeling_llava_next_video.py @@ -18,6 +18,7 @@ import unittest import numpy as np +import requests from huggingface_hub import hf_hub_download from transformers import ( @@ -363,29 +364,6 @@ def test_small_model_integration_test(self): ) inputs = self.processor(self.prompt_video, videos=self.video, return_tensors="pt") - expected_input_ids = [ - 1, - 3148, - 1001, - 29901, - 29871, - 32000, - 13, - 11008, - 338, - 445, - 4863, - 2090, - 1460, - 29973, - 319, - 1799, - 9047, - 13566, - 29901, - ] - self.assertListEqual(expected_input_ids, inputs.input_ids[0].tolist()) - # verify single forward pass inputs = inputs.to(torch_device) with torch.no_grad(): @@ -393,7 +371,7 @@ def test_small_model_integration_test(self): # verify generation output = model.generate(**inputs, do_sample=False, max_new_tokens=40) - EXPECTED_DECODED_TEXT = 'USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and exaggerated reactions of the child to the book. The child appears to be reading a book, but instead of a calm and focused reading experience' # fmt: skip + EXPECTED_DECODED_TEXT = 'USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and somewhat comical situation of a young child reading a book while another child is attempting to read the same book. The child who is reading the book seems' # fmt: skip self.assertEqual( self.processor.decode(output[0], skip_special_tokens=True), @@ -416,7 +394,10 @@ def test_small_model_integration_test_batch(self): output = model.generate(**inputs, do_sample=False, max_new_tokens=20) - EXPECTED_DECODED_TEXT = ['USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and exaggerated reactions of the child to the', 'USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and exaggerated reactions of the child to the'] # fmt: skip + EXPECTED_DECODED_TEXT = [ + 'USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and somewhat comical situation of a young child reading a', + 'USER: \nWhy is this video funny? ASSISTANT: The humor in this video comes from the unexpected and somewhat comical situation of a young child reading a' + ] # fmt: skip self.assertEqual( self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT, @@ -447,7 +428,7 @@ def test_small_model_integration_test_batch_different_vision_types(self): # verify generation output = model.generate(**inputs, do_sample=False, max_new_tokens=50) - EXPECTED_DECODED_TEXT = 'USER: \nWhat is shown in this image? ASSISTANT: The image appears to be a graphical representation of a benchmark test for a machine learning model. It shows the performance of various models on a task, with the x-axis representing the number of parameters (measured in millions) and the y' # fmt: skip + EXPECTED_DECODED_TEXT = 'USER: \nWhat is shown in this image? ASSISTANT: The image appears to be a graphical representation of a machine learning model\'s performance on a task, likely related to natural language processing or text understanding. It shows a scatter plot with two axes, one labeled "BLIP-2"' # fmt: skip self.assertEqual(self.processor.decode(output[0], skip_special_tokens=True), EXPECTED_DECODED_TEXT) @slow @@ -493,41 +474,25 @@ def test_padding_side_when_merging_inputs(self): # model is in eval mode by default so we should get pad on the left side # we can check the first hidden-states (aka inputs embeds) # the first element was lo-res image and we expect the first 1482 tokens to be all pads - output_eval = model(**inputs_batched, output_hidden_states=True) - self.assertTrue((output_eval.hidden_states[0][0, :1482, ...] == 0).all().item()) - - # otherwise padding is on the right side, so it's last 1482 tokens - self.processor.padding_side = "right" - inputs_batched = self.processor( - [self.prompt_video, self.prompt_image], - images=[self.image], - videos=[self.video], - return_tensors="pt", - padding=True, - ).to(torch_device) - - model.train() with torch.no_grad(): - output_train = model(**inputs_batched, output_hidden_states=True) - self.assertTrue((output_train.hidden_states[0][0, -1482:, ...] == 0).all().item()) + output_eval = model(**inputs_batched, output_hidden_states=True) + self.assertTrue((output_eval.hidden_states[0][0, :1482, ...] == 0).all().item()) with self.assertLogs("transformers", level="WARNING") as logs: model.padding_side = "left" model.train() - model(**inputs_batched, output_hidden_states=True) + with torch.no_grad(): + model(**inputs_batched, output_hidden_states=True) - self.assertIn( - "Padding side is set to 'left' but the model is in training mode. For training", logs.output[0] - ) + self.assertIn("Padding side is set to 'left' but the model is in training mode. For training", logs) with self.assertLogs("transformers", level="WARNING") as logs: model.padding_side = "right" model.eval() - model(**inputs_batched, output_hidden_states=True) + with torch.no_grad(): + model(**inputs_batched, output_hidden_states=True) - self.assertIn( - "Padding side is set to 'right' but the model is in inference mode. For correct", logs.output[0] - ) + self.assertIn("Padding side is set to 'right' but the model is in inference mode. For correct", logs) @slow @require_bitsandbytes @@ -556,3 +521,73 @@ def test_expansion_in_processing(self): # check that both inputs are handled correctly and generate the same output self.assertListEqual(output_expanded[:, -20:].tolist(), output[:, -20:].tolist()) + + @slow + @require_bitsandbytes + def test_expansion_in_processing_images(self): + model_id = "llava-hf/LLaVA-NeXT-Video-7B-hf" + model = LlavaNextVideoForConditionalGeneration.from_pretrained( + "llava-hf/LLaVA-NeXT-Video-7B-hf", load_in_4bit=True + ) + processor = AutoProcessor.from_pretrained(model_id) + + # check processing with expansion of inputs + processor.vision_feature_select_strategy = "default" + processor.patch_size = 14 + inputs_expanded = processor(self.prompt_image, images=[self.image], return_tensors="pt").to(torch_device) + self.assertTrue(inputs_expanded.input_ids.shape[-1] == 2652) + + # check processing without expansion of inputs (legacy behavior) + processor.vision_feature_select_strategy = None + processor.patch_size = None + inputs = processor(self.prompt_image, images=[self.image], return_tensors="pt").to(torch_device) + self.assertTrue(inputs.input_ids.shape[-1] == 19) + + # generate exactly 20 tokens + output = model.generate(**inputs, min_new_tokens=20, max_new_tokens=20) + output_expanded = model.generate(**inputs_expanded, min_new_tokens=20, max_new_tokens=20) + + # check that both inputs are handled correctly and generate the same output + self.assertListEqual(output_expanded[:, -20:].tolist(), output[:, -20:].tolist()) + + @slow + @require_bitsandbytes + def test_expansion_in_processing_multiimage(self): + model_id = "llava-hf/LLaVA-NeXT-Video-7B-hf" + model = LlavaNextVideoForConditionalGeneration.from_pretrained( + "llava-hf/LLaVA-NeXT-Video-7B-hf", load_in_4bit=True + ) + processor = AutoProcessor.from_pretrained(model_id) + + prompt = "USER: \nDescribe the similarity between the two images:\nASSISTANT:" + image_file = "http://images.cocodataset.org/val2017/000000039769.jpg" + raw_image = Image.open(requests.get(image_file, stream=True).raw) + deer_image = Image.open( + requests.get( + "https://4.img-dpreview.com/files/p/TS560x560~forums/56876524/03975b28741443319e9a94615e35667e", + stream=True, + ).raw + ) + + # check processing with expansion of inputs + processor.vision_feature_select_strategy = "default" + processor.patch_size = 14 + inputs_expanded = processor(text=prompt, images=[raw_image, deer_image], return_tensors="pt").to( + torch_device, torch.float16 + ) + self.assertTrue(inputs_expanded.input_ids.shape[-1] == 3968) + + # check processing without expansion of inputs (legacy behavior) + processor.vision_feature_select_strategy = None + processor.patch_size = None + inputs = processor(text=prompt, images=[raw_image, deer_image], return_tensors="pt").to( + torch_device, torch.float16 + ) + self.assertTrue(inputs.input_ids.shape[-1] == 22) + + # generate exactly 20 tokens + output = model.generate(**inputs, min_new_tokens=20, max_new_tokens=20) + output_expanded = model.generate(**inputs_expanded, min_new_tokens=20, max_new_tokens=20) + + # check that both inputs are handled correctly and generate the same output + self.assertListEqual(output_expanded[:, -20:].tolist(), output[:, -20:].tolist()) diff --git a/tests/models/video_llava/test_modeling_video_llava.py b/tests/models/video_llava/test_modeling_video_llava.py index 29fa3b71589aa7..a8b2229a02f5f2 100644 --- a/tests/models/video_llava/test_modeling_video_llava.py +++ b/tests/models/video_llava/test_modeling_video_llava.py @@ -383,18 +383,19 @@ def test_small_model_integration_test(self): # Let' s make sure we test the preprocessing to replace what is used model = VideoLlavaForConditionalGeneration.from_pretrained("LanguageBind/Video-LLaVA-7B-hf", load_in_4bit=True) - prompt = "USER: