|
25 | 25 |
|
26 | 26 | from .blip import BlipVisionModel |
27 | 27 | from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP |
28 | | -from .utils import (AutoWeightsLoader, init_vllm_registered_model, |
| 28 | +from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, |
29 | 29 | maybe_prefix, merge_multimodal_embeddings) |
30 | 30 |
|
31 | 31 | # We use this internally as placeholders since there is no image token |
@@ -565,25 +565,23 @@ def _parse_and_validate_image_input( |
565 | 565 | return None |
566 | 566 |
|
567 | 567 | if pixel_values is not None: |
568 | | - if not isinstance(pixel_values, torch.Tensor): |
| 568 | + if not isinstance(pixel_values, (torch.Tensor, list)): |
569 | 569 | raise ValueError("Incorrect type of pixel values. " |
570 | 570 | f"Got type: {type(pixel_values)}") |
571 | 571 |
|
572 | | - # Remove the N dimension until multiple images are supported. |
573 | | - pixel_values = pixel_values.squeeze(1) |
| 572 | + pixel_values = flatten_bn(pixel_values, concat=True) |
574 | 573 |
|
575 | 574 | return Blip2ImagePixelInputs( |
576 | 575 | type="pixel_values", |
577 | 576 | data=self._validate_pixel_values(pixel_values), |
578 | 577 | ) |
579 | 578 |
|
580 | 579 | if image_embeds is not None: |
581 | | - if not isinstance(image_embeds, torch.Tensor): |
| 580 | + if not isinstance(image_embeds, (torch.Tensor, list)): |
582 | 581 | raise ValueError("Incorrect type of image embeddings. " |
583 | 582 | f"Got type: {type(image_embeds)}") |
584 | 583 |
|
585 | | - # Remove the N dimension until multiple images are supported. |
586 | | - image_embeds = image_embeds.squeeze(1) |
| 584 | + image_embeds = flatten_bn(image_embeds, concat=True) |
587 | 585 |
|
588 | 586 | return Blip2ImageEmbeddingInputs( |
589 | 587 | type="image_embeds", |
|
0 commit comments