Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix patch_attention_mask incorrect setting which leads to the differe… #33499

Merged
merged 4 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/transformers/models/idefics2/modeling_idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1384,7 +1384,7 @@ def forward(
patch_size = self.config.vision_config.patch_size
patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) == patch_size * patch_size).bool()

# Get sequence from the vision encoder
image_hidden_states = self.vision_model(
Expand Down
35 changes: 35 additions & 0 deletions tests/models/idefics2/test_modeling_idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,41 @@ def test_integration_test_4bit(self):
expected_generated_text = "In this image, we see the Statue of Liberty, the Hudson River,"
self.assertEqual(generated_texts[0], expected_generated_text)

@slow
@require_bitsandbytes
def test_integration_test_4bit_batch2(self):
# Let' s make sure we test the preprocessing to replace what is used

model = Idefics2ForConditionalGeneration.from_pretrained(
"HuggingFaceM4/idefics2-8b-base",
load_in_4bit=True,
)

from datasets import load_dataset

dataset = load_dataset("nielsr/docvqa_1200_examples", split="test")

text = [f"<image>{dataset[40]['query']['en']}", f"<image>{dataset[41]['query']['en']}"]
sywangyi marked this conversation as resolved.
Show resolved Hide resolved
images = [[dataset[40]["image"]], [dataset[41]["image"]]]
inputs = self.processor(text=text, images=images, padding=True, return_tensors="pt")
generated_ids = model.generate(**inputs, max_new_tokens=64)
batched_generated_texts = self.processor.batch_decode(generated_ids, skip_special_tokens=True)

text = f"<image>{dataset[40]['query']['en']}"
images = dataset[40]["image"]
inputs = self.processor(text=text, images=images, padding=True, return_tensors="pt")
generated_ids = model.generate(**inputs, max_new_tokens=64)
generated_text_0 = self.processor.batch_decode(generated_ids, skip_special_tokens=True)

text = f"<image>{dataset[41]['query']['en']}"
images = dataset[41]["image"]
inputs = self.processor(text=text, images=images, padding=True, return_tensors="pt")
generated_ids = model.generate(**inputs, max_new_tokens=64)
generated_text_1 = self.processor.batch_decode(generated_ids, skip_special_tokens=True)

self.assertEqual(batched_generated_texts[0], generated_text_0[0])
self.assertEqual(batched_generated_texts[1], generated_text_1[0])

@require_flash_attn
@require_torch_gpu
@require_bitsandbytes
Expand Down