@@ -18,13 +18,13 @@ class QEffInternVLModel(nn.Module):
1818 def get_specializations (
1919 self , batch_size : int , prefill_seq_len : int , ctx_len : int , img_size : int , ** compiler_options
2020 ):
21- # TODO: check if this should be named num_crops or something else
22- num_crops = compiler_options .get ("num_crops " , None )
23- if num_crops is None :
21+ # TODO: check if this should be named num_patches or something else
22+ num_patches = compiler_options .get ("num_patches " , None )
23+ if num_patches is None :
2424 logger .warning (
25- "User should pass `num_crops ` to compile API to fix the dynamic axes `pixel_values`, you can get more info by calling get_inputs_info function!, Since its not found setting its value to 13"
25+ "User should pass `num_patches ` to compile API to fix the dynamic axes `pixel_values`, you can get more info by calling get_inputs_info function!, Since its not found setting its value to 13"
2626 )
27- num_crops = 13
27+ num_patches = 13
2828
2929 prefill_seq_len = prefill_seq_len if prefill_seq_len else 3840 # 4096-256
3030 ctx_len = ctx_len if ctx_len else 4096
@@ -39,14 +39,14 @@ def get_specializations(
3939 "batch_size" : batch_size ,
4040 "seq_len" : prefill_seq_len ,
4141 "ctx_len" : ctx_len ,
42- "num_crops " : num_crops ,
42+ "num_patches " : num_patches ,
4343 "img_size" : img_size ,
4444 },
4545 {
4646 "batch_size" : batch_size ,
4747 "seq_len" : "1" ,
4848 "ctx_len" : ctx_len ,
49- "num_crops " : num_crops ,
49+ "num_patches " : num_patches ,
5050 "img_size" : img_size ,
5151 },
5252 ]
@@ -58,7 +58,7 @@ def get_onnx_dynamic_axes(
5858 dynamic_axes = {}
5959 dynamic_axes ["input_ids" ] = {0 : "batch_size" , 1 : "seq_len" }
6060 dynamic_axes ["position_ids" ] = {0 : "batch_size" , 1 : "seq_len" }
61- dynamic_axes ["pixel_values" ] = {0 : "num_crops " , 2 : "img_size" , 3 : "img_size" }
61+ dynamic_axes ["pixel_values" ] = {0 : "num_patches " , 2 : "img_size" , 3 : "img_size" }
6262
6363 pkv_dynamic_axes = {0 : "batch_size" , 2 : "ctx_len" }
6464 for i in range (self .language_model .config .num_hidden_layers ):
@@ -79,12 +79,12 @@ def get_output_names(
7979 def get_dummy_inputs (self , kv_offload : bool = False ):
8080 if kv_offload :
8181 raise ValueError ("kv_offload method not supported for InternVL yet!" )
82- NUM_CROPS = 13
82+ num_patches = 13
8383 C = 3
8484 if vis_cfg := getattr (self .config , "vision_config" , None ):
85- img_size = getattr (vis_cfg , "image_size" , 336 )
85+ img_size = getattr (vis_cfg , "image_size" , 448 )
8686 else :
87- img_size = 336
87+ img_size = 448
8888
8989 # Define shapes
9090 inputs_shapes = {}
@@ -93,7 +93,7 @@ def get_dummy_inputs(self, kv_offload: bool = False):
9393 constants .ONNX_EXPORT_EXAMPLE_BATCH_SIZE ,
9494 constants .ONNX_EXPORT_EXAMPLE_SEQ_LEN ,
9595 )
96- inputs_shapes ["pixel_values" ] = (NUM_CROPS , C , img_size , img_size )
96+ inputs_shapes ["pixel_values" ] = (num_patches , C , img_size , img_size )
9797
9898 # Define inputs
9999 inputs = {}
@@ -143,7 +143,7 @@ def get_inputs_info(self):
143143 return [
144144 IOInfo (name = "input_ids" , datatype = torch .int64 , shape = ("batch_size" , "seq_len" )),
145145 IOInfo (name = "attention_mask" , datatype = torch .int64 , shape = ("batch_size" , "seq_len" )),
146- IOInfo (name = "pixel_values" , datatype = torch .float32 , shape = ("num_crops " , 3 , "img_size" , "img_size" )),
146+ IOInfo (name = "pixel_values" , datatype = torch .float32 , shape = ("num_patches " , 3 , "img_size" , "img_size" )),
147147 ]
148148
149149
0 commit comments