@@ -181,6 +181,24 @@ def __init__(
181181 self .error_message = None
182182 self .error_code = None
183183
184+ def _convert_mm_input_types (self , d : dict ) -> list :
185+ if (
186+ "multimodal_inputs" in d
187+ and isinstance (d ["multimodal_inputs" ], dict )
188+ and "mm_positions" in d ["multimodal_inputs" ]
189+ and isinstance (d ["multimodal_inputs" ]["mm_positions" ], list )
190+ and len (d ["multimodal_inputs" ]["mm_positions" ]) > 0
191+ and not isinstance (d ["multimodal_inputs" ]["mm_positions" ][0 ], ImagePosition )
192+ ):
193+ # if mm_positions is of type dict, convert to ImagePosition
194+ try :
195+ for i , mm_pos in enumerate (d ["multimodal_inputs" ]["mm_positions" ]):
196+ d ["multimodal_inputs" ]["mm_positions" ][i ] = ImagePosition (** mm_pos )
197+ except Exception as e :
198+ data_processor_logger .error (
199+ f"Convert mm_positions to ImagePosition error: { e } , { str (traceback .format_exc ())} "
200+ )
201+
184202 @classmethod
185203 def from_dict (cls , d : dict ):
186204 data_processor_logger .debug (f"{ d } " )
@@ -191,13 +209,12 @@ def from_dict(cls, d: dict):
191209 else :
192210 sampling_params = SamplingParams .from_dict (d )
193211 if (
194- "multimodal_inputs" in d
195- and isinstance (d ["multimodal_inputs" ], dict )
196- and "mm_positions" in d ["multimodal_inputs" ]
197- and isinstance (d ["multimodal_inputs" ]["mm_positions" ], list )
212+ isinstance (d .get ("multimodal_inputs" ), dict )
213+ and isinstance (d ["multimodal_inputs" ].get ("mm_positions" ), list )
198214 and len (d ["multimodal_inputs" ]["mm_positions" ]) > 0
199215 and not isinstance (d ["multimodal_inputs" ]["mm_positions" ][0 ], ImagePosition )
200216 ):
217+ # if mm_positions is not of type ImagePosition, convert to ImagePosition
201218 try :
202219 for i , mm_pos in enumerate (d ["multimodal_inputs" ]["mm_positions" ]):
203220 d ["multimodal_inputs" ]["mm_positions" ][i ] = ImagePosition (** mm_pos )
0 commit comments