1010import math
1111from typing import List , Optional , Tuple , Union
1212
13- import requests
1413import torch
1514import torch .nn .functional as F
1615import torch .utils .checkpoint
17- from PIL import Image
1816from torch import nn
1917from torch .nn import CrossEntropyLoss
2018from transformers .cache_utils import Cache , DynamicCache
@@ -1197,12 +1195,13 @@ def forward(
11971195 return outputs
11981196
11991197 def generate_input (self , processor , kv_offload ):
1200-
1201- #vision_inputs
1198+ # vision_inputs
12021199 vision_inputs = {
1203- "pixel_values" : torch .zeros ((bs , max_num_images ,max_image_tiles ,num_channel , image_length , image_width ), dtype = torch .int64 ),
1200+ "pixel_values" : torch .zeros (
1201+ (bs , max_num_images , max_image_tiles , num_channel , image_length , image_width ), dtype = torch .int64
1202+ ),
12041203 "aspect_ratio_ids" : torch .ones ((bs , max_num_images ), dtype = torch .int64 ),
1205- "aspect_ratio_mask" : torch .ones ((bs , max_num_images , max_image_tiles ,1 ), dtype = torch .int64 )
1204+ "aspect_ratio_mask" : torch .ones ((bs , max_num_images , max_image_tiles , 1 ), dtype = torch .int64 ),
12061205 }
12071206
12081207 vision_output_names = []
@@ -1220,19 +1219,19 @@ def generate_input(self, processor, kv_offload):
12201219 },
12211220 }
12221221
1223- #lang_inputs
1222+ # lang_inputs
12241223 lang_inputs = {
1225- "input_ids" : torch .zeros ((bs ,seq_len ),dtype = torch .int64 ),
1224+ "input_ids" : torch .zeros ((bs , seq_len ), dtype = torch .int64 ),
12261225 "position_ids" : torch .arange (seq_len , dtype = torch .int64 ).view (1 , seq_len ).repeat (bs , 1 ),
1227- "cross_attention_mask" : torch .ones ((bs , max_image_tiles ),dtype = torch .int64 ),
1228- "attention_mask" : torch .ones ((bs ,seq_len ),dtype = torch .int64 )
1226+ "cross_attention_mask" : torch .ones ((bs , max_image_tiles ), dtype = torch .int64 ),
1227+ "attention_mask" : torch .ones ((bs , seq_len ), dtype = torch .int64 ),
12291228 }
12301229
12311230 lang_inputs ["position_ids" ] = torch .where (
12321231 lang_inputs .pop ("attention_mask" ) == 1 ,
12331232 torch .arange (lang_inputs ["input_ids" ].shape [1 ]).view (1 , - 1 ),
12341233 - 1 ,
1235- )
1234+ )
12361235
12371236 ctx_len = Constants .CTX_LEN
12381237 txt_cfg = self .mllama .config .get_text_config ()
@@ -1245,7 +1244,6 @@ def generate_input(self, processor, kv_offload):
12451244 num_patches = (vis_cfg .image_size // vis_cfg .patch_size ) ** 2 + 1
12461245 image_tokens_len = vis_cfg .max_num_tiles * num_patches
12471246
1248-
12491247 lang_inputs ["past_key_values" ] = DynamicCache (num_hidden_layers )
12501248 lang_inputs ["past_key_values" ].key_cache = [0 ] * num_hidden_layers
12511249 lang_inputs ["past_key_values" ].value_cache = [0 ] * num_hidden_layers
@@ -1254,20 +1252,21 @@ def generate_input(self, processor, kv_offload):
12541252 if i in cross_attention_layers :
12551253 idx = cross_attention_layers .index (i )
12561254 assert idx == ((i - 3 ) // 5 ), f"{ i } , { (i - 3 ) // 5 } "
1257- lang_inputs ["past_key_values" ].key_cache [i ] = torch .zeros (1 , num_key_value_heads , image_tokens_len , head_dim )
1255+ lang_inputs ["past_key_values" ].key_cache [i ] = torch .zeros (
1256+ 1 , num_key_value_heads , image_tokens_len , head_dim
1257+ )
12581258 lang_inputs ["past_key_values" ].value_cache [i ] = torch .zeros (
12591259 1 , num_key_value_heads , image_tokens_len , head_dim
12601260 )
12611261 else :
12621262 lang_inputs ["past_key_values" ].key_cache [i ] = torch .zeros (1 , num_key_value_heads , ctx_len , head_dim )
12631263 lang_inputs ["past_key_values" ].value_cache [i ] = torch .zeros (1 , num_key_value_heads , ctx_len , head_dim )
12641264
1265-
12661265 lang_output_names = [
12671266 "logits" ,
12681267 * [f"past_{ kv } .{ i } _RetainedState" for i in range (num_hidden_layers ) for kv in ["key" , "value" ]],
12691268 ]
1270-
1269+
12711270 lang_dynamic_axes = {
12721271 "input_ids" : {0 : "batch_size" , 1 : "seq_len" },
12731272 "position_ids" : {0 : "batch_size" , 1 : "seq_len" },
@@ -1286,10 +1285,10 @@ def generate_input(self, processor, kv_offload):
12861285 else :
12871286 lang_dynamic_axes [f"past_key.{ i } " ] = {0 : "batch_size" , 2 : "ctx_len" }
12881287 lang_dynamic_axes [f"past_value.{ i } " ] = {0 : "batch_size" , 2 : "ctx_len" }
1289-
1288+
12901289 lang_inputs ["past_key_values" ] = lang_inputs ["past_key_values" ].to_legacy_cache ()
12911290 lang_inputs ["position_ids" ] = torch .full (lang_inputs ["position_ids" ].shape , ctx_len - 1 )
1292-
1291+
12931292 inputs = []
12941293 output_names = []
12951294 dynamic_axes = []
@@ -1304,5 +1303,3 @@ def generate_input(self, processor, kv_offload):
13041303 dynamic_axes .append ({** vision_dynamic_axes , ** lang_dynamic_axes })
13051304
13061305 return inputs , output_names , dynamic_axes
1307-
1308-
0 commit comments