Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LlavaNextProcessor bug in _get_unpadded_features #33261

Closed
1 of 4 tasks
laurentd-lunit opened this issue Sep 2, 2024 · 1 comment · Fixed by #33263
Closed
1 of 4 tasks

LlavaNextProcessor bug in _get_unpadded_features #33261

laurentd-lunit opened this issue Sep 2, 2024 · 1 comment · Fixed by #33263
Labels

Comments

@laurentd-lunit
Copy link
Contributor

System Info

  • transformers version: 4.45.0.dev0
  • Platform: Linux-5.15.0-78-generic-x86_64-with-glibc2.35
  • Python version: 3.11.9
  • Huggingface_hub version: 0.24.6
  • Safetensors version: 0.4.4
  • Accelerate version: 0.33.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.4.0 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?:
  • Using GPU in script?:
  • GPU type: NVIDIA H100 80GB HBM3

Who can help?

@zu

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

There is a typo in the following lines in LlavaNextProcessor as current_width and current_height are inverted which can cause errors due to miss match of image feature size computed by the processor and by the vision branch in LlavaNextForConditionalGeneration. I encountered this issue while running the following example script.

Here is a code snippet to reproduce the issue:

from transformers import LlavaNextProcessor
from transformers.models.llava_next.processing_llava_next import select_best_resolution
from transformers.models.llava_next.modeling_llava_next import unpad_image, get_anyres_image_grid_shape
import torch

POSSIBLE_RESOLUTIONS = [
    [
      336,
      672
    ],
    [
      672,
      336
    ],
    [
      672,
      672
    ],
    [
      1008,
      336
    ],
    [
      336,
      1008
    ]
]
processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
HEIGHT = 500
WIDTH = 316
VISION_MODEL_INPUT_SIZE = 336
PATCH_SIZE = 14
PATCH_DIM = VISION_MODEL_INPUT_SIZE // PATCH_SIZE


# Reproduce pre-processing steps in the processor
height_best_resolution, width_best_resolution = select_best_resolution(
[HEIGHT, WIDTH], POSSIBLE_RESOLUTIONS
)
scale_height, scale_width = height_best_resolution // VISION_MODEL_INPUT_SIZE, width_best_resolution // VISION_MODEL_INPUT_SIZE
patches_height = VISION_MODEL_INPUT_SIZE // PATCH_SIZE
patches_width = VISION_MODEL_INPUT_SIZE // PATCH_SIZE
unpadded_features, newline_features = processor._get_unpadded_features(HEIGHT, WIDTH, patches_height, patches_width, scale_height, scale_width)
num_unpad_features_from_processor = unpadded_features


# Reproduce computation of unpadded features in the vision branch
# Equivalent to:
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_next/modeling_llava_next.py#L676-L684
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
    (HEIGHT, WIDTH),
    POSSIBLE_RESOLUTIONS,
    VISION_MODEL_INPUT_SIZE,
)
unpad_features_from_vision = unpad_image(torch.randn(128, num_patch_height*PATCH_DIM, num_patch_width*PATCH_DIM), (HEIGHT, WIDTH))
num_unpad_features_from_vision = unpad_features_from_vision.shape[1] * unpad_features_from_vision.shape[2]

# Should be equal
assert num_unpad_features_from_processor == num_unpad_features_from_vision, f"Not equal: From processor: {num_unpad_features_from_processor}, from vision {num_unpad_features_from_vision}"

Expected behavior

No assertion error.

@LysandreJik
Copy link
Member

cc @zucchini-nlp maybe

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants