Skip to content

Commit 25a4fa3

Browse files
committed
add integration tests for qwen3vl-30a3
1 parent 6f123f7 commit 25a4fa3

File tree

3 files changed

+285
-3
lines changed

3 files changed

+285
-3
lines changed

src/transformers/models/qwen3_vl/modular_qwen3_vl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1394,7 +1394,7 @@ def __call__(
13941394
index = 0
13951395
for i in range(len(text)):
13961396
while self.video_token in text[i]:
1397-
metadata = video_metadata[i]
1397+
metadata = video_metadata[index]
13981398
if metadata.fps is None:
13991399
logger.warning_once(
14001400
"Qwen3VL requires frame timestamps to construct prompts, but the `fps` of the input video could not be inferred. "

tests/models/qwen3_vl/test_processing_qwen3_vl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ def test_apply_chat_template_video_frame_sampling(self):
293293
self.assertListEqual(list(out_dict.keys()), ["input_ids", "attention_mask"])
294294

295295
# for fast test, set the longest edge to 4096
296-
processor.video_processor.size['longest_edge'] = 8192
296+
processor.video_processor.size["longest_edge"] = 8192
297297

298298
# Add video URL for return dict and load with `num_frames` arg
299299
messages[0][0]["content"][0] = {
@@ -307,7 +307,7 @@ def test_apply_chat_template_video_frame_sampling(self):
307307
tokenize=True,
308308
return_dict=True,
309309
num_frames=num_frames,
310-
fps=None, # if pass num_frames, fps should be None
310+
fps=None, # if pass num_frames, fps should be None
311311
)
312312
self.assertTrue(self.videos_input_name in out_dict_with_video)
313313
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 256)

tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py

Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,18 @@
1717
import unittest
1818

1919
from transformers import (
20+
AutoProcessor,
2021
Qwen3VLMoeConfig,
2122
Qwen3VLMoeForConditionalGeneration,
2223
Qwen3VLMoeModel,
2324
is_torch_available,
2425
)
2526
from transformers.testing_utils import (
27+
cleanup,
28+
require_flash_attn,
2629
require_torch,
30+
require_torch_gpu,
31+
slow,
2732
torch_device,
2833
)
2934

@@ -296,3 +301,280 @@ def test_video_forward(self):
296301
video_grid_thw=video_grid_thw,
297302
)
298303
self.assertIsNotNone(outputs)
304+
305+
306+
@require_torch
307+
@unittest.skip("The checkpoint is not yet released")
308+
class Qwen3VLMoeIntegrationTest(unittest.TestCase):
309+
def setUp(self):
310+
cleanup(torch_device, gc_collect=True)
311+
312+
self.processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL-30B-A3B-Instruct")
313+
self.processor.tokenizer.padding_side = "left"
314+
self.message = [
315+
{
316+
"role": "user",
317+
"content": [
318+
{
319+
"type": "image",
320+
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg",
321+
},
322+
{"type": "text", "text": "What kind of dog is this?"},
323+
],
324+
}
325+
]
326+
self.message2 = [
327+
{
328+
"role": "user",
329+
"content": [
330+
{
331+
"type": "image",
332+
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/coco_sample.png",
333+
},
334+
{"type": "text", "text": "What kind of dog is this?"},
335+
],
336+
}
337+
]
338+
339+
def tearDown(self):
340+
cleanup(torch_device, gc_collect=True)
341+
342+
@slow
343+
def test_small_model_integration_test(self):
344+
model = Qwen3VLMoeForConditionalGeneration.from_pretrained(
345+
"Qwen/Qwen3-VL-30B-A3B-Instruct", dtype="auto", device_map="auto"
346+
)
347+
348+
inputs = self.processor.apply_chat_template(
349+
self.message, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt"
350+
)
351+
expected_input_ids = [151644, 872, 198, 151652, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655, 151655] # fmt: skip
352+
assert expected_input_ids == inputs.input_ids[0].tolist()[:17]
353+
354+
expected_pixel_slice = torch.tensor(
355+
[
356+
[-0.0902, -0.0824, -0.0824],
357+
[-0.2627, -0.2627, -0.2627],
358+
[-0.0824, -0.0902, -0.0902],
359+
[-0.0118, -0.0510, -0.1137],
360+
[-0.5137, -0.5529, -0.6078],
361+
[-0.6941, -0.6314, -0.5765],
362+
],
363+
dtype=torch.float32,
364+
device="cpu",
365+
)
366+
assert torch.allclose(expected_pixel_slice, inputs.pixel_values[:6, :3], atol=3e-3)
367+
368+
# verify generation
369+
inputs = inputs.to(torch_device)
370+
371+
output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
372+
EXPECTED_DECODED_TEXT = "user\nWhat kind of dog is this?\nassistant\nThis is a Pallas's cat, also known as the manul. It's a small wild cat native to the grasslands and steppes"
373+
self.assertEqual(
374+
self.processor.decode(output[0], skip_special_tokens=True),
375+
EXPECTED_DECODED_TEXT,
376+
)
377+
378+
@slow
379+
def test_small_model_integration_test_batch(self):
380+
model = Qwen3VLMoeForConditionalGeneration.from_pretrained(
381+
"Qwen/Qwen3-VL-30B-A3B-Instruct", dtype="auto", device_map="auto"
382+
)
383+
batch_messages = [self.message] * 2
384+
inputs = self.processor.apply_chat_template(
385+
batch_messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt"
386+
).to(torch_device)
387+
388+
# it should not matter whether two images are the same size or not
389+
output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
390+
391+
EXPECTED_DECODED_TEXT = [
392+
"user\nWhat kind of dog is this?\nassistant\nThis is a Pallas's cat, also known as the manul. It's a small wild cat native to the grasslands and montane regions",
393+
"user\nWhat kind of dog is this?\nassistant\nThis is a Pallas's cat, also known as the manul. It's a small wild cat native to the grasslands and montane regions"
394+
] # fmt: skip
395+
self.assertEqual(
396+
self.processor.batch_decode(output, skip_special_tokens=True),
397+
EXPECTED_DECODED_TEXT,
398+
)
399+
400+
@slow
401+
def test_small_model_integration_test_with_video(self):
402+
processor = AutoProcessor.from_pretrained(
403+
"Qwen/Qwen3-VL-30B-A3B-Instruct", max_image_size={"longest_edge": 50176}
404+
)
405+
model = Qwen3VLMoeForConditionalGeneration.from_pretrained(
406+
"Qwen/Qwen3-VL-30B-A3B-Instruct", dtype=torch.float16, device_map="auto"
407+
)
408+
questions = ["How long is the video? Describe the it in short."]
409+
video_urls = ["https://huggingface.co/datasets/hf-internal-testing/fixtures_videos/resolve/main/tennis.mp4"]
410+
messages = [
411+
[
412+
{
413+
"role": "user",
414+
"content": [
415+
{
416+
"type": "video",
417+
"video": video_url,
418+
},
419+
{"type": "text", "text": question},
420+
],
421+
}
422+
]
423+
for question, video_url in zip(questions, video_urls)
424+
]
425+
inputs = processor.apply_chat_template(
426+
messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt", padding=True
427+
).to(torch_device)
428+
429+
output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
430+
EXPECTED_DECODED_TEXT = ["user\n<0.3 seconds><1.4 seconds><2.5 seconds><3.6 seconds><4.7 seconds><5.8 seconds>How long is the video? Describe the it in short.\nassistant\nThe video is 6 seconds long. It shows a man playing tennis on an indoor court. He is wearing a white shirt and black shorts. He"] # fmt: skip
431+
432+
self.assertEqual(
433+
processor.batch_decode(output, skip_special_tokens=True),
434+
EXPECTED_DECODED_TEXT,
435+
)
436+
437+
@slow
438+
def test_small_model_integration_test_expand(self):
439+
model = Qwen3VLMoeForConditionalGeneration.from_pretrained(
440+
"Qwen/Qwen3-VL-30B-A3B-Instruct", dtype="auto", device_map="auto"
441+
)
442+
inputs = self.processor.apply_chat_template(
443+
self.message, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt"
444+
).to(torch_device)
445+
446+
output = model.generate(**inputs, max_new_tokens=30, do_sample=False, num_beams=2, num_return_sequences=2)
447+
448+
EXPECTED_DECODED_TEXT = [
449+
"user\nWhat kind of dog is this?\nassistant\nThe animal in the image is not a dog. It is a **Pallas's cat** (*Otocolobus manul*), also known",
450+
"user\nWhat kind of dog is this?\nassistant\nThe animal in the image is not a dog. It is a **Pallas's cat** (also known as the manul), a wild f"
451+
] # fmt: skip
452+
self.assertEqual(
453+
self.processor.batch_decode(output, skip_special_tokens=True),
454+
EXPECTED_DECODED_TEXT,
455+
)
456+
457+
@slow
458+
def test_small_model_integration_test_batch_wo_image(self):
459+
model = Qwen3VLMoeForConditionalGeneration.from_pretrained(
460+
"Qwen/Qwen3-VL-30B-A3B-Instruct", dtype="auto", device_map="auto"
461+
)
462+
message_wo_image = [
463+
{"role": "user", "content": [{"type": "text", "text": "Who are you?"}]},
464+
]
465+
batched_messages = [self.message, message_wo_image]
466+
inputs = self.processor.apply_chat_template(
467+
batched_messages,
468+
tokenize=True,
469+
add_generation_prompt=True,
470+
return_dict=True,
471+
return_tensors="pt",
472+
padding=True,
473+
).to(torch_device)
474+
475+
# it should not matter whether two images are the same size or not
476+
output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
477+
478+
EXPECTED_DECODED_TEXT = [
479+
"user\nWhat kind of dog is this?\nassistant\nThis is a Pallas's cat, also known as the manul. It's a wild cat species native to the grasslands and steppes",
480+
"user\nWho are you?\nassistant\nI am Qwen, a large-scale language model developed by Alibaba Cloud's Tongyi Lab. I can assist you with answering questions, creating text such"
481+
] # fmt: skip
482+
self.assertEqual(
483+
self.processor.batch_decode(output, skip_special_tokens=True),
484+
EXPECTED_DECODED_TEXT,
485+
)
486+
487+
@slow
488+
def test_small_model_integration_test_batch_different_resolutions(self):
489+
model = Qwen3VLMoeForConditionalGeneration.from_pretrained(
490+
"Qwen/Qwen3-VL-30B-A3B-Instruct", dtype="auto", device_map="auto"
491+
)
492+
batched_messages = [self.message, self.message2]
493+
inputs = self.processor.apply_chat_template(
494+
batched_messages,
495+
tokenize=True,
496+
add_generation_prompt=True,
497+
return_dict=True,
498+
return_tensors="pt",
499+
padding=True,
500+
).to(torch_device)
501+
502+
# it should not matter whether two images are the same size or not
503+
output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
504+
505+
EXPECTED_DECODED_TEXT = [
506+
"user\nWhat kind of dog is this?\nassistant\nThis is a Pallas's cat, also known as the manul. It's a wild cat species native to the grasslands and steppes",
507+
"user\nWhat kind of dog is this?\nassistant\nBased on the image provided, the animals are not dogs. They are two cats.\n\nHere is a description of the animals in the image:\n\n- "
508+
] # fmt: skip
509+
self.assertEqual(
510+
self.processor.batch_decode(output, skip_special_tokens=True),
511+
EXPECTED_DECODED_TEXT,
512+
)
513+
514+
@slow
515+
@require_flash_attn
516+
@require_torch_gpu
517+
def test_small_model_integration_test_batch_flashatt2(self):
518+
model = Qwen3VLMoeForConditionalGeneration.from_pretrained(
519+
"Qwen/Qwen3-VL-30B-A3B-Instruct",
520+
dtype=torch.bfloat16,
521+
attn_implementation="flash_attention_2",
522+
device_map="auto",
523+
)
524+
batched_messages = [self.message, self.message2]
525+
inputs = self.processor.apply_chat_template(
526+
batched_messages,
527+
tokenize=True,
528+
add_generation_prompt=True,
529+
return_dict=True,
530+
return_tensors="pt",
531+
padding=True,
532+
).to(torch_device)
533+
534+
# it should not matter whether two images are the same size or not
535+
output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
536+
537+
EXPECTED_DECODED_TEXT = [
538+
"user\nWhat kind of dog is this?\nassistant\nThis is a Pallas's cat, also known as the manul. It's a wild cat species native to the grasslands and montane regions",
539+
"user\nWhat kind of dog is this?\nassistant\nBased on the image provided, there is no dog present. The animals in the picture are two cats.\n\nHere are some observations about the cats in the"
540+
] # fmt: skip
541+
self.assertEqual(
542+
self.processor.batch_decode(output, skip_special_tokens=True),
543+
EXPECTED_DECODED_TEXT,
544+
)
545+
546+
@slow
547+
@require_flash_attn
548+
@require_torch_gpu
549+
def test_small_model_integration_test_batch_wo_image_flashatt2(self):
550+
model = Qwen3VLMoeForConditionalGeneration.from_pretrained(
551+
"Qwen/Qwen3-VL-30B-A3B-Instruct",
552+
dtype=torch.bfloat16,
553+
attn_implementation="flash_attention_2",
554+
device_map="auto",
555+
)
556+
message_wo_image = [
557+
{"role": "user", "content": [{"type": "text", "text": "Who are you?"}]},
558+
]
559+
batched_messages = [self.message, message_wo_image]
560+
inputs = self.processor.apply_chat_template(
561+
batched_messages,
562+
tokenize=True,
563+
add_generation_prompt=True,
564+
return_dict=True,
565+
return_tensors="pt",
566+
padding=True,
567+
).to(torch_device)
568+
569+
# it should not matter whether two images are the same size or not
570+
output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
571+
572+
EXPECTED_DECODED_TEXT = [
573+
"user\nWhat kind of dog is this?\nassistant\nThis is a Pallas's cat, also known as the manul. It's a wild cat species native to the grasslands and montane regions",
574+
"user\nWho are you?\nassistant\nI am Qwen, a large-scale language model developed by Alibaba Cloud's Tongyi Lab. I can assist you with answering questions, creating text such"
575+
] # fmt: skip
576+
577+
self.assertEqual(
578+
self.processor.batch_decode(output, skip_special_tokens=True),
579+
EXPECTED_DECODED_TEXT,
580+
)

0 commit comments

Comments
 (0)