Skip to content

Commit

Permalink
uniformize processor Mllama (#33876)
Browse files Browse the repository at this point in the history
* uniformize processor Mllama

* nit syntax

* nit
  • Loading branch information
yonigozlan authored Oct 2, 2024
1 parent 62e8c75 commit d7950bf
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 30 deletions.
11 changes: 6 additions & 5 deletions src/transformers/models/mllama/processing_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from ...image_utils import ImageInput
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import (
BatchEncoding,
PreTokenizedInput,
TextInput,
)
Expand Down Expand Up @@ -226,8 +225,10 @@ def __call__(
self,
images: Optional[ImageInput] = None,
text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
audio=None,
videos=None,
**kwargs: Unpack[MllamaProcessorKwargs],
) -> BatchEncoding:
) -> BatchFeature:
"""
Main method to prepare text(s) and image(s) to be fed as input to the model. This method forwards the `text`
arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] if `text` is not `None` to encode
Expand All @@ -250,7 +251,7 @@ def __call__(
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
Expand Down Expand Up @@ -323,9 +324,9 @@ def __call__(
data["cross_attention_mask"] = cross_attention_mask

return_tensors = common_kwargs.pop("return_tensors", None)
batch_encoding = BatchFeature(data=data, tensor_type=return_tensors)
batch_feature = BatchFeature(data=data, tensor_type=return_tensors)

return batch_encoding
return batch_feature

def batch_decode(self, *args, **kwargs):
"""
Expand Down
78 changes: 53 additions & 25 deletions tests/models/mllama/test_processor_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,32 +13,44 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import shutil
import tempfile
import unittest
from typing import Optional

import numpy as np

from transformers import MllamaProcessor
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_vision_available

from ...test_processing_common import ProcessorTesterMixin


if is_vision_available():
from PIL import Image


@require_torch
@require_vision
class MllamaProcessorTest(unittest.TestCase):
class MllamaProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_class = MllamaProcessor

def setUp(self):
self.checkpoint = "hf-internal-testing/mllama-11b" # TODO: change
self.processor = MllamaProcessor.from_pretrained(self.checkpoint)
self.checkpoint = "hf-internal-testing/mllama-11b"
processor = MllamaProcessor.from_pretrained(self.checkpoint)
self.image1 = Image.new("RGB", (224, 220))
self.image2 = Image.new("RGB", (512, 128))
self.image_token = self.processor.image_token
self.image_token_id = self.processor.image_token_id
self.pad_token_id = self.processor.tokenizer.pad_token_id
self.bos_token = self.processor.bos_token
self.bos_token_id = self.processor.tokenizer.bos_token_id
self.image_token = processor.image_token
self.image_token_id = processor.image_token_id
self.pad_token_id = processor.tokenizer.pad_token_id
self.bos_token = processor.bos_token
self.bos_token_id = processor.tokenizer.bos_token_id
self.tmpdirname = tempfile.mkdtemp()
processor.save_pretrained(self.tmpdirname)

def tearDown(self):
shutil.rmtree(self.tmpdirname)

def test_apply_chat_template(self):
# Message contains content which a mix of lists with images and image urls and string
Expand All @@ -64,8 +76,8 @@ def test_apply_chat_template(self):
],
},
]

rendered = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
processor = MllamaProcessor.from_pretrained(self.tmpdirname)
rendered = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)

expected_rendered = (
"<|begin_of_text|>"
Expand Down Expand Up @@ -96,7 +108,7 @@ def test_apply_chat_template(self):
],
},
]
input_ids = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)
input_ids = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)
expected_ids = [
128000, # <|begin_of_text|>
128006, # <|start_header_id|>
Expand Down Expand Up @@ -142,15 +154,15 @@ def test_apply_chat_template(self):
}
]

rendered = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
rendered = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
expected_rendered = (
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"
"Describe this image in two sentences<|image|> Test sentence <|image|>ok\n<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
self.assertEqual(rendered, expected_rendered)

input_ids = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)
input_ids = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)
# fmt: off
expected_ids = [
128000, 128006, 882, 128007, 271, 75885, 420, 2217, 304, 1403, 23719, 128256,
Expand All @@ -176,18 +188,19 @@ def test_apply_chat_template(self):
}
]

rendered_list = self.processor.apply_chat_template(messages_list, add_generation_prompt=True, tokenize=False)
rendered_str = self.processor.apply_chat_template(messages_str, add_generation_prompt=True, tokenize=False)
rendered_list = processor.apply_chat_template(messages_list, add_generation_prompt=True, tokenize=False)
rendered_str = processor.apply_chat_template(messages_str, add_generation_prompt=True, tokenize=False)
self.assertEqual(rendered_list, rendered_str)

def test_process_interleaved_images_prompts_image_splitting(self):
processor = MllamaProcessor.from_pretrained(self.tmpdirname)
# Test that a single image is processed correctly
inputs = self.processor(images=self.image2, size={"width": 224, "height": 224})
inputs = processor(images=self.image2, size={"width": 224, "height": 224})
self.assertEqual(inputs["pixel_values"].shape, (1, 1, 4, 3, 224, 224))

# Test that text is processed correctly
text = "<|begin_of_text|>This is a test sentence.<|end_of_text|>"
inputs = self.processor(text=text)
inputs = processor(text=text)
expected_ids = [128000, 2028, 374, 264, 1296, 11914, 13, 128001]
self.assertEqual(inputs["input_ids"][0], expected_ids)
self.assertEqual(inputs["attention_mask"][0], [1] * len(expected_ids))
Expand All @@ -197,7 +210,7 @@ def test_process_interleaved_images_prompts_image_splitting(self):
image_str = "<|image|>"
text_str = "This is a test sentence."
text = image_str + text_str
inputs = self.processor(
inputs = processor(
text=text,
images=self.image1,
size={"width": 128, "height": 128},
Expand Down Expand Up @@ -225,7 +238,7 @@ def test_process_interleaved_images_prompts_image_splitting(self):
]
# fmt: onn
images = [[self.image1], [self.image1, self.image2]]
inputs = self.processor(text=text, images=images, padding=True, size={"width": 256, "height": 256})
inputs = processor(text=text, images=images, padding=True, size={"width": 256, "height": 256})

self.assertEqual(inputs["pixel_values"].shape, (2, 2, 4, 3, 256, 256))
for input_ids_i, attention_mask_i, expected_ids_i in zip(inputs["input_ids"], inputs["attention_mask"], expected_ids):
Expand Down Expand Up @@ -264,34 +277,49 @@ def test_process_interleaved_images_prompts_image_error(self):
"This is a test sentence.",
"In this other sentence we try some good things",
]
inputs = self.processor(text=text, images=None, padding=True)
processor = MllamaProcessor.from_pretrained(self.tmpdirname)
inputs = processor(text=text, images=None, padding=True)
self.assertIsNotNone(inputs["input_ids"])

text = [
"This is a test sentence.<|image|>",
"In this other sentence we try some good things",
]
with self.assertRaises(ValueError):
self.processor(text=text, images=None, padding=True)
processor(text=text, images=None, padding=True)

images = [[self.image1], []]
with self.assertRaises(ValueError):
self.processor(text=text, images=images, padding=True)
processor(text=text, images=images, padding=True)

text = [
"This is a test sentence.<|image|>",
"In this other sentence we try some good things<|image|>",
]
with self.assertRaises(ValueError):
self.processor(text=text, images=None, padding=True)
processor(text=text, images=None, padding=True)

text = [
"This is a test sentence.<|image|>",
"In this other sentence we try some good things<|image|>",
]
images = [[self.image1], [self.image2]]
inputs = self.processor(text=text, images=images, padding=True)
inputs = processor(text=text, images=images, padding=True)

images = [[self.image1, self.image2], []]
with self.assertRaises(ValueError):
self.processor(text=text, images=None, padding=True)
processor(text=text, images=None, padding=True)

# Override as MllamaProcessor needs image tokens in prompts
def prepare_text_inputs(self, batch_size: Optional[int] = None):
if batch_size is None:
return "lower newer <|image|>"

if batch_size < 1:
raise ValueError("batch_size must be greater than 0")

if batch_size == 1:
return ["lower newer <|image|>"]
return ["lower newer <|image|>", "<|image|> upper older longer string"] + ["<|image|> lower newer"] * (
batch_size - 2
)

0 comments on commit d7950bf

Please sign in to comment.