8484parser .add_argument (
8585 '--quant_format' ,
8686 type = str ,
87- default = 'QOperator' ,
87+ default = 'QOperator' ,
8888 choices = ['QOperator' , 'QDQ' ],
8989 help = "quantization format"
9090)
124124)
125125args = parser .parse_args ()
126126
127- # load model
127+ # load model tokenize and config
128128tokenizer = LlamaTokenizer .from_pretrained (args .tokenizer )
129+ config = LlamaConfig .from_pretrained (args .model_path )
129130
130131def tokenize_function (examples ):
131132 example = tokenizer (examples ['text' ])
@@ -134,29 +135,20 @@ def tokenize_function(examples):
134135def benchmark (model ):
135136 import json
136137 import time
137- config = LlamaConfig .from_pretrained (args .model_path )
138138 sess_options = ort .SessionOptions ()
139139 sess_options .intra_op_num_threads = args .intra_op_num_threads
140-
141- if os .path .exists (os .path .join (model , "decoder_with_past_model.onnx" )):
142- sessions = ORTModelForCausalLM .load_model ( # pylint: disable=E1123
143- os .path .join (model , "decoder_model.onnx" ),
144- os .path .join (model , "decoder_with_past_model.onnx" ),
145- session_options = sess_options )
146- model = ORTModelForCausalLM (sessions [0 ], # pylint: disable=E1121
147- config ,
148- model ,
149- sessions [1 ],
150- use_cache = True )
151- else :
152- sessions = ORTModelForCausalLM .load_model ( # pylint: disable=E1123
153- os .path .join (model , "decoder_model.onnx" ),
154- session_options = sess_options )
155- model = ORTModelForCausalLM (sessions [0 ], # pylint: disable=E1121
156- config ,
157- model ,
158- use_cache = False ,
159- use_io_binding = False )
140+
141+ session = ORTModelForCausalLM .load_model ( # pylint: disable=E1123
142+ os .path .join (model , "model.onnx" ),
143+ session_options = sess_options )
144+ inputs_names = session .get_inputs ()
145+ key_value_input_names = [key .name for key in inputs_names if (".key" in key .name ) or (".value" in key .name )]
146+ use_cache = len (key_value_input_names ) > 0
147+
148+ model = ORTModelForCausalLM (session , # pylint: disable=E1121
149+ config ,
150+ use_cache = True if use_cache else False ,
151+ use_io_binding = True if use_cache else False ,)
160152
161153 input_tokens = '32'
162154 max_new_tokens = 32
@@ -192,7 +184,7 @@ def benchmark(model):
192184 print (args )
193185 throughput = (num_iter - num_warmup ) / total_time
194186 print ("Throughput: {} samples/s" .format (throughput ))
195-
187+
196188
197189def replace_architectures (json_path ):
198190 # replace 'LLaMATokenizer' to lowercase 'LlamaTokenizer'
@@ -201,7 +193,7 @@ def replace_architectures(json_path):
201193 with open (json_path , "r" ) as file :
202194 data = json .load (file )
203195 data ["architectures" ] = ["LlamaForCausalLM" ]
204-
196+
205197 with open (json_path , 'w' ) as file :
206198 json .dump (data , file , indent = 4 )
207199
@@ -234,6 +226,7 @@ def eval_func(model):
234226
235227 return eval_acc
236228
229+
237230class KVDataloader :
238231 def __init__ (self , model_path , pad_max = 196 , batch_size = 1 , sub_folder = 'train' ):
239232 self .pad_max = pad_max
@@ -247,10 +240,11 @@ def __init__(self, model_path, pad_max=196, batch_size=1, sub_folder='train'):
247240 shuffle = False ,
248241 collate_fn = self .collate_batch ,
249242 )
250- self .sess = None
251- if not model_path .endswith ('decoder_model.onnx' ):
252- self .sess = ort .InferenceSession (os .path .join (os .path .dirname (model_path ), 'decoder_model.onnx' ))
253-
243+ session = ort .InferenceSession (model_path )
244+ inputs_names = [input .name for input in session .get_inputs ()]
245+ self .key_value_input_names = [key for key in inputs_names if (".key" in key ) or (".value" in key )]
246+ self .use_cache = len (self .key_value_input_names ) > 0
247+ self .session = session if self .use_cache else None
254248
255249 def collate_batch (self , batch ):
256250
@@ -269,23 +263,26 @@ def collate_batch(self, batch):
269263 attention_mask_padded .append (attention_mask )
270264 return (torch .vstack (input_ids_padded ), torch .vstack (attention_mask_padded )), torch .tensor (last_ind )
271265
272-
273266 def __iter__ (self ):
274267 try :
275268 for (input_ids , attention_mask ), last_ind in self .dataloader :
276- if self .sess is None :
277- yield {'input_ids' : input_ids [:, :- 1 ].detach ().cpu ().numpy ().astype ('int64' ),
278- 'attention_mask' :attention_mask [:, :- 1 ].detach ().cpu ().numpy ().astype ('int64' )}, last_ind .detach ().cpu ().numpy ()
279- else :
280- outputs = self .sess .run (None , {'input_ids' : input_ids [:, :- 1 ].detach ().cpu ().numpy ().astype ('int64' ),
281- 'attention_mask' :attention_mask [:, :- 1 ].detach ().cpu ().numpy ().astype ('int64' )})
282- ort_input = {}
283- ort_input ['input_ids' ] = input_ids [:, - 1 ].unsqueeze (0 ).detach ().cpu ().numpy ().astype ('int64' )
284- for i in range (int ((len (outputs ) - 1 ) / 2 )):
285- ort_input ['past_key_values.{}.key' .format (i )] = outputs [i * 2 + 1 ]
286- ort_input ['past_key_values.{}.value' .format (i )] = outputs [i * 2 + 2 ]
287- ort_input ['attention_mask' ] = np .zeros ([self .batch_size , ort_input ['past_key_values.0.key' ].shape [2 ]+ 1 ], dtype = 'int64' )
288- yield ort_input , last_ind .detach ().cpu ().numpy ()
269+ ort_input = {}
270+ ort_input ["input_ids" ] = input_ids [:, :- 1 ].detach ().cpu ().numpy ().astype ("int64" )
271+ ort_input ["attention_mask" ] = attention_mask [:, :- 1 ].detach ().cpu ().numpy ().astype ("int64" )
272+ position_ids = attention_mask .long ().cumsum (- 1 ) - 1
273+ position_ids .masked_fill_ (attention_mask == 0 , 1 )
274+ ort_input ["position_ids" ] = position_ids [:,:- 1 ].detach ().cpu ().numpy ().astype ("int64" )
275+ if self .use_cache :
276+ # Create dummy past_key_values for decoder
277+ num_attention_heads = config .num_key_value_heads
278+ embed_size_per_head = config .hidden_size // config .num_attention_heads
279+ shape = (self .batch_size , num_attention_heads , 0 , embed_size_per_head )
280+ key_or_value = np .zeros (shape , dtype = np .float32 )
281+ for key_value_input_name in self .key_value_input_names :
282+ ort_input [key_value_input_name ] = key_or_value
283+
284+ yield ort_input , last_ind .detach ().cpu ().numpy ()
285+
289286 except StopIteration :
290287 return
291288
@@ -294,43 +291,38 @@ def __iter__(self):
294291 set_workspace (args .workspace )
295292
296293 if args .benchmark :
297- if args .mode == 'performance' :
294+ if args .mode == 'performance' :
298295 benchmark (args .model_path )
299296 elif args .mode == 'accuracy' :
300297 eval_func (args .model_path )
301298
302299 if args .tune :
303300 from neural_compressor import quantization , PostTrainingQuantConfig
301+
302+ model_name = "model.onnx" # require optimum >= 1.14.0
303+ model_path = os .path .join (args .model_path , model_name )
304+
304305 if args .layer_wise :
305306 # layer-wise quantization for ONNX models is still under development and only support W8A8 quantization now
306- config = PostTrainingQuantConfig (
307+ ptq_config = PostTrainingQuantConfig (
307308 calibration_sampling_size = [8 ],
308309 recipes = {'optypes_to_exclude_output_quant' : ['MatMul' ],
309- 'layer_wise_quant' : True },
310+ 'layer_wise_quant' : True ,
311+ 'graph_optimization_level' : 'ENABLE_EXTENDED' },
310312 op_type_dict = {'^((?!(MatMul|Gather|Conv)).)*$' : {'weight' : {'dtype' : ['fp32' ]}, 'activation' : {'dtype' : ['fp32' ]}}})
311- for model in ['decoder_model.onnx' ]:
312- # only test decoder_model
313- q_model = quantization .fit (
314- os .path .join (args .model_path , model ),
315- config ,
316- calib_dataloader = KVDataloader (os .path .join (args .model_path , model ), pad_max = args .pad_max , batch_size = 1 ))
317- q_model .save (os .path .join (args .output_model , model ))
318-
319- tokenizer .save_pretrained (args .output_model )
320-
321313 else :
322- config = PostTrainingQuantConfig (
314+ ptq_config = PostTrainingQuantConfig (
323315 calibration_sampling_size = [8 ],
324316 recipes = {'optypes_to_exclude_output_quant' : ['MatMul' ],
325- 'smooth_quant' : True ,
326- 'smooth_quant_args' : {'alpha' : args .smooth_quant_alpha },
327- },
317+ 'smooth_quant' : True ,
318+ 'smooth_quant_args' : {'alpha' : args .smooth_quant_alpha },
319+ 'graph_optimization_level' : 'ENABLE_EXTENDED' },
328320 op_type_dict = {'^((?!(MatMul|Gather|Conv)).)*$' : {'weight' : {'dtype' : ['fp32' ]}, 'activation' : {'dtype' : ['fp32' ]}}})
329- for model in [ 'decoder_model.onnx' , 'decoder_with_past_model.onnx' ]:
330- q_model = quantization .fit (
331- os . path . join ( args . model_path , model ) ,
332- config ,
333- calib_dataloader = KVDataloader (os . path . join ( args . model_path , model ) , pad_max = args .pad_max , batch_size = 1 ))
334- q_model .save (os .path .join (args .output_model , model ))
335-
336- tokenizer .save_pretrained (args .output_model )
321+
322+ q_model = quantization .fit (
323+ model_path ,
324+ ptq_config ,
325+ calib_dataloader = KVDataloader (model_path , pad_max = args .pad_max , batch_size = 1 ))
326+ q_model .save (os .path .join (args .output_model , model_name ))
327+
328+ tokenizer .save_pretrained (args .output_model )
0 commit comments