Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expand inputs in processors for VLMs #30962

Merged
merged 44 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
050657f
let it be
zucchini-nlp May 20, 2024
a67087e
draft
zucchini-nlp May 22, 2024
1e2b873
should not have changed
zucchini-nlp May 22, 2024
70145d4
add warnings
zucchini-nlp May 29, 2024
16a6787
Merge remote-tracking branch 'upstream/main' into vlm_processors
zucchini-nlp May 29, 2024
8472035
fix & add tests
zucchini-nlp May 29, 2024
13af9e8
fix tests
zucchini-nlp May 29, 2024
41d086f
ipnuts embeds cannot be passed with pixels
zucchini-nlp May 29, 2024
bf59ed6
more updates
zucchini-nlp Jun 7, 2024
020e7ed
paligemma ready!
zucchini-nlp Jun 10, 2024
3e0455c
minor typos
zucchini-nlp Jun 10, 2024
674f16e
update blip-2
zucchini-nlp Jun 10, 2024
42ae646
fix tests & raise error
zucchini-nlp Jun 10, 2024
b5259f2
Merge branch 'main' into vlm_processors
zucchini-nlp Jun 10, 2024
a6c50de
docstring
zucchini-nlp Jun 10, 2024
4766e2e
add blip2 test
zucchini-nlp Jun 10, 2024
d46df90
Merge branch 'main' into vlm_processors
zucchini-nlp Jun 10, 2024
f74297b
tmp
zucchini-nlp Jun 17, 2024
5fc8565
add image seq length to config
zucchini-nlp Jun 18, 2024
1b4674a
update docstring
zucchini-nlp Jun 18, 2024
c3c130b
Merge branch 'main' into vlm_processors
zucchini-nlp Jun 18, 2024
8438875
delete
zucchini-nlp Jun 18, 2024
bf9e637
fix tests
zucchini-nlp Jun 18, 2024
db1fa4f
fix blip
zucchini-nlp Jun 18, 2024
246b06a
fix paligemma
zucchini-nlp Jun 21, 2024
222bf9a
merge `main`
zucchini-nlp Jul 18, 2024
5486215
out-of-place scatter
zucchini-nlp Jul 18, 2024
78c4484
add llava-next-video
zucchini-nlp Jul 18, 2024
d60624e
Update src/transformers/models/blip_2/modeling_blip_2.py
zucchini-nlp Aug 5, 2024
1973b39
remove tmp
zucchini-nlp Aug 5, 2024
a6e380f
merge `main`
zucchini-nlp Aug 5, 2024
8e88d8b
codestyle
zucchini-nlp Aug 5, 2024
689eed9
nits
zucchini-nlp Aug 6, 2024
28e8054
more nits
zucchini-nlp Aug 6, 2024
637e514
remove overriding in tests
zucchini-nlp Aug 6, 2024
be939d8
comprehension when merging video
zucchini-nlp Aug 6, 2024
232eb7c
fix-copies
zucchini-nlp Aug 6, 2024
385a617
revert changes for embeds test
zucchini-nlp Aug 6, 2024
4831a7e
fix tests after making comprehension
zucchini-nlp Aug 6, 2024
85fbff9
Update src/transformers/models/blip_2/processing_blip_2.py
zucchini-nlp Aug 8, 2024
119178f
Update src/transformers/models/blip_2/processing_blip_2.py
zucchini-nlp Aug 8, 2024
2451911
more updates
zucchini-nlp Aug 8, 2024
414031e
fix tests
zucchini-nlp Aug 8, 2024
8cfad20
Merge remote-tracking branch 'upstream/main' into vlm_processors
zucchini-nlp Aug 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
merge main
  • Loading branch information
zucchini-nlp committed Aug 5, 2024
commit a6e380fe495b4350c1a70f95f5292ece1326fe10
26 changes: 13 additions & 13 deletions src/transformers/models/llava/processing_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,20 +128,20 @@ def __call__(
if images is not None:
image_inputs = self.image_processor(images, return_tensors=return_tensors)
else:
pixel_values = None
image_inputs = {}

if isinstance(text, str):
text = [text]
elif not isinstance(text, list) and not isinstance(text[0], str):
raise ValueError("Invalid input text. Please provide a string, or a list of strings")

# try to expand inputs in processing if we have the necessary parts
if (
pixel_values is not None
and self.patch_size is not None
if image_inputs.get("pixel_values") is not None:
if (self.patch_size is not None
and self.vision_feature_select_strategy is not None
):
# Replace the image token with the expanded image token sequence
pixel_values = image_inputs["pixel_values"]
height, width = get_image_size(to_numpy_array(pixel_values[0]))
zucchini-nlp marked this conversation as resolved.
Show resolved Hide resolved
num_image_tokens = (height // self.patch_size) * (width // self.patch_size) + 1
if self.vision_feature_select_strategy == "default":
Expand All @@ -151,14 +151,14 @@ def __call__(
for sample in text:
sample = sample.replace(self.image_token, self.image_token * num_image_tokens)
prompt_strings.append(sample)
elif pixel_values is not None:
prompt_strings = text
logger.warning_once(
"Expanding inputs for image tokens in LLaVa should be done in processing. "
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
"Using processors without these attributes in the config is deprecated and will throw an error in v4.44."
)
else:
prompt_strings = text
logger.warning_once(
"Expanding inputs for image tokens in LLaVa should be done in processing. "
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
"Using processors without these attributes in the config is deprecated and will throw an error in v4.44."
)

text_inputs = self.tokenizer(
prompt_strings,
Expand All @@ -167,7 +167,7 @@ def __call__(
truncation=truncation,
max_length=max_length,
)
return BatchFeature(data={**text_inputs, "pixel_values": pixel_values})
return BatchFeature(data={**text_inputs, **image_inputs})

# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
def batch_decode(self, *args, **kwargs):
Expand Down
13 changes: 13 additions & 0 deletions tests/models/llava/test_modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,19 @@ def test_tokenizer_integration(self):
self.assertEqual(slow_tokenizer.tokenize(prompt), EXPECTED_OUTPUT)
self.assertEqual(fast_tokenizer.tokenize(prompt), EXPECTED_OUTPUT)

@slow
@require_bitsandbytes
def test_generation_no_images(self):
model_id = "llava-hf/llava-1.5-7b-hf"
model = LlavaForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True)
processor = AutoProcessor.from_pretrained(model_id)

# Prepare inputs with no images
inputs = processor("Hello, I am", return_tensors="pt").to(torch_device)

# Make sure that `generate` works
_ = model.generate(**inputs, max_new_tokens=20)

@slow
@require_bitsandbytes
def test_expansion_in_processing(self):
Expand Down
35 changes: 34 additions & 1 deletion tests/models/llava_next/test_modeling_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,40 @@ def test_small_model_integration_test_batch_matches_single(self):

@slow
@require_bitsandbytes
def test_expansion_in_processing(self):
def test_padding_side_when_merging_inputs(self):
model = LlavaNextForConditionalGeneration.from_pretrained(
"llava-hf/llava-v1.6-mistral-7b-hf",
load_in_4bit=True,
)

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
lowres_url = "https://4.img-dpreview.com/files/p/TS560x560~forums/56876524/03975b28741443319e9a94615e35667e"
cats_image = Image.open(requests.get(url, stream=True).raw)
lowres_img = Image.open(requests.get(lowres_url, stream=True).raw)

inputs_batched = self.processor(
[self.prompt, self.prompt], images=[lowres_img, cats_image], return_tensors="pt", padding=True
).to(torch_device)

# model is in eval mode by default so we should get pad on the left side
# we can check the first hidden-states (aka inputs embeds)
# the first element was lo-res image and we expect the first 1414 tokens to be all pads
output_eval = model(**inputs_batched, output_hidden_states=True)
self.assertTrue((output_eval.hidden_states[0][0, :1414, ...] == 0).all().item())

# otherwise padding is on the right side, so it's last 1414 tokens
self.processor.padding_side = "right"
inputs_batched = self.processor(
[self.prompt, self.prompt], images=[lowres_img, cats_image], return_tensors="pt", padding=True
).to(torch_device)

model.train()
with torch.no_grad():
output_train = model(**inputs_batched, output_hidden_states=True)
self.assertTrue((output_train.hidden_states[0][0, -1414:, ...] == 0).all().item())

@slow
@require_bitsandbytesdef test_expansion_in_processing(self):
model_id = "llava-hf/llava-v1.6-mistral-7b-hf"
model = LlavaNextForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True)
processor = AutoProcessor.from_pretrained(model_id)
Expand Down
36 changes: 36 additions & 0 deletions tests/models/llava_next_video/test_modeling_llava_next_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,42 @@ def test_small_model_integration_test_batch_matches_single(self):
self.processor.decode(output_single[0], skip_special_tokens=True),
)

@slow
@require_bitsandbytes
def test_padding_side_when_merging_inputs(self):
model = LlavaNextVideoForConditionalGeneration.from_pretrained(
"llava-hf/LLaVA-NeXT-Video-7B-hf", load_in_4bit=True
)

inputs_batched = self.processor(
[self.prompt_video, self.prompt_image],
images=[self.image],
videos=[self.video],
return_tensors="pt",
padding=True,
).to(torch_device)

# model is in eval mode by default so we should get pad on the left side
# we can check the first hidden-states (aka inputs embeds)
# the first element was lo-res image and we expect the first 1482 tokens to be all pads
output_eval = model(**inputs_batched, output_hidden_states=True)
self.assertTrue((output_eval.hidden_states[0][0, :1482, ...] == 0).all().item())

# otherwise padding is on the right side, so it's last 1482 tokens
self.processor.padding_side = "right"
inputs_batched = self.processor(
[self.prompt_video, self.prompt_image],
images=[self.image],
videos=[self.video],
return_tensors="pt",
padding=True,
).to(torch_device)

model.train()
with torch.no_grad():
output_train = model(**inputs_batched, output_hidden_states=True)
self.assertTrue((output_train.hidden_states[0][0, -1482:, ...] == 0).all().item())

@slow
@require_bitsandbytes
def test_expansion_in_processing(self):
Expand Down
You are viewing a condensed version of this merge commit. You can view the full changes here.