Skip to content

Commit 22d35ec

Browse files
shubhagr-quiceplatero97
authored andcommitted
QNN Compilation path Support in QEFFBaseModel class. (quic#374)
This change will facilitate the support of QNN Compilation path for any model class derived from QEFFBaseModel class. --------- Signed-off-by: Shubham Agrawal <quic_shubhagr@quicinc.com> Signed-off-by: eplatero <quic_eplatero@quicinc.com>
1 parent 46ccf25 commit 22d35ec

File tree

4 files changed

+137
-74
lines changed

4 files changed

+137
-74
lines changed

QEfficient/base/modeling_qeff.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,22 @@ def _compile(
254254
qpc_path = compile_dir / "qpc"
255255
if not onnx_path.is_file():
256256
raise FileNotFoundError(f"ONNX file not found at: {onnx_path}")
257+
258+
if enable_qnn:
259+
self.qpc_path = qnn_compile(
260+
onnx_path=onnx_path,
261+
qpc_base_path=compile_dir,
262+
specializations=specializations,
263+
custom_io=custom_io,
264+
device_group=list(range(mdp_ts_num_devices)),
265+
num_cores=compiler_options.get("aic_num_cores", 16),
266+
mxfp6=compiler_options.get("mxfp6_matmul", False),
267+
mxint8=mxint8_kv_cache,
268+
qnn_config=qnn_config,
269+
)
270+
271+
return self.qpc_path
272+
257273
command = constants.COMPILER + [f"-m={onnx_path}"]
258274
if mdp_ts_json_path := compiler_options.pop("mdp_ts_json_path", None):
259275
mdp_ts_num_devices = None

QEfficient/transformers/models/modeling_auto.py

Lines changed: 80 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -605,6 +605,8 @@ def compile(
605605
num_speculative_tokens: Optional[int] = None,
606606
enable_qnn: bool = False,
607607
qnn_config: Optional[str] = None,
608+
skip_vision: Optional[bool] = False,
609+
skip_lang: Optional[bool] = False,
608610
**compiler_options,
609611
) -> str:
610612
if any(param is not None for param in [full_batch_size, kv_cache_batch_size, num_speculative_tokens]):
@@ -643,17 +645,18 @@ def compile(
643645
):
644646
self.export()
645647

646-
self.vision_model._compile(
647-
compile_dir,
648-
compile_only=True,
649-
specializations=specializations["vision"],
650-
convert_to_fp16=True,
651-
mxfp6_matmul=mxfp6_matmul,
652-
mdp_ts_num_devices=num_devices,
653-
aic_num_cores=num_cores,
654-
custom_io=custom_io_vision,
655-
**compiler_options,
656-
)
648+
if not skip_vision:
649+
self.vision_model._compile(
650+
compile_dir,
651+
compile_only=True,
652+
specializations=specializations["vision"],
653+
convert_to_fp16=True,
654+
mxfp6_matmul=mxfp6_matmul,
655+
mdp_ts_num_devices=num_devices,
656+
aic_num_cores=num_cores,
657+
custom_io=custom_io_vision,
658+
**compiler_options,
659+
)
657660

658661
if not skip_lang:
659662
custom_io_lang = {}
@@ -667,18 +670,18 @@ def compile(
667670
if output_name.endswith("_RetainedState"):
668671
custom_io_lang[output_name] = kv_cache_dtype
669672

670-
self.lang_model._compile(
671-
compile_dir,
672-
compile_only=True,
673-
retained_state=True,
674-
specializations=specializations["lang"],
675-
convert_to_fp16=True,
676-
mxfp6_matmul=mxfp6_matmul,
677-
mdp_ts_num_devices=num_devices,
678-
aic_num_cores=num_cores,
679-
custom_io=custom_io_lang,
680-
**compiler_options,
681-
)
673+
self.lang_model._compile(
674+
compile_dir,
675+
compile_only=True,
676+
retained_state=True,
677+
specializations=specializations["lang"],
678+
convert_to_fp16=True,
679+
mxfp6_matmul=mxfp6_matmul,
680+
mdp_ts_num_devices=num_devices,
681+
aic_num_cores=num_cores,
682+
custom_io=custom_io_lang,
683+
**compiler_options,
684+
)
682685
return self.qpc_path
683686

684687
def generate(
@@ -1534,6 +1537,7 @@ def compile(
15341537
num_speculative_tokens: Optional[int] = None,
15351538
enable_qnn: bool = False,
15361539
qnn_config: Optional[str] = None,
1540+
prefill_only: Optional[bool] = None,
15371541
**compiler_options,
15381542
) -> str:
15391543
"""
@@ -1557,6 +1561,8 @@ def compile(
15571561
:aic_enable_depth_first (bool, optional): Enables DFS with default memory size. ``Defaults to False``.
15581562
:enable_qnn (bool): Enables QNN Compilation. ``Defaults to False.``
15591563
:qnn_config (str): Path of QNN Config parameters file. ``Defaults to None.``
1564+
:prefill_only (bool): if ``True`` compile for prefill only and if ``False`` compile for decode only. Defaults to None, which compiles for both ``prefill and ``decode``.
1565+
:compiler_options (dict, optional): Any other options that the `qaic-exec` takes. ``Defaults to None``.
15601566
15611567
Returns:
15621568
:str: Path of the compiled ``qpc`` package.
@@ -1588,48 +1594,33 @@ def compile(
15881594
"enable `continuous_batching=True` in `from_pretrained`."
15891595
)
15901596

1591-
kv_cache_batch_size = (
1592-
kv_cache_batch_size if kv_cache_batch_size else (full_batch_size if full_batch_size else batch_size)
1593-
)
1594-
# Define prefill specialization
1595-
prefill_specialization = {
1596-
# Prefill is always run with single BS for continuous batching.
1597-
"batch_size": 1 if self.continuous_batching else batch_size,
1598-
"seq_len": prefill_seq_len,
1599-
"ctx_len": ctx_len,
1600-
# TODO: should be renamed to kv_cache_batch_size in specialization too
1601-
}
1602-
prefill_specialization.update({"num_logits_to_keep": 1}) if self.is_tlm else ...
1603-
if self.continuous_batching:
1604-
prefill_specialization.update({"full_batch_size": kv_cache_batch_size})
1605-
else:
1606-
prefill_specialization.update({"batch_size": kv_cache_batch_size})
1607-
prefill_specialization.update({"full_batch_exec_size": full_batch_size}) if full_batch_size else ...
1608-
specializations = [
1609-
prefill_specialization,
1610-
]
1597+
# Infer kv_cache_batch_size if not provided
1598+
kv_cache_batch_size = kv_cache_batch_size or full_batch_size or batch_size
16111599

1612-
# Skip decode specialization if we are not in continuous batching and prefill_seq_len=1 as this repeats prefill specialization
1613-
if prefill_seq_len != 1 or self.continuous_batching:
1614-
decode_specialization = {
1615-
"batch_size": full_batch_size if self.continuous_batching else batch_size,
1616-
"seq_len": num_speculative_tokens + 1 if self.is_tlm else 1,
1617-
"ctx_len": ctx_len,
1618-
}
1619-
if self.continuous_batching:
1620-
decode_specialization.update({"full_batch_size": kv_cache_batch_size})
1621-
else:
1622-
decode_specialization.update({"batch_size": kv_cache_batch_size})
1623-
decode_specialization.update({"num_logits_to_keep": num_speculative_tokens + 1}) if self.is_tlm else ...
1624-
specializations.append(decode_specialization)
1600+
# --- Specializations ---
1601+
specializations = []
1602+
1603+
if prefill_only is None or prefill_only or prefill_seq_len == 1:
1604+
specializations.append(
1605+
self.build_prefill_specialization(
1606+
prefill_seq_len, ctx_len, batch_size, kv_cache_batch_size, full_batch_size
1607+
)
1608+
)
1609+
if prefill_only is None or not prefill_only:
1610+
decode_spec = self.build_decode_specialization(
1611+
prefill_seq_len, ctx_len, batch_size, kv_cache_batch_size, full_batch_size, num_speculative_tokens
1612+
)
1613+
if decode_spec:
1614+
specializations.append(decode_spec)
16251615

1616+
# --- Compilation ---
16261617
if enable_qnn:
16271618
if compiler_options:
1628-
logger.warning("Extra arguments to QNN compilation are supported via qnn_config.json only")
1619+
logger.warning("Extra arguments to QNN compilation are ignored. Use `qnn_config.json`.")
16291620

16301621
qpc_path = self._qnn_compile(
1631-
onnx_path,
1632-
compile_dir,
1622+
onnx_path=onnx_path,
1623+
compile_dir=compile_dir,
16331624
specializations=specializations,
16341625
prefill_seq_len=prefill_seq_len,
16351626
ctx_len=ctx_len,
@@ -1643,17 +1634,17 @@ def compile(
16431634
kv_cache_batch_size=kv_cache_batch_size,
16441635
)
16451636
else:
1646-
# Custom IO
1647-
custom_io = {}
16481637
kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16"
1638+
custom_io = {}
1639+
16491640
for suffix in ["", "_RetainedState"]:
16501641
for i in range(self.num_layers):
16511642
for kv in ["key", "value"]:
16521643
custom_io[f"past_{kv}.{i}{suffix}"] = kv_cache_dtype
16531644

16541645
qpc_path = self._compile(
1655-
onnx_path,
1656-
compile_dir,
1646+
onnx_path=onnx_path,
1647+
compile_dir=compile_dir,
16571648
compile_only=True,
16581649
retained_state=True,
16591650
specializations=specializations,
@@ -1665,6 +1656,7 @@ def compile(
16651656
aic_num_cores=num_cores,
16661657
**compiler_options,
16671658
)
1659+
16681660
return qpc_path
16691661

16701662
# FIXME: Update this method to match with transformers AutoModelForCausalLM.generate
@@ -1829,6 +1821,10 @@ def compile(
18291821
num_devices: int = 1,
18301822
num_cores: int = 16, # FIXME: Make this mandatory arg
18311823
mxfp6_matmul: bool = False,
1824+
mxint8_kv_cache: bool = False,
1825+
num_speculative_tokens: Optional[int] = None,
1826+
enable_qnn: bool = False,
1827+
qnn_config: Optional[str] = None,
18321828
**compiler_options,
18331829
) -> str:
18341830
"""
@@ -1851,7 +1847,27 @@ def compile(
18511847
Returns:
18521848
:str: Path of the compiled ``qpc`` package.
18531849
"""
1854-
specializations = self.model.get_specializations(batch_size, encoder_ctx_len, decoder_ctx_len, feature_len)
1850+
specializations, compiler_options = self.model.get_specializations(
1851+
batch_size,
1852+
encoder_ctx_len,
1853+
ctx_len,
1854+
**compiler_options,
1855+
)
1856+
1857+
if full_batch_size:
1858+
logger.warning("Continuous batching is not yet enabled for AutoModelForSpeechSeq2Seq")
1859+
1860+
if kv_cache_batch_size:
1861+
logger.warning("Prefix caching is not yet enabled for AutoModelForSpeechSeq2Seq")
1862+
1863+
if mxint8_kv_cache:
1864+
logger.warning("mxint8 cache is not yet enabled for AutoModelForSpeechSeq2Seq")
1865+
1866+
if num_speculative_tokens:
1867+
logger.warning("Speculative decoding is not yet enabled for AutoModelForSpeechSeq2Seq")
1868+
1869+
if enable_qnn or qnn_config:
1870+
logger.warning("QNN compile is not yet enabled for AutoModelForSpeechSeq2Seq")
18551871

18561872
return self._compile(
18571873
onnx_path,

tests/transformers/models/test_causal_lm_models.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,13 @@
5151
"Snowflake/Llama-3.1-SwiftKV-8B-Instruct", # SwiftKV model
5252
]
5353

54+
test_models_qnn = [
55+
"mistralai/Mixtral-8x7B-Instruct-v0.1",
56+
"meta-llama/Llama-3.2-1B",
57+
"unsloth/gemma-2b",
58+
"ibm-granite/granite-guardian-3.1-2b",
59+
]
60+
5461
spd_test_models = [
5562
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
5663
"Qwen/Qwen2-0.5B",
@@ -88,6 +95,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
8895
ctx_len: int = Constants.CTX_LEN,
8996
n_layer: int = 1,
9097
num_speculative_tokens: Optional[int] = None,
98+
prefill_only: Optional[bool] = None,
9199
):
92100
"""
93101
Validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching.
@@ -145,6 +153,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
145153
mxfp6=False,
146154
aic_enable_depth_first=False,
147155
num_speculative_tokens=num_speculative_tokens,
156+
prefill_only=prefill_only,
148157
)
149158
exec_info = qeff_model.generate(tokenizer, prompts=Constants.INPUT_STR)
150159
cloud_ai_100_tokens = exec_info.generated_ids[0] # Because we always run for single input and single batch size
@@ -270,6 +279,29 @@ def test_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_qnn(model_name):
270279
)
271280

272281

282+
@pytest.mark.on_qaic
283+
@pytest.mark.qnn
284+
@pytest.mark.parametrize("model_name", test_models_qnn)
285+
def test_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_qnn(model_name):
286+
"""
287+
QNN Compilation Test
288+
Test function to validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching.
289+
``Mandatory`` Args:
290+
:model_name (str): Hugging Face Model Card name, Example: ``gpt2``
291+
"""
292+
if model_name == "microsoft/Phi-3-mini-4k-instruct":
293+
n_layer = 2 # test only 2 layer models
294+
else:
295+
n_layer = 1
296+
297+
qnn_config_json_path = os.path.join(os.getcwd(), "qnn_config.json")
298+
create_json(qnn_config_json_path, QnnConstants.QNN_SAMPLE_CONFIG)
299+
300+
check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
301+
model_name=model_name, n_layer=n_layer, enable_qnn=True, qnn_config=qnn_config_json_path
302+
)
303+
304+
273305
@pytest.mark.on_qaic
274306
@pytest.mark.parametrize("model_name", spd_test_models)
275307
def test_causal_tlm_pytorch_vs_kv_vs_ort_vs_ai100(model_name):
@@ -298,3 +330,12 @@ def test_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_pl1():
298330
prompt_len = 1
299331

300332
check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, prompt_len=prompt_len)
333+
334+
335+
@pytest.mark.on_qaic
336+
def test_prefiill_only_pytorch_vs_kv_vs_ort_vs_ai100():
337+
model_name = "gpt2"
338+
n_layer = 1
339+
check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, n_layer=n_layer, prefill_only=True)
340+
341+
check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, n_layer=n_layer, prefill_only=False)

tests/transformers/models/test_prefix_caching.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,6 @@ def test_simple_prefix_caching(model_name):
4040
@pytest.mark.parametrize("model_name", test_models)
4141
def test_simple_prefix_caching_qnn(model_name):
4242
qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_name, continuous_batching=True)
43-
qnn_config = {
44-
"converter_args_extension": "",
45-
"context_binary_generator_args_extension": "--log_level debug",
46-
"qnn_compilation_backend": {
47-
"compiler_enable_depth_first": True,
48-
"compiler_printDDRStats": False,
49-
"compiler_printPerfMetrics": False,
50-
},
51-
"SKIP_QNN_CONVERTER_STEP": False,
52-
}
5343
qnn_config_json_path = os.path.join(os.getcwd(), "qnn_config.json")
5444
create_json(qnn_config_json_path, QnnConstants.QNN_SAMPLE_CONFIG)
5545

0 commit comments

Comments
 (0)