@@ -176,6 +176,7 @@ def forward(
176176 image_idx ,
177177 past_key_values ,
178178 comp_ctx_lengths : Optional [List [int ]] = None ,
179+ batch_index : Optional [torch .LongTensor ] = None ,
179180 ):
180181 inputs_embeds = self .model .get_input_embeddings ()(input_ids )
181182 vision_embeds = vision_embeds .to (inputs_embeds .device , inputs_embeds .dtype )
@@ -190,6 +191,7 @@ def forward(
190191 position_ids = position_ids ,
191192 past_key_values = past_key_values ,
192193 comp_ctx_lengths = comp_ctx_lengths ,
194+ batch_index = batch_index ,
193195 )
194196
195197 # Cast to int32 to avoid ONNXRT issue
@@ -250,7 +252,7 @@ def forward(
250252
251253 return logits , pixel_values , image_idx , outputs .past_key_values
252254
253- def get_dummy_inputs (self , comp_ctx_lengths : Optional [List [int ]] = None , kv_offload : bool = False , ** kwargs ):
255+ def get_dummy_inputs (self , comp_ctx_lengths : Optional [List [int ]] = None , kv_offload : bool = False , continuous_batching : bool = False , ** kwargs ):
254256 inputs_shapes = {}
255257 inputs_shapes ["input_ids" ] = (constants .ONNX_EXPORT_EXAMPLE_BATCH_SIZE , constants .ONNX_EXPORT_EXAMPLE_SEQ_LEN )
256258 height = self .config .vision_config .image_size
@@ -290,10 +292,14 @@ def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offl
290292 .repeat (constants .ONNX_EXPORT_EXAMPLE_BATCH_SIZE , 1 )
291293 )
292294 lang_inputs ["image_idx" ] = torch .zeros ((inputs_shapes ["image_idx" ]), dtype = torch .int64 )
295+
296+ bs : int = constants .ONNX_EXPORT_EXAMPLE_BATCH_SIZE
297+ fbs : int = constants .ONNX_EXPORT_EXAMPLE_FBS
298+
293299 # Add data for KV
294300 kv_cache_shape = get_padding_shape_from_config (
295- config = self .language_model .config ,
296- batch_size = constants . ONNX_EXPORT_EXAMPLE_BATCH_SIZE ,
301+ config = self .model .config . text_config ,
302+ batch_size = fbs if continuous_batching else bs ,
297303 seq_len = constants .ONNX_EXPORT_EXAMPLE_SEQ_LEN ,
298304 )
299305
@@ -304,6 +310,8 @@ def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offl
304310
305311 if comp_ctx_lengths is not None :
306312 lang_inputs ["comp_ctx_lengths" ] = torch .randint (0 , 100 , (40 ,), dtype = torch .long )
313+ if continuous_batching :
314+ lang_inputs ["batch_index" ] = torch .arange (bs ).view (bs , 1 )
307315
308316 inputs = {}
309317 if kv_offload :
@@ -324,6 +332,9 @@ def get_specializations(
324332 comp_ctx_lengths_prefill : Optional [List [int ]] = None ,
325333 comp_ctx_lengths_decode : Optional [List [int ]] = None ,
326334 kv_offload : bool = False ,
335+ continuous_batching : bool = False ,
336+ kv_cache_batch_size : Optional [int ] = None ,
337+ full_batch_size : Optional [int ] = None ,
327338 ** compiler_options ,
328339 ):
329340 if img_size is None and hasattr (self .config .vision_config , "image_size" ):
@@ -352,46 +363,65 @@ def get_specializations(
352363 lang = []
353364
354365 for i in range (0 , len (comp_ctx_lengths_prefill )):
355- lang .append (
356- {
357- "batch_size" : batch_size ,
358- "seq_len" : prefill_seq_len ,
359- "ctx_len" : ctx_len ,
360- "comp_ctx_lengths" : comp_ctx_lengths_prefill [i ],
361- "image_size" : img_size ,
362- "vision_size" : vision_size ,
363- }
364- )
365-
366- # Remaining elements use comp_ctx_lengths[1:] in a loop
367- for i in range (0 , len (comp_ctx_lengths_decode )):
368- lang .append (
369- {
370- "batch_size" : batch_size ,
371- "seq_len" : "1" ,
372- "ctx_len" : ctx_len ,
373- "comp_ctx_lengths" : comp_ctx_lengths_decode [i ],
374- "image_size" : img_size ,
375- "vision_size" : vision_size ,
376- }
377- )
378- else :
379- lang = [
380- {
381- "batch_size" : batch_size ,
366+ lang_prefill = {
367+ "batch_size" : 1 if continuous_batching else batch_size ,
382368 "seq_len" : prefill_seq_len ,
383369 "ctx_len" : ctx_len ,
370+ "comp_ctx_lengths" : comp_ctx_lengths_prefill [i ],
384371 "image_size" : img_size ,
385372 "vision_size" : vision_size ,
386- },
387- {
388- "batch_size" : batch_size ,
373+ }
374+ if continuous_batching :
375+ lang_prefill ["full_batch_size" ] = kv_cache_batch_size
376+ else :
377+ lang_prefill ["batch_size" ] = kv_cache_batch_size
378+ if full_batch_size :
379+ lang_prefill ["full_batch_exec_size" ] = full_batch_size
380+ lang .append (lang_prefill )
381+
382+ # Remaining elements use comp_ctx_lengths[1:] in a loop
383+ for i in range (0 , len (comp_ctx_lengths_decode )):
384+ lang_decode = {
385+ "batch_size" : full_batch_size if continuous_batching else batch_size ,
389386 "seq_len" : "1" ,
390387 "ctx_len" : ctx_len ,
388+ "comp_ctx_lengths" : comp_ctx_lengths_decode [i ],
391389 "image_size" : img_size ,
392390 "vision_size" : vision_size ,
393- },
394- ]
391+ }
392+
393+ if continuous_batching :
394+ lang_decode ["full_batch_size" ] = kv_cache_batch_size
395+ else :
396+ lang_decode ["batch_size" ] = kv_cache_batch_size
397+ lang .append (lang_decode )
398+ else :
399+ lang_prefill = {
400+ "batch_size" : 1 if continuous_batching else batch_size ,
401+ "seq_len" : prefill_seq_len ,
402+ "ctx_len" : ctx_len ,
403+ "image_size" : img_size ,
404+ "vision_size" : vision_size ,
405+ }
406+ if continuous_batching :
407+ lang_prefill ["full_batch_size" ] = kv_cache_batch_size
408+ else :
409+ lang_prefill ["batch_size" ] = kv_cache_batch_size
410+ if full_batch_size :
411+ lang_prefill ["full_batch_exec_size" ] = full_batch_size
412+
413+ lang_decode = {
414+ "batch_size" : full_batch_size if continuous_batching else batch_size ,
415+ "seq_len" : "1" ,
416+ "ctx_len" : ctx_len ,
417+ "image_size" : img_size ,
418+ "vision_size" : vision_size ,
419+ }
420+
421+ if continuous_batching :
422+ lang_decode ["full_batch_size" ] = kv_cache_batch_size
423+ else :
424+ lang_decode ["batch_size" ] = kv_cache_batch_size
395425
396426 specializations = {}
397427
@@ -404,7 +434,7 @@ def get_specializations(
404434 lang [1 ].pop ("vision_size" )
405435 return lang , compiler_options
406436
407- def get_onnx_dynamic_axes (self , comp_ctx_lengths : Optional [List [int ]] = None , kv_offload : bool = False ):
437+ def get_onnx_dynamic_axes (self , comp_ctx_lengths : Optional [List [int ]] = None , kv_offload : bool = False , continuous_batching : bool = False ):
408438 # Define dynamic axes
409439 num_layers = self .config .text_config .num_hidden_layers
410440
@@ -417,9 +447,18 @@ def get_onnx_dynamic_axes(self, comp_ctx_lengths: Optional[List[int]] = None, kv
417447 "vision_embeds" : {0 : "vision_size" },
418448 }
419449
450+ if continuous_batching :
451+ lang_dynamic_axes ["batch_index" ] = {0 : "batch_size" }
452+
420453 for i in range (num_layers ):
421- lang_dynamic_axes [f"past_key.{ i } " ] = {0 : "batch_size" , 2 : "ctx_len" }
422- lang_dynamic_axes [f"past_value.{ i } " ] = {0 : "batch_size" , 2 : "ctx_len" }
454+ lang_dynamic_axes [f"past_key.{ i } " ] = {
455+ 0 : "full_batch_size" if continuous_batching else "batch_size" ,
456+ 2 : "ctx_len" ,
457+ }
458+ lang_dynamic_axes [f"past_value.{ i } " ] = {
459+ 0 : "full_batch_size" if continuous_batching else "batch_size" ,
460+ 2 : "ctx_len" ,
461+ }
423462
424463 if comp_ctx_lengths is not None :
425464 lang_dynamic_axes ["comp_ctx_lengths" ] = {0 : "comp_ctx_lengths" }
0 commit comments