4
4
from collections .abc import Iterable , Mapping , Sequence
5
5
from functools import cached_property
6
6
from typing import (Final , List , Literal , Optional , Protocol , Set , Tuple ,
7
- TypedDict , TypeVar , Union )
7
+ TypedDict , TypeVar , Union , cast )
8
8
9
9
import torch
10
10
import torch .nn as nn
35
35
PromptReplacement , PromptUpdate )
36
36
from vllm .multimodal .profiling import BaseDummyInputsBuilder , ProcessorInputs
37
37
from vllm .sequence import IntermediateTensors
38
+ from vllm .utils import JSONTree , flatten_2d_lists , json_map_leaves
38
39
39
40
from .clip import CLIPVisionModel
40
41
from .interfaces import SupportsMultiModal , SupportsPP
@@ -56,6 +57,25 @@ class LlavaImagePixelInputs(TypedDict):
56
57
in which case the data is passed as a list instead of a batched tensor.
57
58
"""
58
59
60
+ feat_is_patch : Union [torch .Tensor , List [torch .Tensor ]]
61
+ """
62
+ A boolean mask indicating which image features correspond
63
+ to patch tokens.
64
+
65
+ Shape: `(batch_size, num_crops, num_patch)`
66
+ """
67
+
68
+ embed_is_patch : Union [torch .Tensor , List [torch .Tensor ]]
69
+ """
70
+ A boolean mask indicating which image embeddings correspond
71
+ to patch tokens.
72
+
73
+ Shape: `(batch_size, num_embeds)`
74
+ """
75
+
76
+ num_crops : torch .Tensor
77
+ """Shape: `(batch_size, num_images)`"""
78
+
59
79
60
80
class LlavaImageEmbeddingInputs (TypedDict ):
61
81
type : Literal ["image_embeds" ]
@@ -65,6 +85,25 @@ class LlavaImageEmbeddingInputs(TypedDict):
65
85
`hidden_size` must match the hidden size of language model backbone.
66
86
"""
67
87
88
+ feat_is_patch : Union [torch .Tensor , List [torch .Tensor ]]
89
+ """
90
+ A boolean mask indicating which image features correspond
91
+ to patch tokens.
92
+
93
+ Shape: `(batch_size, num_crops, num_patch)`
94
+ """
95
+
96
+ embed_is_patch : Union [torch .Tensor , List [torch .Tensor ]]
97
+ """
98
+ A boolean mask indicating which image embeddings correspond
99
+ to patch tokens.
100
+
101
+ Shape: `(batch_size, num_embeds)`
102
+ """
103
+
104
+ num_crops : torch .Tensor
105
+ """Shape: `(batch_size, num_images)`"""
106
+
68
107
69
108
LlavaImageInputs = Union [LlavaImagePixelInputs , LlavaImageEmbeddingInputs ]
70
109
@@ -317,14 +356,40 @@ def _call_hf_processor(
317
356
for p , (h , w ) in zip (pixel_values , image_sizes )
318
357
]
319
358
359
+ hf_config = self .info .get_hf_config ()
360
+
361
+ tile_sizes = [
362
+ get_pixtral_hf_image_feature_grid_size (
363
+ hf_config .vision_config ,
364
+ image_width = pixel_value .shape [- 1 ],
365
+ image_height = pixel_value .shape [- 2 ])
366
+ for pixel_value in processed_outputs ["pixel_values" ]
367
+ ]
368
+ num_crops = torch .tensor ([(ncols + 1 ) * nrows
369
+ for ncols , nrows in tile_sizes ])
370
+ # Each image may result to masks of different sizes, so we need to
371
+ # flatten the list and later use `num_crops` to get per-image masks.
372
+ embed_is_patch = torch .tensor (
373
+ flatten_2d_lists ([([True ] * ncols + [False ]) * nrows
374
+ for ncols , nrows in tile_sizes ]))
375
+ processed_outputs ["num_crops" ] = num_crops
376
+ processed_outputs ["embed_is_patch" ] = embed_is_patch
377
+ processed_outputs ["feat_is_patch" ] = embed_is_patch
378
+
320
379
return processed_outputs
321
380
322
381
def _get_mm_fields_config (
323
382
self ,
324
383
hf_inputs : BatchFeature ,
325
384
hf_processor_mm_kwargs : Mapping [str , object ],
326
385
) -> Mapping [str , MultiModalFieldConfig ]:
386
+ num_crops = hf_inputs .get ("num_crops" , torch .empty (0 )).view (- 1 )
327
387
return dict (
388
+ feat_is_patch = MultiModalFieldConfig .flat_from_sizes (
389
+ "image" , num_crops ),
390
+ embed_is_patch = MultiModalFieldConfig .flat_from_sizes (
391
+ "image" , num_crops ),
392
+ num_crops = MultiModalFieldConfig .batched ("image" ),
328
393
pixel_values = MultiModalFieldConfig .batched ("image" ),
329
394
image_embeds = MultiModalFieldConfig .batched ("image" ),
330
395
)
@@ -562,6 +627,23 @@ def _parse_and_validate_image_input(
562
627
if pixel_values is None and image_embeds is None :
563
628
return None
564
629
630
+ feat_is_patch = kwargs .pop ("feat_is_patch" , None )
631
+ if feat_is_patch is not None and not isinstance (
632
+ feat_is_patch , (torch .Tensor , list )):
633
+ raise ValueError ("Incorrect type of feat_is_patch. "
634
+ f"Got type: { type (feat_is_patch )} " )
635
+
636
+ embed_is_patch = kwargs .pop ("embed_is_patch" , None )
637
+ if embed_is_patch is not None and not isinstance (
638
+ embed_is_patch , (torch .Tensor , list )):
639
+ raise ValueError ("Incorrect type of embed_is_patch. "
640
+ f"Got type: { type (embed_is_patch )} " )
641
+
642
+ num_crops = kwargs .pop ("num_crops" , None )
643
+ if num_crops is not None and not isinstance (num_crops , torch .Tensor ):
644
+ raise ValueError ("Incorrect type of num_crops. "
645
+ f"Got type: { type (num_crops )} " )
646
+
565
647
if pixel_values is not None :
566
648
if not isinstance (pixel_values , (torch .Tensor , list )):
567
649
raise ValueError ("Incorrect type of pixel values. "
@@ -571,12 +653,18 @@ def _parse_and_validate_image_input(
571
653
return LlavaImagePixelInputs (
572
654
type = "pixel_values" ,
573
655
data = flatten_bn (pixel_values ),
656
+ feat_is_patch = feat_is_patch ,
657
+ embed_is_patch = embed_is_patch ,
658
+ num_crops = num_crops ,
574
659
)
575
660
576
661
return LlavaImagePixelInputs (
577
662
type = "pixel_values" ,
578
663
data = self ._validate_pixel_values (
579
664
flatten_bn (pixel_values , concat = True )),
665
+ feat_is_patch = feat_is_patch ,
666
+ embed_is_patch = embed_is_patch ,
667
+ num_crops = num_crops ,
580
668
)
581
669
582
670
if image_embeds is not None :
@@ -587,6 +675,9 @@ def _parse_and_validate_image_input(
587
675
return LlavaImageEmbeddingInputs (
588
676
type = "image_embeds" ,
589
677
data = flatten_bn (image_embeds , concat = True ),
678
+ feat_is_patch = feat_is_patch ,
679
+ embed_is_patch = embed_is_patch ,
680
+ num_crops = num_crops ,
590
681
)
591
682
592
683
raise AssertionError ("This line should be unreachable." )
@@ -633,16 +724,74 @@ def _process_image_input(self,
633
724
634
725
assert self .vision_tower is not None
635
726
image_features = self ._process_image_pixels (image_input )
636
- return self .multi_modal_projector (image_features )
637
727
638
- def get_multimodal_embeddings (
639
- self , ** kwargs
640
- ) -> Union [list [torch .Tensor ], torch .Tensor , tuple [torch .Tensor , ...]]:
728
+ if isinstance (image_features , torch .Tensor ):
729
+ return self .multi_modal_projector (image_features )
730
+
731
+ feature_sizes = [
732
+ image_feature .shape [0 ] for image_feature in image_features
733
+ ]
734
+
735
+ image_embeds = self .multi_modal_projector (torch .cat (image_features ))
736
+ image_embeds = torch .split (image_embeds , feature_sizes )
737
+ return image_embeds
738
+
739
+ def _get_mm_embeds (
740
+ self ,
741
+ features : torch .Tensor , # Shape: (num_crop, num_patch, d)
742
+ feat_is_patch : torch .Tensor , # Shape: (num_crop, num_patch)
743
+ num_crops : torch .Tensor , # Shape: (num_images,)
744
+ embed_is_patch : torch .Tensor , # Shape: (num_embeds,)
745
+ ) -> list [torch .Tensor ]:
746
+ """Scatter the patch features into a contiguous tensor that corresponds
747
+ to the embedding tokens defined by the multimodal processor.
748
+
749
+ Mostly copied from `Molmo._get_mm_embeds`. See following fixme comment.
750
+ """
751
+
752
+ # Insert columns of nan values according to `feat_is_patch`. This work
753
+ # ideally should be done in `_process_image_input`, but
754
+ # `_process_image_input` is used in both V0 and V1 path. It's safer to
755
+ # put the logic here.
756
+ # FIXME: Move this logic to `_process_image_input` when v0 is
757
+ # deprecated. Merge this function with `Molmo._get_mm_embeds`.
758
+ feat_is_patch = feat_is_patch .view (- 1 )
759
+ embed_is_patch = embed_is_patch .view (- 1 )
760
+ expanded_embedding = torch .full (
761
+ (sum (num_crops ), * features .shape [1 :]),
762
+ torch .nan ,
763
+ dtype = features .dtype ).to (features .device )
764
+ expanded_embedding [feat_is_patch ] = features
765
+
766
+ num_crops_per_image = num_crops .tolist ()
767
+ feats_per_image = expanded_embedding .split (num_crops_per_image )
768
+ f_is_patch_per_image = feat_is_patch .split (num_crops_per_image )
769
+
770
+ embed_dim = expanded_embedding .shape [- 1 ]
771
+ num_embeds = embed_is_patch .shape [0 ]
772
+
773
+ embeds_in_batch = list [torch .Tensor ]()
774
+ for feats , f_is_patch in zip (feats_per_image , f_is_patch_per_image ):
775
+ embeds = feats .new_full ((num_embeds , embed_dim ), torch .nan )
776
+ embeds [embed_is_patch ] = feats [f_is_patch ]
777
+ embeds_in_batch .append (embeds )
778
+
779
+ return embeds_in_batch
780
+
781
+ def get_multimodal_embeddings (self , ** kwargs ) -> Optional [NestedTensors ]:
641
782
image_input = self ._parse_and_validate_image_input (** kwargs )
642
783
if image_input is None :
643
784
return None
644
785
vision_embeddings = self ._process_image_input (image_input )
645
- return vision_embeddings
786
+ if kwargs .get ("v0_path" , False ):
787
+ return vision_embeddings
788
+ else :
789
+ nested_emb = [
790
+ self ._get_mm_embeds (* args ) for args in zip (
791
+ vision_embeddings , image_input ["feat_is_patch" ],
792
+ image_input ["num_crops" ], image_input ["embed_is_patch" ])
793
+ ]
794
+ return flatten_2d_lists (nested_emb )
646
795
647
796
def get_input_embeddings (
648
797
self ,
@@ -651,8 +800,15 @@ def get_input_embeddings(
651
800
) -> torch .Tensor :
652
801
inputs_embeds = self .language_model .get_input_embeddings (input_ids )
653
802
if multimodal_embeddings is not None :
803
+ # Extract the patch tokens
804
+ patch_embeddings = json_map_leaves (
805
+ lambda x : x [~ x .isnan ()].view (- 1 , * x .shape [1 :]),
806
+ cast (JSONTree [torch .Tensor ], multimodal_embeddings ),
807
+ )
808
+
654
809
inputs_embeds = merge_multimodal_embeddings (
655
- input_ids , inputs_embeds , multimodal_embeddings ,
810
+ input_ids , inputs_embeds , cast (NestedTensors ,
811
+ patch_embeddings ),
656
812
self .config .image_token_index )
657
813
return inputs_embeds
658
814
@@ -705,6 +861,7 @@ def forward(
705
861
# NOTE: In v1, inputs_embeds is always generated at model runner, this
706
862
# condition is for v0 compatibility.
707
863
elif inputs_embeds is None :
864
+ kwargs .update ({"v0_path" : True })
708
865
vision_embeddings = self .get_multimodal_embeddings (** kwargs )
709
866
inputs_embeds = self .get_input_embeddings (input_ids ,
710
867
vision_embeddings )
0 commit comments