33
33
from torch import nn
34
34
from torch .nn .init import trunc_normal_
35
35
from transformers import PretrainedConfig
36
+ from typing_extensions import NotRequired
36
37
37
38
from vllm .attention import AttentionMetadata
38
39
from vllm .config import CacheConfig , MultiModalConfig
52
53
from vllm .model_executor .models .qwen2 import Qwen2Model
53
54
from vllm .model_executor .sampling_metadata import SamplingMetadata
54
55
from vllm .multimodal import MULTIMODAL_REGISTRY
56
+ from vllm .multimodal .base import MultiModalInputs
55
57
from vllm .multimodal .image import cached_get_image_processor
56
58
from vllm .multimodal .utils import cached_get_tokenizer
57
59
from vllm .sequence import IntermediateTensors , SequenceData
64
66
}
65
67
66
68
69
+ class MiniCPMVImageInput (TypedDict ):
70
+ """Input mapper input with auxiliary data for computing image bounds."""
71
+ image : Image .Image
72
+
73
+ # Image bounds token ids in 0-dim scaler tensor.
74
+ im_start_id : torch .Tensor
75
+ im_end_id : torch .Tensor
76
+ slice_start_id : NotRequired [torch .Tensor ]
77
+ slice_end_id : NotRequired [torch .Tensor ]
78
+
79
+
67
80
class MiniCPMVImagePixelInputs (TypedDict ):
68
81
pixel_values : List [torch .Tensor ]
69
82
"""
@@ -88,8 +101,6 @@ class MiniCPMVImagePixelInputs(TypedDict):
88
101
"""
89
102
90
103
91
- MiniCPMVImageInputs = MiniCPMVImagePixelInputs
92
-
93
104
DEFAULT_LN = partial (nn .LayerNorm , eps = 1e-6 )
94
105
95
106
@@ -234,6 +245,25 @@ def forward(self, x: torch.Tensor,
234
245
return x
235
246
236
247
248
+ def _build_image_input (ctx : InputContext ,
249
+ image : Image .Image ) -> MiniCPMVImageInput :
250
+ tokenizer = cached_get_tokenizer (
251
+ ctx .model_config .tokenizer ,
252
+ trust_remote_code = ctx .model_config .trust_remote_code )
253
+ if hasattr (tokenizer , "slice_start_id" ):
254
+ return MiniCPMVImageInput (
255
+ image = image ,
256
+ im_start_id = torch .tensor (tokenizer .im_start_id ),
257
+ im_end_id = torch .tensor (tokenizer .im_end_id ),
258
+ slice_start_id = torch .tensor (tokenizer .slice_start_id ),
259
+ slice_end_id = torch .tensor (tokenizer .slice_end_id ))
260
+ else :
261
+ return MiniCPMVImageInput (image = image ,
262
+ im_start_id = torch .tensor (
263
+ tokenizer .im_start_id ),
264
+ im_end_id = torch .tensor (tokenizer .im_end_id ))
265
+
266
+
237
267
def get_version_by_config (config : PretrainedConfig ) -> Tuple [int , ...]:
238
268
version_float = getattr (config , "version" , None )
239
269
@@ -257,10 +287,13 @@ def dummy_seq_data_for_minicpmv(seq_len: int, num_images: int):
257
287
return SequenceData .from_token_counts ((0 , seq_len ))
258
288
259
289
260
- def dummy_image_for_minicpmv (hf_config : PretrainedConfig , num_images : int ):
290
+ def dummy_image_for_minicpmv (ctx : InputContext , hf_config : PretrainedConfig ,
291
+ num_images : int ):
261
292
width = height = hf_config .image_size
262
- image = Image .new ("RGB" , (width , height ), color = 0 )
263
- return {"image" : image if num_images == 1 else [image ] * num_images }
293
+ image = _build_image_input (ctx ,
294
+ image = Image .new ("RGB" , (width , height ),
295
+ color = 0 ))
296
+ return {"image" : [image ] if num_images == 1 else [image ] * num_images }
264
297
265
298
266
299
def dummy_data_for_minicpmv (ctx : InputContext , seq_len : int ,
@@ -269,7 +302,7 @@ def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int,
269
302
num_images = mm_counts ["image" ]
270
303
271
304
seq_data = dummy_seq_data_for_minicpmv (seq_len , num_images )
272
- mm_data = dummy_image_for_minicpmv (hf_config , num_images )
305
+ mm_data = dummy_image_for_minicpmv (ctx , hf_config , num_images )
273
306
274
307
return seq_data , mm_data
275
308
@@ -280,8 +313,9 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
280
313
return llm_inputs
281
314
model_config = ctx .model_config
282
315
version = get_version_by_config (model_config .hf_config )
283
- tokenizer = cached_get_tokenizer (model_config .tokenizer ,
284
- trust_remote_code = True )
316
+ tokenizer = cached_get_tokenizer (
317
+ model_config .tokenizer ,
318
+ trust_remote_code = model_config .trust_remote_code )
285
319
image_processor = cached_get_image_processor (model_config .tokenizer )
286
320
287
321
def get_placeholder (image_size : Tuple [int , int ], num_image : int ):
@@ -317,6 +351,10 @@ def get_placeholder(image_size: Tuple[int, int], num_image: int):
317
351
new_prompt = "" .join (new_prompt_chunks )
318
352
new_token_ids = tokenizer .encode (new_prompt )
319
353
354
+ multi_modal_data ["image" ] = [
355
+ _build_image_input (ctx , image ) for image in images
356
+ ]
357
+
320
358
llm_inputs = LLMInputs (
321
359
prompt_token_ids = new_token_ids ,
322
360
prompt = new_prompt ,
@@ -325,6 +363,32 @@ def get_placeholder(image_size: Tuple[int, int], num_image: int):
325
363
return llm_inputs
326
364
327
365
366
+ def input_mapper_for_minicpmv (ctx : InputContext , data : object ):
367
+ model_config = ctx .model_config
368
+
369
+ image_processor = cached_get_image_processor (
370
+ model_config .model , trust_remote_code = model_config .trust_remote_code )
371
+ if image_processor is None :
372
+ raise RuntimeError ("No HuggingFace processor is available "
373
+ "to process the image object" )
374
+
375
+ if not isinstance (data , list ):
376
+ raise ValueError (
377
+ "Image input must be list of MiniCPMVImageInput, got (%s)" , data )
378
+ batch_data = image_processor \
379
+ .preprocess ([img ["image" ] for img in data ], return_tensors = "pt" ) \
380
+ .data
381
+
382
+ if len (data ) > 0 :
383
+ batch_data ["im_start_id" ] = data [0 ]["im_start_id" ]
384
+ batch_data ["im_end_id" ] = data [0 ]["im_end_id" ]
385
+ if "slice_start_id" in data [0 ]:
386
+ batch_data ["slice_start_id" ] = data [0 ]["slice_start_id" ]
387
+ batch_data ["slice_end_id" ] = data [0 ]["slice_end_id" ]
388
+
389
+ return MultiModalInputs (batch_data )
390
+
391
+
328
392
class MiniCPMVBaseModel (nn .Module , SupportsMultiModal ):
329
393
"""
330
394
The abstract class of MiniCPMV can only be inherited, but cannot be
@@ -365,7 +429,7 @@ def __init__(
365
429
def get_embedding (
366
430
self ,
367
431
input_ids : torch .Tensor ,
368
- image_inputs : Optional [MiniCPMVImageInputs ],
432
+ image_inputs : Optional [MiniCPMVImagePixelInputs ],
369
433
) -> Tuple [torch .Tensor , torch .Tensor ]:
370
434
vlm_embedding : torch .Tensor = self .llm .embed_tokens (input_ids )
371
435
if hasattr (self .config , "scale_emb" ):
@@ -393,14 +457,20 @@ def get_embedding(
393
457
394
458
return vlm_embedding , vision_hidden_states
395
459
396
- def _get_image_bounds (self , input_ids : torch .Tensor ) -> torch .Tensor :
397
- tokenizer = cached_get_tokenizer (self .config ._name_or_path ,
398
- trust_remote_code = True )
399
- start_cond = input_ids == tokenizer .im_start_id
400
- end_cond = input_ids == tokenizer .im_end_id
401
- if hasattr (tokenizer , "slice_start_id" ):
402
- start_cond |= (input_ids == tokenizer .slice_start_id )
403
- end_cond |= (input_ids == tokenizer .slice_end_id )
460
+ def _get_image_bounds (
461
+ self ,
462
+ input_ids : torch .Tensor ,
463
+ im_start_id : torch .Tensor ,
464
+ im_end_id : torch .Tensor ,
465
+ slice_start_id : Optional [torch .Tensor ] = None ,
466
+ slice_end_id : Optional [torch .Tensor ] = None ) -> torch .Tensor :
467
+ # All the images in the batch should share the same special image
468
+ # bound token ids.
469
+ start_cond = input_ids == im_start_id [0 ]
470
+ end_cond = input_ids == im_end_id [0 ]
471
+ if slice_start_id is not None :
472
+ start_cond |= (input_ids == slice_start_id [0 ])
473
+ end_cond |= (input_ids == slice_end_id [0 ])
404
474
405
475
image_start_tokens , = torch .where (start_cond )
406
476
image_start_tokens += 1
@@ -419,7 +489,7 @@ def _parse_and_validate_inputs(
419
489
self ,
420
490
input_ids : torch .Tensor ,
421
491
** kwargs : object ,
422
- ) -> Optional [MiniCPMVImageInputs ]:
492
+ ) -> Optional [MiniCPMVImagePixelInputs ]:
423
493
pixel_values = kwargs .pop ("pixel_values" , [])
424
494
tgt_sizes = kwargs .pop ("tgt_sizes" , [])
425
495
@@ -456,8 +526,17 @@ def _parse_and_validate_inputs(
456
526
if len (pixel_values_flat ) == 0 :
457
527
return None
458
528
459
- return MiniCPMVImageInputs (
460
- image_bounds = self ._get_image_bounds (input_ids ),
529
+ im_start_id = kwargs .pop ("im_start_id" , None )
530
+ im_end_id = kwargs .pop ("im_end_id" , None )
531
+ slice_start_id = kwargs .pop ("slice_start_id" , None )
532
+ slice_end_id = kwargs .pop ("slice_end_id" , None )
533
+ if im_start_id is None :
534
+ return None
535
+
536
+ return MiniCPMVImagePixelInputs (
537
+ image_bounds = self ._get_image_bounds (input_ids , im_start_id ,
538
+ im_end_id , slice_start_id ,
539
+ slice_end_id ),
461
540
pixel_values = pixel_values_flat ,
462
541
tgt_sizes = torch .stack (tgt_sizes_flat ),
463
542
)
@@ -564,8 +643,8 @@ def get_vision_embedding(
564
643
) -> torch .Tensor :
565
644
raise NotImplementedError
566
645
567
- def get_vision_hidden_states (self ,
568
- data : MiniCPMVImageInputs ) -> torch .Tensor :
646
+ def get_vision_hidden_states (
647
+ self , data : MiniCPMVImagePixelInputs ) -> torch .Tensor :
569
648
raise NotImplementedError
570
649
571
650
def is_default_weight_loading (self , name : str ) -> bool :
@@ -654,8 +733,8 @@ def get_vision_embedding(
654
733
res .append (self .resampler (vision_embedding , tgt_size ))
655
734
return torch .vstack (res )
656
735
657
- def get_vision_hidden_states (self ,
658
- data : MiniCPMVImageInputs ) -> torch .Tensor :
736
+ def get_vision_hidden_states (
737
+ self , data : MiniCPMVImagePixelInputs ) -> torch .Tensor :
659
738
pixel_values = data ["pixel_values" ]
660
739
661
740
return self .get_vision_embedding (pixel_values )
@@ -713,8 +792,8 @@ def get_vision_embedding(
713
792
vision_embedding = self .resampler (vision_embedding , tgt_sizes )
714
793
return vision_embedding
715
794
716
- def get_vision_hidden_states (self ,
717
- data : MiniCPMVImageInputs ) -> torch .Tensor :
795
+ def get_vision_hidden_states (
796
+ self , data : MiniCPMVImagePixelInputs ) -> torch .Tensor :
718
797
pixel_values = data ["pixel_values" ]
719
798
tgt_sizes = data ["tgt_sizes" ]
720
799
@@ -807,8 +886,8 @@ def get_vision_embedding(
807
886
).last_hidden_state
808
887
return vision_embedding
809
888
810
- def get_vision_hidden_states (self ,
811
- data : MiniCPMVImageInputs ) -> torch .Tensor :
889
+ def get_vision_hidden_states (
890
+ self , data : MiniCPMVImagePixelInputs ) -> torch .Tensor :
812
891
pixel_values = data ["pixel_values" ]
813
892
tgt_sizes = data ["tgt_sizes" ]
814
893
@@ -851,7 +930,7 @@ def is_default_weight_loading(self, name: str) -> bool:
851
930
}
852
931
853
932
854
- @MULTIMODAL_REGISTRY .register_image_input_mapper ()
933
+ @MULTIMODAL_REGISTRY .register_image_input_mapper (input_mapper_for_minicpmv )
855
934
@MULTIMODAL_REGISTRY .register_max_image_tokens (get_max_minicpmv_image_tokens )
856
935
@INPUT_REGISTRY .register_dummy_data (dummy_data_for_minicpmv )
857
936
@INPUT_REGISTRY .register_input_processor (input_processor_for_minicpmv )
0 commit comments