@@ -81,22 +81,26 @@ def list_of_dict2dict_of_list(
8181 return {key : [dict_item [key ] for dict_item in list_of_dicts ] for key in keys }
8282
8383
84+ def is_multi_modal_key (key : str ) -> bool :
85+ # Any key matching: multi_modal_input*
86+ return key .startswith ("multi_modal_input" )
87+
88+
8489def pad_sequences_to_tensors (
8590 sequence_list : list [dict [str , Any ]], pad_value : float = 0.0
8691) -> dict [str , Any ]:
8792 if not sequence_list :
8893 return {}
89- skip_keys = {"multi_modal_input" }
9094 max_length = max (
9195 len (seq )
9296 for item in sequence_list
9397 for key , seq in item .items ()
94- if key not in skip_keys
98+ if not is_multi_modal_key ( key )
9599 )
96100 result = {}
97101 for key in sequence_list [0 ].keys ():
98102 padded = []
99- if key == "multi_modal_input" :
103+ if is_multi_modal_key ( key ) :
100104 for i in range (len (sequence_list )):
101105 if sequence_list [i ][key ]:
102106 item = sequence_list [i ][key ][0 ]
@@ -118,11 +122,20 @@ def pad_sequences_to_tensors(
118122 padded .append (padded_x )
119123 result [key ] = torch .stack (padded )
120124 attention_mask = [
121- [1 ] * len (next (iter (item [key ] for key in item .keys () if key not in skip_keys )))
125+ [1 ]
126+ * len (
127+ next (iter (item [key ] for key in item .keys () if not is_multi_modal_key (key )))
128+ )
122129 + [0 ]
123130 * (
124131 max_length
125- - len (next (iter (item [key ] for key in item .keys () if key not in skip_keys )))
132+ - len (
133+ next (
134+ iter (
135+ item [key ] for key in item .keys () if not is_multi_modal_key (key )
136+ )
137+ )
138+ )
126139 )
127140 for item in sequence_list
128141 ]
@@ -163,31 +176,21 @@ def concat_padded_tensors(
163176 max_length = max ([x ["attention_mask" ].shape [1 ] for x in tensor_dicts ])
164177 result = {}
165178
166- has_any_multi_modal = any ( "multi_modal_input" in td for td in tensor_dicts )
167-
168- merged_multi_modal = None
169-
170- if has_any_multi_modal :
179+ multimodal_keys = {
180+ key for td in tensor_dicts for key in td if is_multi_modal_key ( key )
181+ }
182+ # Merge multimodal keys
183+ for mm_key in multimodal_keys :
171184 merged_multi_modal = []
172-
173- # Merge multi-modal data maintaining per-dp correspondence
174- for tensor_dict in tensor_dicts :
175- td_batch_size = get_batch_size (tensor_dict )
176-
177- if "multi_modal_input" in tensor_dict :
178- # Has multi_modal_input - extend the lists
179- multi_modal = tensor_dict ["multi_modal_input" ]
180- else :
181- multi_modal = [{} for _ in range (td_batch_size )]
182-
183- merged_multi_modal .extend (multi_modal )
184-
185- result ["multi_modal_input" ] = merged_multi_modal
185+ for td in tensor_dicts :
186+ bs = get_batch_size (td )
187+ merged_multi_modal .extend (td .get (mm_key , [{} for _ in range (bs )]))
188+ result [mm_key ] = merged_multi_modal
186189
187190 # Process each key
188191 for key in tensor_dicts [0 ].keys ():
189192 tensors_to_concat = []
190- if key == "multi_modal_input" :
193+ if is_multi_modal_key ( key ) :
191194 continue
192195 for tensor_dict in tensor_dicts :
193196 tensor = tensor_dict [key ]
@@ -444,11 +447,14 @@ def split_padded_tensor_dict_into_mb_list(
444447 .numpy ()
445448 )
446449
450+ # check for multimodal input data
451+ multimodal_keys = {key for key in data if is_multi_modal_key (key )}
452+
447453 # check tensor shape, split only 1d tensors with length "total_lens"
448454 to_split = {}
449455 not_to_split = {}
450456 for key , value in data .items ():
451- if key == "multi_modal_input" :
457+ if key in multimodal_keys :
452458 continue
453459 if key == "position_ids" or (
454460 torch .is_tensor (value ) and value .numel () == bs * max_seqlen
@@ -493,8 +499,8 @@ def _split(tensor):
493499
494500 to_split = dict_map (to_split , lambda x : _split (x ))
495501
496- if "multi_modal_input" in data :
497- multi_modal_input = data ["multi_modal_input" ]
502+ for key in multimodal_keys :
503+ multi_modal_input = data [key ]
498504
499505 # Prepare the pixel_values and image_grid_thw for each group
500506 multi_modal_input_split = []
@@ -504,7 +510,7 @@ def _split(tensor):
504510 # Stack pixel_values for each group (assuming pixel_values is a list of tensors)
505511 multi_modal_input_split .append (group_pixel_multi_modal_input )
506512 # Pack the split pixel_values and image_grid_thw back into the data
507- to_split ["multi_modal_input" ] = multi_modal_input_split
513+ to_split [key ] = multi_modal_input_split
508514 mbs = dict_of_list2list_of_dict (to_split )
509515
510516 results = []
0 commit comments