Skip to content

Commit 574a6df

Browse files
authored
Support for dual compilation on VLMs (#361)
Signed-off-by: Rishin Raj <quic_rishinr@quicinc.com>
1 parent b88b758 commit 574a6df

File tree

1 file changed

+53
-32
lines changed

1 file changed

+53
-32
lines changed

QEfficient/transformers/models/modeling_auto.py

Lines changed: 53 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,7 @@ def compile(
600600
num_speculative_tokens: Optional[int] = None,
601601
enable_qnn: bool = False,
602602
qnn_config: Optional[str] = None,
603+
compile_for: Optional[str] = None,
603604
**compiler_options,
604605
) -> str:
605606
if (
@@ -615,6 +616,9 @@ def compile(
615616
f"enable_qnn={enable_qnn}, qnn_config={qnn_config}"
616617
)
617618

619+
if compile_for not in {"vision", "lang", None}:
620+
raise ValueError(f"Expected 'compile_for' to be one of 'vision', 'lang', or None but got: {compile_for}")
621+
618622
output_names = self.model.get_output_names(kv_offload=True)
619623

620624
specializations, compiler_options = self.model.get_specializations(
@@ -642,41 +646,49 @@ def compile(
642646
):
643647
self.export()
644648

645-
self.vision_model._compile(
646-
compile_dir,
647-
compile_only=True,
648-
specializations=specializations["vision"],
649-
convert_to_fp16=True,
650-
mxfp6_matmul=mxfp6_matmul,
651-
mdp_ts_num_devices=num_devices,
652-
aic_num_cores=num_cores,
653-
custom_io=custom_io_vision,
654-
**compiler_options,
655-
)
649+
if compile_for is None or compile_for.lower() == "vision":
650+
vision_qpc_path = self.vision_model._compile(
651+
compile_dir,
652+
compile_only=True,
653+
specializations=specializations["vision"],
654+
convert_to_fp16=True,
655+
mxfp6_matmul=mxfp6_matmul,
656+
mdp_ts_num_devices=num_devices,
657+
aic_num_cores=num_cores,
658+
custom_io=custom_io_vision,
659+
**compiler_options,
660+
)
656661

657-
custom_io_lang = {}
658-
# Inputs
659-
for output_name in output_names["lang"]:
660-
if output_name.endswith("_RetainedState"):
661-
custom_io_lang[output_name[: -len("_RetainedState")]] = kv_cache_dtype
662+
if compile_for == "vision":
663+
return vision_qpc_path
662664

663-
# outputs
664-
for output_name in output_names["lang"]:
665-
if output_name.endswith("_RetainedState"):
666-
custom_io_lang[output_name] = kv_cache_dtype
665+
if compile_for is None or compile_for.lower() == "lang":
666+
custom_io_lang = {}
667+
# Inputs
668+
for output_name in output_names["lang"]:
669+
if output_name.endswith("_RetainedState"):
670+
custom_io_lang[output_name[: -len("_RetainedState")]] = kv_cache_dtype
671+
672+
# outputs
673+
for output_name in output_names["lang"]:
674+
if output_name.endswith("_RetainedState"):
675+
custom_io_lang[output_name] = kv_cache_dtype
676+
677+
lang_qpc_path = self.lang_model._compile(
678+
compile_dir,
679+
compile_only=True,
680+
retained_state=True,
681+
specializations=specializations["lang"],
682+
convert_to_fp16=True,
683+
mxfp6_matmul=mxfp6_matmul,
684+
mdp_ts_num_devices=num_devices,
685+
aic_num_cores=num_cores,
686+
custom_io=custom_io_lang,
687+
**compiler_options,
688+
)
689+
if compile_for == "lang":
690+
return lang_qpc_path
667691

668-
self.lang_model._compile(
669-
compile_dir,
670-
compile_only=True,
671-
retained_state=True,
672-
specializations=specializations["lang"],
673-
convert_to_fp16=True,
674-
mxfp6_matmul=mxfp6_matmul,
675-
mdp_ts_num_devices=num_devices,
676-
aic_num_cores=num_cores,
677-
custom_io=custom_io_lang,
678-
**compiler_options,
679-
)
680692
return self.qpc_path
681693

682694
def generate(
@@ -711,6 +723,15 @@ def kv_offload_generate(
711723
device_ids: List[int] = None,
712724
generation_len: int = None,
713725
):
726+
if not self.vision_model.qpc_path or not self.lang_model.qpc_path:
727+
raise TypeError("Please run compile API for vision and language model first!")
728+
729+
if not self.vision_model.qpc_path:
730+
raise TypeError("Please run compile API for vision model first!")
731+
732+
if not self.lang_model.qpc_path:
733+
raise TypeError("Please run compile API for language model first!")
734+
714735
lang_session = QAICInferenceSession(self.lang_model.qpc_path, device_ids, activate=False)
715736

716737
vision_session = QAICInferenceSession(self.vision_model.qpc_path, device_ids)

0 commit comments

Comments
 (0)