3030from vllm .sequence import IntermediateTensors
3131from vllm .transformers_utils .processors .ovis2_5 import Ovis2_5Processor
3232
33- from .interfaces import MultiModalEmbeddings , SupportsMultiModal
33+ from .interfaces import MultiModalEmbeddings , SupportsMultiModal , SupportsPP
3434
3535IMAGE_TOKEN = "<image>"
3636VIDEO_TOKEN = "<video>"
@@ -70,13 +70,15 @@ def __init__(
7070 visual_vocab_size : int ,
7171 quant_config : Optional [QuantizationConfig ] = None ,
7272 prefix : str = "" ,
73+ use_data_parallel : bool = False ,
7374 ):
7475 super ().__init__ ()
7576 self .config = config
7677 self .vit = self ._init_backbone (
7778 config = config ,
7879 quant_config = quant_config ,
7980 prefix = f"{ prefix } .vit" ,
81+ use_data_parallel = use_data_parallel ,
8082 )
8183 # reserved tokens for INDICATOR_IDS
8284 head_dim = visual_vocab_size - len (INDICATOR_IDS )
@@ -93,39 +95,42 @@ def _init_backbone(
9395 config : PretrainedConfig ,
9496 quant_config : Optional [QuantizationConfig ] = None ,
9597 prefix : str = "" ,
98+ use_data_parallel : bool = False ,
9699 ):
97100 model_type = config .model_type
98101 if model_type == "siglip2_navit" :
99- return Siglip2NavitModel (config = config , )
102+ return Siglip2NavitModel (config = config ,
103+ quant_config = quant_config ,
104+ prefix = prefix ,
105+ use_data_parallel = use_data_parallel )
100106 raise ValueError (
101107 f"Unsupported visual tokenizer model_type: { model_type } " )
102108
103109 @property
104- def dtype (self ):
110+ def dtype (self ) -> torch . dtype :
105111 return next (self .head .parameters ()).dtype
106112
107113 @property
108- def device (self ):
114+ def device (self ) -> torch . device :
109115 return next (self .head .parameters ()).device
110116
111- def tokenize (self , logits ) :
117+ def tokenize (self , logits : torch . Tensor ) -> torch . Tensor :
112118 tokens = torch .softmax (logits , dim = - 1 ,
113119 dtype = torch .float32 ).to (logits .dtype )
114120 return tokens
115121
116- def encode (self , pixel_values , grid_thws ):
117- features = self .vit (pixel_values ,
118- grid_thws ,
119- output_hidden_states = True ,
120- return_dict = True )
122+ def encode (self , pixel_values : torch .Tensor ,
123+ grid_thws : torch .Tensor ) -> torch .Tensor :
124+ features = self .vit (pixel_values , grid_thws )
121125 # refer to qwen2.5-vl patchmerger
122126 seq_len , _ = features .shape
123127 features = features .reshape (seq_len // (self .config .hidden_stride ** 2 ),
124128 - 1 )
125129
126130 return features
127131
128- def forward (self , pixel_values , grid_thws ) -> torch .Tensor :
132+ def forward (self , pixel_values : torch .Tensor ,
133+ grid_thws : torch .Tensor ) -> torch .Tensor :
129134 features = self .encode (pixel_values , grid_thws )
130135 logits = self .head (features )
131136 tokens = self .tokenize (logits )
@@ -395,7 +400,7 @@ def get_replacement_ovis(item_idx, modality: str):
395400@MULTIMODAL_REGISTRY .register_processor (Ovis2_5MultiModalProcessor ,
396401 info = Ovis2_5ProcessingInfo ,
397402 dummy_inputs = Ovis2_5DummyInputsBuilder )
398- class Ovis2_5 (nn .Module , SupportsMultiModal ):
403+ class Ovis2_5 (nn .Module , SupportsMultiModal , SupportsPP ):
399404
400405 def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
401406 super ().__init__ ()
@@ -421,9 +426,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
421426 text_model_type = self .config .get_text_config ().model_type
422427 self .image_pad_token_id = IMAGE_PAD_TOKEN_ID_MAP [text_model_type ]
423428
424- # TODO(Isotr0py): PP support
425- # self.make_empty_intermediate_tensors = (
426- # self.language_model.make_empty_intermediate_tensors)
429+ self .make_empty_intermediate_tensors = (
430+ self .get_language_model ().make_empty_intermediate_tensors )
427431
428432 def _parse_and_validate_visual_input (
429433 self , is_video ,
@@ -567,4 +571,4 @@ def load_weights(self, weights: Iterable[tuple[str,
567571 return loader .load_weights (weights )
568572
569573 def get_language_model (self ) -> torch .nn .Module :
570- return self .llm
574+ return self .llm
0 commit comments