Skip to content

Commit

Permalink
pixtralVisionModel does not have a lm head
Browse files Browse the repository at this point in the history
  • Loading branch information
molbap committed Oct 29, 2024
1 parent 01c1ca1 commit a72d1cb
Showing 1 changed file with 1 addition and 39 deletions.
40 changes: 1 addition & 39 deletions tests/models/pixtral/test_modeling_pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,16 @@
# limitations under the License.
"""Testing suite for the PyTorch Pixtral model."""

import gc
import unittest

import requests

from transformers import (
AutoProcessor,
PixtralVisionConfig,
PixtralVisionModel,
is_torch_available,
is_vision_available,
)
from transformers.testing_utils import (
require_bitsandbytes,
require_torch,
slow,
torch_device,
)

Expand All @@ -43,7 +37,7 @@
is_torch_greater_or_equal_than_2_0 = False

if is_vision_available():
from PIL import Image
pass


class PixtralVisionModelTester:
Expand Down Expand Up @@ -259,35 +253,3 @@ def test_disk_offload_safetensors(self):
@unittest.skip(reason="Not supported yet")
def test_determinism(self):
pass


@require_torch
class PixtralVisionModelIntegrationTest(unittest.TestCase):
def setUp(self):
self.processor = AutoProcessor.from_pretrained("hf-internal-testing/pixtral-12b")

def tearDown(self):
gc.collect()
torch.cuda.empty_cache()

@slow
@require_bitsandbytes
def test_small_model_integration_test(self):
# Let' s make sure we test the preprocessing to replace what is used
model = PixtralVisionModel.from_pretrained("hf-internal-testing/pixtral-12b", load_in_4bit=True)

prompt = "<s>[INST][IMG]\nWhat are the things I should be cautious about when I visit this place?[/INST]"
image_file = "https://llava-vl.github.io/static/images/view.jpg"
raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = self.processor(prompt, raw_image, return_tensors="pt")

EXPECTED_INPUT_IDS = torch.tensor([[1, 32000, 28705, 13, 11123, 28747, 1824, 460, 272, 1722,315, 1023, 347, 13831, 925, 684, 739, 315, 3251, 456,1633, 28804, 13, 4816, 8048, 12738, 28747]]) # fmt: skip
self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS))

output = model.generate(**inputs, max_new_tokens=20)
EXPECTED_DECODED_TEXT = "\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT: When visiting this place, there are a few things one should be cautious about. Firstly," # fmt: skip

self.assertEqual(
self.processor.decode(output[0], skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)

0 comments on commit a72d1cb

Please sign in to comment.