diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index 345dc848b28005..f3c4e91d849c6b 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -512,6 +512,19 @@ def _merge_input_ids_with_image_features( image_token_index = image_token_index if image_token_index is not None else self.config.image_token_index ignore_index = ignore_index if ignore_index is not None else self.config.ignore_index + if self.training and self.padding_side == "left": + logger.warning_once( + "Padding side is set to 'left' but the model is in training mode. For training " + "it is recommended to set `model.padding_side='right' and `processor.tokenizer.padding_side='right'`. " + "If that's intended, ignore this warning" + ) + if not self.training and self.padding_side == "right": + logger.warning_once( + "Padding side is set to 'right' but the model is in inference mode. For correct " + "generation results, please set `model.padding_side='left'` and `processor.tokenizer.padding_side='left'`. " + "If that's intended, ignore this warning" + ) + with torch.no_grad(): # ! in llava 1.6, number of patches is variable num_images = feature_lens.size(0) @@ -522,18 +535,14 @@ def _merge_input_ids_with_image_features( _left_padding = torch.any(attention_mask[:, 0] == 0) _right_padding = torch.any(attention_mask[:, -1] == 0) - left_padding = True if not self.training else False - if batch_size > 1 and not self.training: - if _left_padding and not _right_padding: - left_padding = True - elif not _left_padding and _right_padding: - left_padding = False - elif not _left_padding and not _right_padding: - # both side is 1, so cannot tell - left_padding = self.padding_side == "left" - else: - # invalid attention_mask + left_padding = self.padding_side == "left" + if batch_size > 1: + if _left_padding and _right_padding: raise ValueError(f"both side of attention_mask has zero, invalid. {attention_mask}") + elif _right_padding and left_padding: + left_padding = False + elif _left_padding and not left_padding: + left_padding = True # Whether to turn off right padding # 1. Create a mask to know where special image tokens are diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index 2910438310e7f1..c635fb37bf955a 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -454,6 +454,7 @@ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_m self.vocab_size = model_embeds.num_embeddings return model_embeds + # Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration._merge_input_ids_with_image_features def _merge_input_ids_with_image_features( self, image_features, @@ -557,6 +558,19 @@ def _merge_input_ids_with_image_features( image_token_index = image_token_index if image_token_index is not None else self.config.image_token_index ignore_index = ignore_index if ignore_index is not None else self.config.ignore_index + if self.training and self.padding_side == "left": + logger.warning_once( + "Padding side is set to 'left' but the model is in training mode. For training " + "it is recommended to set `model.padding_side='right' and `processor.tokenizer.padding_side='right'`. " + "If that's intended, ignore this warning" + ) + if not self.training and self.padding_side == "right": + logger.warning_once( + "Padding side is set to 'right' but the model is in inference mode. For correct " + "generation results, please set `model.padding_side='left'` and `processor.tokenizer.padding_side='left'`. " + "If that's intended, ignore this warning" + ) + with torch.no_grad(): # ! in llava 1.6, number of patches is variable num_images = feature_lens.size(0) @@ -567,18 +581,14 @@ def _merge_input_ids_with_image_features( _left_padding = torch.any(attention_mask[:, 0] == 0) _right_padding = torch.any(attention_mask[:, -1] == 0) - left_padding = True if not self.training else False - if batch_size > 1 and not self.training: - if _left_padding and not _right_padding: - left_padding = True - elif not _left_padding and _right_padding: - left_padding = False - elif not _left_padding and not _right_padding: - # both side is 1, so cannot tell - left_padding = self.padding_side == "left" - else: - # invalid attention_mask + left_padding = self.padding_side == "left" + if batch_size > 1: + if _left_padding and _right_padding: raise ValueError(f"both side of attention_mask has zero, invalid. {attention_mask}") + elif _right_padding and left_padding: + left_padding = False + elif _left_padding and not left_padding: + left_padding = True # Whether to turn off right padding # 1. Create a mask to know where special image tokens are diff --git a/tests/models/llava_next/test_modeling_llava_next.py b/tests/models/llava_next/test_modeling_llava_next.py index c08f5973d4211b..c665631c40331d 100644 --- a/tests/models/llava_next/test_modeling_llava_next.py +++ b/tests/models/llava_next/test_modeling_llava_next.py @@ -549,6 +549,24 @@ def test_padding_side_when_merging_inputs(self): output_train = model(**inputs_batched, output_hidden_states=True) self.assertTrue((output_train.hidden_states[0][0, -1414:, ...] == 0).all().item()) + with self.assertLogs("transformers", level="WARNING") as logs: + model.padding_side = "left" + model.train() + model(**inputs_batched, output_hidden_states=True) + + self.assertIn( + "Padding side is set to 'left' but the model is in training mode. For training", logs.output[0] + ) + + with self.assertLogs("transformers", level="WARNING") as logs: + model.padding_side = "right" + model.eval() + model(**inputs_batched, output_hidden_states=True) + + self.assertIn( + "Padding side is set to 'right' but the model is in inference mode. For correct", logs.output[0] + ) + @slow @require_bitsandbytes def test_expansion_in_processing(self): diff --git a/tests/models/llava_next_video/test_modeling_llava_next_video.py b/tests/models/llava_next_video/test_modeling_llava_next_video.py index 6bdb178ad0a718..38b1782b75d6ac 100644 --- a/tests/models/llava_next_video/test_modeling_llava_next_video.py +++ b/tests/models/llava_next_video/test_modeling_llava_next_video.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Testing suite for the PyTorch Llava-NeXT model.""" +"""Testing suite for the PyTorch Llava-NeXT-Video model.""" import gc import unittest @@ -511,6 +511,24 @@ def test_padding_side_when_merging_inputs(self): output_train = model(**inputs_batched, output_hidden_states=True) self.assertTrue((output_train.hidden_states[0][0, -1482:, ...] == 0).all().item()) + with self.assertLogs("transformers", level="WARNING") as logs: + model.padding_side = "left" + model.train() + model(**inputs_batched, output_hidden_states=True) + + self.assertIn( + "Padding side is set to 'left' but the model is in training mode. For training", logs.output[0] + ) + + with self.assertLogs("transformers", level="WARNING") as logs: + model.padding_side = "right" + model.eval() + model(**inputs_batched, output_hidden_states=True) + + self.assertIn( + "Padding side is set to 'right' but the model is in inference mode. For correct", logs.output[0] + ) + @slow @require_bitsandbytes def test_expansion_in_processing(self):