@@ -600,6 +600,7 @@ def compile(
600600 num_speculative_tokens : Optional [int ] = None ,
601601 enable_qnn : bool = False ,
602602 qnn_config : Optional [str ] = None ,
603+ compile_for : Optional [str ] = None ,
603604 ** compiler_options ,
604605 ) -> str :
605606 if (
@@ -615,6 +616,9 @@ def compile(
615616 f"enable_qnn={ enable_qnn } , qnn_config={ qnn_config } "
616617 )
617618
619+ if compile_for not in {"vision" , "lang" , None }:
620+ raise ValueError (f"Expected 'compile_for' to be one of 'vision', 'lang', or None but got: { compile_for } " )
621+
618622 output_names = self .model .get_output_names (kv_offload = True )
619623
620624 specializations , compiler_options = self .model .get_specializations (
@@ -642,41 +646,49 @@ def compile(
642646 ):
643647 self .export ()
644648
645- self .vision_model ._compile (
646- compile_dir ,
647- compile_only = True ,
648- specializations = specializations ["vision" ],
649- convert_to_fp16 = True ,
650- mxfp6_matmul = mxfp6_matmul ,
651- mdp_ts_num_devices = num_devices ,
652- aic_num_cores = num_cores ,
653- custom_io = custom_io_vision ,
654- ** compiler_options ,
655- )
649+ if compile_for is None or compile_for .lower () == "vision" :
650+ vision_qpc_path = self .vision_model ._compile (
651+ compile_dir ,
652+ compile_only = True ,
653+ specializations = specializations ["vision" ],
654+ convert_to_fp16 = True ,
655+ mxfp6_matmul = mxfp6_matmul ,
656+ mdp_ts_num_devices = num_devices ,
657+ aic_num_cores = num_cores ,
658+ custom_io = custom_io_vision ,
659+ ** compiler_options ,
660+ )
656661
657- custom_io_lang = {}
658- # Inputs
659- for output_name in output_names ["lang" ]:
660- if output_name .endswith ("_RetainedState" ):
661- custom_io_lang [output_name [: - len ("_RetainedState" )]] = kv_cache_dtype
662+ if compile_for == "vision" :
663+ return vision_qpc_path
662664
663- # outputs
664- for output_name in output_names ["lang" ]:
665- if output_name .endswith ("_RetainedState" ):
666- custom_io_lang [output_name ] = kv_cache_dtype
665+ if compile_for is None or compile_for .lower () == "lang" :
666+ custom_io_lang = {}
667+ # Inputs
668+ for output_name in output_names ["lang" ]:
669+ if output_name .endswith ("_RetainedState" ):
670+ custom_io_lang [output_name [: - len ("_RetainedState" )]] = kv_cache_dtype
671+
672+ # outputs
673+ for output_name in output_names ["lang" ]:
674+ if output_name .endswith ("_RetainedState" ):
675+ custom_io_lang [output_name ] = kv_cache_dtype
676+
677+ lang_qpc_path = self .lang_model ._compile (
678+ compile_dir ,
679+ compile_only = True ,
680+ retained_state = True ,
681+ specializations = specializations ["lang" ],
682+ convert_to_fp16 = True ,
683+ mxfp6_matmul = mxfp6_matmul ,
684+ mdp_ts_num_devices = num_devices ,
685+ aic_num_cores = num_cores ,
686+ custom_io = custom_io_lang ,
687+ ** compiler_options ,
688+ )
689+ if compile_for == "lang" :
690+ return lang_qpc_path
667691
668- self .lang_model ._compile (
669- compile_dir ,
670- compile_only = True ,
671- retained_state = True ,
672- specializations = specializations ["lang" ],
673- convert_to_fp16 = True ,
674- mxfp6_matmul = mxfp6_matmul ,
675- mdp_ts_num_devices = num_devices ,
676- aic_num_cores = num_cores ,
677- custom_io = custom_io_lang ,
678- ** compiler_options ,
679- )
680692 return self .qpc_path
681693
682694 def generate (
@@ -711,6 +723,15 @@ def kv_offload_generate(
711723 device_ids : List [int ] = None ,
712724 generation_len : int = None ,
713725 ):
726+ if not self .vision_model .qpc_path or not self .lang_model .qpc_path :
727+ raise TypeError ("Please run compile API for vision and language model first!" )
728+
729+ if not self .vision_model .qpc_path :
730+ raise TypeError ("Please run compile API for vision model first!" )
731+
732+ if not self .lang_model .qpc_path :
733+ raise TypeError ("Please run compile API for language model first!" )
734+
714735 lang_session = QAICInferenceSession (self .lang_model .qpc_path , device_ids , activate = False )
715736
716737 vision_session = QAICInferenceSession (self .vision_model .qpc_path , device_ids )
0 commit comments