66# ----------------------------------------------------------------------------
77
88import hashlib
9- import logging
109import sys
1110import warnings
1211from pathlib import Path
2928from QEfficient .base .modeling_qeff import QEFFBaseModel
3029from QEfficient .base .onnx_transforms import FP16ClipTransform , SplitTensorsTransform
3130from QEfficient .generation .cloud_infer import QAICInferenceSession
32- from QEfficient .generation .text_generation_inference import get_compilation_dims
31+ from QEfficient .generation .text_generation_inference import CloudAI100ExecInfoNew , PerfMetrics , get_compilation_dims
3332from QEfficient .transformers .models .pytorch_transforms import (
3433 CustomOpsTransform ,
3534 KVCacheModuleMethodMapperTransform ,
@@ -820,7 +819,7 @@ def export(
820819 export_dir : Optional [str ] = None ,
821820 ** kwargs ,
822821 ) -> str :
823- inputs = self .model .generate_dummy_inputs ()
822+ inputs = self .model .get_dummy_inputs ()
824823 dynamic_axes = self .model .get_onnx_dynamic_axes ()
825824 output_names = self .model .get_output_names ()
826825 self ._export (inputs , output_names , dynamic_axes , export_dir = export_dir )
@@ -843,6 +842,7 @@ def compile(
843842 output_names = self .model .get_output_names ()
844843
845844 # Get specializations from modelling file
845+ # TODO: expose this via the auto class as well
846846 specializations = self .model .get_specializations (batch_size = batch_size , prefill_seq_len = prefill_seq_len ,
847847 ctx_len = ctx_len , img_size = img_size , ** compiler_options )
848848
@@ -873,7 +873,6 @@ def compile(
873873 )
874874 return self .qpc_path
875875
876- @property
877876 def get_onnx_dynamic_axes (self ):
878877 return self .model .get_onnx_dynamic_axes ()
879878
@@ -905,111 +904,121 @@ def generate(
905904 )
906905
907906 def cloud_ai_100_generate (
908- self ,
909- inputs : torch .Tensor ,
910- device_ids : List [int ],
911- enable_debug_logs : bool = False ,
912- generation_len : int = None ,
913- streamer : Optional [TextStreamer ] = None ,
914- generation_len : Optional [int ] = None
915- ) -> np .ndarray :
916- qpc_session = QAICInferenceSession (
917- self .qpc_path , device_ids , enable_debug_logs = enable_debug_logs , activate = False
918- )
919-
920- batch_size , ctx_len , fbs = get_compilation_dims (self .qpc_path )
921-
922- # Skip inputs/outputs
923- qpc_session .skip_buffers (
924- [x for x in (qpc_session .input_names + qpc_session .output_names ) if x .startswith ("past_" )] + ["pixel_values_RetainedState" ]
925- )
926-
927- # Read prompt and ctx len from session
928- batch_size = max (
929- [x [qpc_session .binding_index_map ["input_ids" ]][1 ][0 ] for x in qpc_session .allowed_shapes ]
930- + [qpc_session .bindings [qpc_session .binding_index_map ["input_ids" ]].dims [0 ]]
931- )
932-
933- prefill_seq_len = max (
934- [x [qpc_session .binding_index_map ["input_ids" ]][1 ][1 ] for x in qpc_session .allowed_shapes ]
935- + [qpc_session .bindings [qpc_session .binding_index_map ["input_ids" ]].dims [1 ]]
936- )
907+ self ,
908+ inputs : torch .Tensor ,
909+ device_ids : List [int ],
910+ enable_debug_logs : bool = False ,
911+ generation_len : int = None ,
912+ streamer : Optional [TextStreamer ] = None ,
913+ ) -> np .ndarray :
914+ qpc_session = QAICInferenceSession (
915+ self .qpc_path , device_ids , enable_debug_logs = enable_debug_logs , activate = False
916+ )
937917
938- input_len = inputs ["attention_mask" ].sum (1 , keepdims = True )
939- padded_len = inputs ["input_ids" ].shape [1 ]
940- num_chunks = - (padded_len // - prefill_seq_len ) # ceil divide without float
941- padded_len = num_chunks * prefill_seq_len # Convert to a multiple of prompt_len
918+ batch_size , ctx_len , fbs = get_compilation_dims (self .qpc_path )
942919
943- if generation_len is None :
944- generation_len = ctx_len - input_len .max ()
920+ pad_token_id = 1
945921
946- assert generation_len > 0 , "generation length should be greater than zero"
947- generated_ids = np .full ((batch_size , generation_len + 1 ), - 1 )
922+ # Skip inputs/outputs
923+ qpc_session .skip_buffers (
924+ [x for x in qpc_session .input_names + qpc_session .output_names if x .startswith ("past_" ) or x .endswith ("_RetainedState" )]
925+ )
948926
949- # Prepare inputs for prefill
950- start = perf_counter ()
927+ # Read prompt and ctx len from session
928+ batch_size = max (
929+ [x [qpc_session .binding_index_map ["input_ids" ]][1 ][0 ] for x in qpc_session .allowed_shapes ]
930+ + [qpc_session .bindings [qpc_session .binding_index_map ["input_ids" ]].dims [0 ]]
931+ )
951932
952- inputs [ "position_ids" ] = np . where (
953- inputs . pop ( "attention_mask" ), np . arange ( padded_len ), - 1
954- ) # Need to use -1 as position_ids for invalid tokens
955- inputs = dict ( inputs )
933+ prefill_seq_len = max (
934+ [ x [ qpc_session . binding_index_map [ "input_ids" ]][ 1 ][ 1 ] for x in qpc_session . allowed_shapes ]
935+ + [ qpc_session . bindings [ qpc_session . binding_index_map [ "input_ids" ]]. dims [ 1 ]]
936+ )
956937
957- # vision_session.deactivate( )
958- qpc_session . activate ()
938+ input_len = inputs [ "attention_mask" ]. sum ( 1 , keepdims = True )
939+ input_ids_length = inputs [ "input_ids" ]. shape [ 1 ]
959940
960- # Run prefill
961- for i in range (num_chunks ):
962- chunk_inputs = inputs .copy ()
963- chunk_inputs ["input_ids" ] = inputs ["input_ids" ].numpy ()[:, i * prefill_seq_len : (i + 1 ) * prefill_seq_len ]
964- chunk_inputs ['pixel_values' ] = chunk_inputs ['pixel_values' ].numpy ().astype (np .float16 )
965- chunk_inputs ["position_ids" ] = inputs ["position_ids" ][:, i * prefill_seq_len : (i + 1 ) * prefill_seq_len ]
966- outputs = qpc_session .run (chunk_inputs )
941+ num_chunks = - (input_ids_length // - prefill_seq_len ) # ceil divide without float
942+
943+ padded_len = num_chunks * prefill_seq_len # Convert to a multiple of prompt_len
944+ if generation_len is None :
945+ generation_len = ctx_len - input_len .max ()
946+
947+ assert generation_len > 0 , "generation length should be greater than zero"
948+ generated_ids = np .full ((batch_size , generation_len + 1 ), pad_token_id )
949+
950+ # Prepare inputs for prefill
951+ prefill_start = perf_counter ()
952+
953+ input_ids = inputs ["input_ids" ]
954+ input_ids_size = input_ids .shape [1 ]
955+ inputs ["input_ids" ] = torch .nn .functional .pad (
956+ inputs ["input_ids" ],
957+ (0 , padded_len - input_ids_size ),
958+ "constant" ,
959+ 1 ,
960+ )
961+ inputs ["attention_mask" ] = torch .nn .functional .pad (
962+ inputs ["attention_mask" ],
963+ (0 , padded_len - input_ids_size ), "constant" , 0
964+ )
967965
968- # Skip inputs/outputs again
969- qpc_session .skip_buffers (
970- ["pixel_values" ]
971- )
966+ for k , v in inputs .items ():
967+ inputs [k ] = np .array (v )
972968
973- # Get first token
974- inputs = {}
975- inputs ["input_ids" ] = outputs ["logits" ].argmax (2 )
976- inputs ["position_ids" ] = input_len .numpy ().astype (np .int64 )
977- # inputs["cross_attention_mask"] = inputs["cross_attention_mask"][:, -1:, :, :]
978- generated_ids [:, 0 ] = inputs ["input_ids" ].squeeze (1 )
979- # finished_sequences = inputs["input_ids"] == self.tokenizer.eos_token_id
980- if streamer :
981- streamer .put (inputs ["input_ids" ][0 ])
969+ inputs ["pixel_values" ] = inputs ["pixel_values" ].astype ("float16" )
970+ inputs ["position_ids" ] = np .where (inputs .pop ("attention_mask" ), np .arange (padded_len ), - 1 )
982971
983- # Decode loop
984- loop_start = perf_counter ()
985- for num_token in range (1 , generation_len ):
986- outputs = qpc_session .run (inputs )
972+ qpc_session .activate ()
987973
988- # Prepare inputs for next iteration
974+ # Run prefill
975+
976+ for i in range (num_chunks ):
977+ chunk_inputs = inputs .copy ()
978+ chunk_inputs ["input_ids" ] = inputs ["input_ids" ][:, i * prefill_seq_len : (i + 1 ) * prefill_seq_len ]
979+ chunk_inputs ["position_ids" ] = inputs ["position_ids" ][:, i * prefill_seq_len : (i + 1 ) * prefill_seq_len ]
980+ outputs = qpc_session .run (chunk_inputs )
981+
982+ prefill_time = prefill_start - perf_counter ()
983+ # Get first token
989984 inputs ["input_ids" ] = outputs ["logits" ].argmax (2 )
990- inputs ["position_ids" ] += 1
991- generated_ids [:, num_token ] = inputs ["input_ids" ].squeeze (1 )
985+ inputs ["position_ids" ] = input_len . numpy ()
986+ generated_ids [:, 0 ] = inputs ["input_ids" ].squeeze (1 )
992987 if streamer :
993988 streamer .put (inputs ["input_ids" ][0 ])
994989
995- end = perf_counter ()
996- if streamer :
997- streamer .end ()
998-
999- prefill_perf = 1 / (loop_start - start )
1000- decode_perf = (num_token - 1 ) / (end - loop_start )
1001- total_perf = num_token / (end - start )
990+ qpc_session .skip_buffers (["pixel_values" ])
991+ inputs .pop ("pixel_values" )
992+
993+ # Decode loop
994+ decode_start = perf_counter ()
995+ for num_token in range (1 , generation_len ):
996+ outputs = qpc_session .run (inputs )
997+ # Prepare inputs for next iteration
998+ inputs ["input_ids" ] = outputs ["logits" ].argmax (2 )
999+ inputs ["position_ids" ] += 1
1000+ generated_ids [:, num_token ] = inputs ["input_ids" ].squeeze (1 )
1001+ if streamer :
1002+ streamer .put (inputs ["input_ids" ][0 ])
1003+
1004+ decode_end = perf_counter ()
1005+ if streamer :
1006+ streamer .end ()
10021007
1003- print ("TTFT:" , round (loop_start - start , 2 ), "s" , file = sys .stderr )
1004- print ("E2ET:" , round (end - start , 2 ), "s" , file = sys .stderr )
1005- print ("Prefill:" , round (prefill_perf , 2 ), "tok/s" , file = sys .stderr )
1006- print ("Decode:" , round (decode_perf , 2 ), "tok/s" , file = sys .stderr )
1007- print ("E2E:" , round (total_perf , 2 ), "tok/s" , file = sys .stderr )
1008- if batch_size > 1 :
1009- print ("Prefill (batch):" , round (prefill_perf * batch_size , 2 ), "tok/s" , file = sys .stderr )
1010- print ("Decode (batch):" , round (decode_perf * batch_size , 2 ), "tok/s" , file = sys .stderr )
1011- print ("E2E (batch):" , round (total_perf * batch_size , 2 ), "tok/s" , file = sys .stderr )
1012- return generated_ids [:, :generation_len ]
1008+ decode_perf = (num_token - 1 ) / (decode_end - decode_start )
1009+ total_time = decode_end - prefill_start
1010+ total_perf = num_token / total_time
1011+
1012+ return CloudAI100ExecInfoNew (
1013+ batch_size = batch_size ,
1014+ generated_ids = generated_ids ,
1015+ perf_metrics = PerfMetrics (
1016+ prefill_time = prefill_time ,
1017+ decode_perf = decode_perf ,
1018+ total_perf = total_perf ,
1019+ total_time = total_time
1020+ )
1021+ )
10131022
10141023 @property
10151024 def model_hash (self ) -> str :
@@ -1040,9 +1049,9 @@ class QEFFAutoModelForImageTextToText:
10401049 @classmethod
10411050 def from_pytorch_model (cls , model : nn .Module , kv_offload = False , ** kwargs ):
10421051 if kv_offload :
1043- return QEffAutoModelForImageTextToText2QPC (model , ** kwargs )
1052+ return _QEffAutoModelForImageTextToText2QPC (model , ** kwargs )
10441053 else :
1045- return QEFFAutoModelForImageTextToText1QPC (model , ** kwargs )
1054+ return _QEFFAutoModelForImageTextToText1QPC (model , ** kwargs )
10461055
10471056
10481057 @classmethod
0 commit comments