3
3
from abc import abstractmethod
4
4
from collections .abc import Iterable , Mapping , Sequence
5
5
from functools import cached_property
6
- from typing import (Final , List , Literal , Optional , Protocol , Set , Tuple ,
7
- TypedDict , TypeVar , Union , cast )
6
+ from typing import (Final , Literal , Optional , Protocol , Set , Tuple , TypedDict ,
7
+ TypeVar , Union , cast )
8
8
9
9
import torch
10
10
import torch .nn as nn
39
39
40
40
from .clip import CLIPVisionModel
41
41
from .interfaces import MultiModalEmbeddings , SupportsMultiModal , SupportsPP
42
- from .pixtral import (PixtralHFVisionModel ,
43
- get_pixtral_hf_image_feature_grid_size )
42
+ from .pixtral import PixtralHFEncoderInfo , PixtralHFVisionModel
44
43
from .siglip import SiglipVisionModel
45
44
from .utils import (AutoWeightsLoader , flatten_bn , init_vllm_registered_model ,
46
45
maybe_prefix , merge_multimodal_embeddings )
49
48
50
49
class LlavaImagePixelInputs (TypedDict ):
51
50
type : Literal ["pixel_values" ]
52
- data : Union [ torch .Tensor , List [ torch . Tensor ]]
51
+ pixel_values : torch .Tensor
53
52
"""
54
53
Shape: `(batch_size * num_images, num_channels, height, width)`
55
54
56
55
Note that `height` or `width` may be different per batch and image,
57
56
in which case the data is passed as a list instead of a batched tensor.
58
57
"""
59
58
60
- feat_is_patch : Union [torch .Tensor , List [torch .Tensor ]]
59
+
60
+ class PixtralHFImagePixelInputs (TypedDict ):
61
+ type : Literal ["pixel_values_pixtral" ]
62
+ pixel_values : Union [torch .Tensor , list [torch .Tensor ]]
63
+ """
64
+ Shape: `(batch_size * num_images, num_channels, height, width)`
65
+
66
+ Note that `height` or `width` may be different per batch and image,
67
+ in which case the data is passed as a list instead of a batched tensor.
68
+ """
69
+
70
+ feat_is_patch : Union [torch .Tensor , list [torch .Tensor ]]
61
71
"""
62
72
A boolean mask indicating which image features correspond
63
73
to patch tokens.
64
74
65
75
Shape: `(batch_size, num_crops, num_patch)`
66
76
"""
67
77
68
- embed_is_patch : Union [torch .Tensor , List [torch .Tensor ]]
78
+ embed_is_patch : Union [torch .Tensor , list [torch .Tensor ]]
69
79
"""
70
80
A boolean mask indicating which image embeddings correspond
71
81
to patch tokens.
72
82
73
83
Shape: `(batch_size, num_embeds)`
74
84
"""
75
85
76
- num_crops : torch .Tensor
86
+ num_crops : Union [ torch .Tensor , list [ torch . Tensor ]]
77
87
"""Shape: `(batch_size, num_images)`"""
78
88
79
89
@@ -85,27 +95,9 @@ class LlavaImageEmbeddingInputs(TypedDict):
85
95
`hidden_size` must match the hidden size of language model backbone.
86
96
"""
87
97
88
- feat_is_patch : Union [torch .Tensor , List [torch .Tensor ]]
89
- """
90
- A boolean mask indicating which image features correspond
91
- to patch tokens.
92
-
93
- Shape: `(batch_size, num_crops, num_patch)`
94
- """
95
-
96
- embed_is_patch : Union [torch .Tensor , List [torch .Tensor ]]
97
- """
98
- A boolean mask indicating which image embeddings correspond
99
- to patch tokens.
100
-
101
- Shape: `(batch_size, num_embeds)`
102
- """
103
-
104
- num_crops : torch .Tensor
105
- """Shape: `(batch_size, num_images)`"""
106
-
107
98
108
- LlavaImageInputs = Union [LlavaImagePixelInputs , LlavaImageEmbeddingInputs ]
99
+ LlavaImageInputs = Union [LlavaImagePixelInputs , PixtralHFImagePixelInputs ,
100
+ LlavaImageEmbeddingInputs ]
109
101
110
102
111
103
class LlavaMultiModalProjector (nn .Module ):
@@ -357,13 +349,15 @@ def _call_hf_processor(
357
349
]
358
350
359
351
hf_config = self .info .get_hf_config ()
352
+ vision_config = hf_config .vision_config
353
+ assert isinstance (vision_config , PixtralVisionConfig )
354
+ encoder_info = PixtralHFEncoderInfo (vision_config )
360
355
361
356
tile_sizes = [
362
- get_pixtral_hf_image_feature_grid_size (
363
- hf_config .vision_config ,
357
+ encoder_info .get_patch_grid_size (
364
358
image_width = pixel_value .shape [- 1 ],
365
- image_height = pixel_value .shape [- 2 ])
366
- for pixel_value in processed_outputs ["pixel_values" ]
359
+ image_height = pixel_value .shape [- 2 ],
360
+ ) for pixel_value in processed_outputs ["pixel_values" ]
367
361
]
368
362
num_crops = torch .tensor ([(ncols + 1 ) * nrows
369
363
for ncols , nrows in tile_sizes ])
@@ -411,13 +405,13 @@ def _get_prompt_updates(
411
405
412
406
vision_config = hf_config .vision_config
413
407
assert isinstance (vision_config , PixtralVisionConfig )
408
+ encoder_info = PixtralHFEncoderInfo (vision_config )
414
409
415
410
def get_replacement (item_idx : int ):
416
411
images = mm_items .get_items ("image" , ImageProcessorItems )
417
412
image_size = images .get_image_size (item_idx )
418
413
419
- ncols , nrows = get_pixtral_hf_image_feature_grid_size (
420
- vision_config ,
414
+ ncols , nrows = encoder_info .get_patch_grid_size (
421
415
image_width = image_size .width ,
422
416
image_height = image_size .height ,
423
417
)
@@ -512,7 +506,7 @@ def init_vision_tower_for_llava(
512
506
* ,
513
507
require_post_norm : Optional [bool ] = None ,
514
508
prefix : str = "" ,
515
- ):
509
+ ) -> Union [ CLIPVisionModel , SiglipVisionModel , PixtralHFVisionModel ] :
516
510
vision_config = hf_config .vision_config
517
511
518
512
# Initialize the vision tower only up to the deepest required feature layer
@@ -627,57 +621,52 @@ def _parse_and_validate_image_input(
627
621
if pixel_values is None and image_embeds is None :
628
622
return None
629
623
630
- feat_is_patch = kwargs .pop ("feat_is_patch" , None )
631
- if feat_is_patch is not None and not isinstance (
632
- feat_is_patch , (torch .Tensor , list )):
633
- raise ValueError ("Incorrect type of feat_is_patch. "
634
- f"Got type: { type (feat_is_patch )} " )
635
-
636
- embed_is_patch = kwargs .pop ("embed_is_patch" , None )
637
- if embed_is_patch is not None and not isinstance (
638
- embed_is_patch , (torch .Tensor , list )):
639
- raise ValueError ("Incorrect type of embed_is_patch. "
640
- f"Got type: { type (embed_is_patch )} " )
641
-
642
- num_crops = kwargs .pop ("num_crops" , None )
643
- if num_crops is not None and not isinstance (num_crops , torch .Tensor ):
644
- raise ValueError ("Incorrect type of num_crops. "
645
- f"Got type: { type (num_crops )} " )
646
-
647
624
if pixel_values is not None :
648
625
if not isinstance (pixel_values , (torch .Tensor , list )):
649
626
raise ValueError ("Incorrect type of pixel values. "
650
627
f"Got type: { type (pixel_values )} " )
651
628
652
629
if self .config .vision_config .model_type == "pixtral" :
653
- return LlavaImagePixelInputs (
654
- type = "pixel_values" ,
655
- data = flatten_bn (pixel_values ),
630
+ feat_is_patch = kwargs .pop ("feat_is_patch" )
631
+ if not isinstance (feat_is_patch , (torch .Tensor , list )):
632
+ raise ValueError ("Incorrect type of feat_is_patch. "
633
+ f"Got type: { type (feat_is_patch )} " )
634
+
635
+ embed_is_patch = kwargs .pop ("embed_is_patch" )
636
+ if not isinstance (embed_is_patch , (torch .Tensor , list )):
637
+ raise ValueError ("Incorrect type of embed_is_patch. "
638
+ f"Got type: { type (embed_is_patch )} " )
639
+
640
+ num_crops = kwargs .pop ("num_crops" )
641
+ if not isinstance (num_crops , (torch .Tensor , list )):
642
+ raise ValueError ("Incorrect type of num_crops. "
643
+ f"Got type: { type (num_crops )} " )
644
+
645
+ return PixtralHFImagePixelInputs (
646
+ type = "pixel_values_pixtral" ,
647
+ pixel_values = flatten_bn (pixel_values ),
656
648
feat_is_patch = feat_is_patch ,
657
649
embed_is_patch = embed_is_patch ,
658
650
num_crops = num_crops ,
659
651
)
660
652
661
653
return LlavaImagePixelInputs (
662
654
type = "pixel_values" ,
663
- data = self ._validate_pixel_values (
655
+ pixel_values = self ._validate_pixel_values (
664
656
flatten_bn (pixel_values , concat = True )),
665
- feat_is_patch = feat_is_patch ,
666
- embed_is_patch = embed_is_patch ,
667
- num_crops = num_crops ,
668
657
)
669
658
670
659
if image_embeds is not None :
671
660
if not isinstance (image_embeds , (torch .Tensor , list )):
672
661
raise ValueError ("Incorrect type of image embeddings. "
673
662
f"Got type: { type (image_embeds )} " )
674
663
664
+ if self .config .vision_config .model_type == "pixtral" :
665
+ raise ValueError ("Pixtral-HF does not support image_embeds." )
666
+
675
667
return LlavaImageEmbeddingInputs (
676
668
type = "image_embeds" ,
677
669
data = flatten_bn (image_embeds , concat = True ),
678
- feat_is_patch = feat_is_patch ,
679
- embed_is_patch = embed_is_patch ,
680
- num_crops = num_crops ,
681
670
)
682
671
683
672
raise AssertionError ("This line should be unreachable." )
@@ -696,7 +685,7 @@ def _image_pixels_to_features(
696
685
self ,
697
686
vision_tower : Union [CLIPVisionModel , SiglipVisionModel ,
698
687
PixtralHFVisionModel ],
699
- pixel_values : torch .Tensor ,
688
+ pixel_values : Union [ torch .Tensor , list [ torch . Tensor ]] ,
700
689
) -> torch .Tensor :
701
690
702
691
# NOTE: we skip the step to select the vision feature layer since
@@ -708,17 +697,20 @@ def _image_pixels_to_features(
708
697
strategy = self .config .vision_feature_select_strategy ,
709
698
)
710
699
711
- def _process_image_pixels (self ,
712
- inputs : LlavaImagePixelInputs ) -> torch .Tensor :
700
+ def _process_image_pixels (
701
+ self ,
702
+ inputs : Union [LlavaImagePixelInputs , PixtralHFImagePixelInputs ],
703
+ ) -> torch .Tensor :
713
704
assert self .vision_tower is not None
714
705
715
- pixel_values = inputs ["data " ]
706
+ pixel_values = inputs ["pixel_values " ]
716
707
717
708
return self ._image_pixels_to_features (self .vision_tower , pixel_values )
718
709
719
- def _process_image_input (self ,
720
- image_input : LlavaImageInputs ) -> torch .Tensor :
721
-
710
+ def _process_image_input (
711
+ self ,
712
+ image_input : LlavaImageInputs ,
713
+ ) -> Union [torch .Tensor , tuple [torch .Tensor , ...]]:
722
714
if image_input ["type" ] == "image_embeds" :
723
715
return image_input ["data" ]
724
716
@@ -783,11 +775,11 @@ def get_multimodal_embeddings(
783
775
image_input = self ._parse_and_validate_image_input (** kwargs )
784
776
if image_input is None :
785
777
return None
778
+
786
779
vision_embeddings = self ._process_image_input (image_input )
787
780
788
- if kwargs .get ("v0_path" , False ) or \
789
- image_input .get ("feat_is_patch" ) is None or \
790
- image_input .get ("embed_is_patch" ) is None :
781
+ if (kwargs .get ("v0_path" , False )
782
+ or image_input ["type" ] != "pixel_values_pixtral" ):
791
783
# The path is used for pixtral (V0 only) and llava (V0/V1)
792
784
return vision_embeddings
793
785
0 commit comments