Skip to content

Commit be0e174

Browse files
shubhagr-quiceplatero97
authored andcommitted
QNN Compilation path Support in QEFFBaseModel class. (quic#374)
This change will facilitate the support of QNN Compilation path for any model class derived from QEFFBaseModel class. --------- Signed-off-by: Shubham Agrawal <quic_shubhagr@quicinc.com> Signed-off-by: eplatero <quic_eplatero@quicinc.com>
1 parent ec7bae9 commit be0e174

File tree

3 files changed

+98
-82
lines changed

3 files changed

+98
-82
lines changed

QEfficient/base/modeling_qeff.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,22 @@ def _compile(
254254
qpc_path = compile_dir / "qpc"
255255
if not onnx_path.is_file():
256256
raise FileNotFoundError(f"ONNX file not found at: {onnx_path}")
257+
258+
if enable_qnn:
259+
self.qpc_path = qnn_compile(
260+
onnx_path=onnx_path,
261+
qpc_base_path=compile_dir,
262+
specializations=specializations,
263+
custom_io=custom_io,
264+
device_group=list(range(mdp_ts_num_devices)),
265+
num_cores=compiler_options.get("aic_num_cores", 16),
266+
mxfp6=compiler_options.get("mxfp6_matmul", False),
267+
mxint8=mxint8_kv_cache,
268+
qnn_config=qnn_config,
269+
)
270+
271+
return self.qpc_path
272+
257273
command = constants.COMPILER + [f"-m={onnx_path}"]
258274
if mdp_ts_json_path := compiler_options.pop("mdp_ts_json_path", None):
259275
mdp_ts_num_devices = None

QEfficient/transformers/models/modeling_auto.py

Lines changed: 57 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -605,6 +605,8 @@ def compile(
605605
num_speculative_tokens: Optional[int] = None,
606606
enable_qnn: bool = False,
607607
qnn_config: Optional[str] = None,
608+
skip_vision: Optional[bool] = False,
609+
skip_lang: Optional[bool] = False,
608610
**compiler_options,
609611
) -> str:
610612
if any(param is not None for param in [full_batch_size, kv_cache_batch_size, num_speculative_tokens]):
@@ -646,17 +648,18 @@ def compile(
646648
):
647649
self.export()
648650

649-
self.vision_model._compile(
650-
compile_dir,
651-
compile_only=True,
652-
specializations=specializations["vision"],
653-
convert_to_fp16=True,
654-
mxfp6_matmul=mxfp6_matmul,
655-
mdp_ts_num_devices=num_devices,
656-
aic_num_cores=num_cores,
657-
custom_io=custom_io_vision,
658-
**compiler_options,
659-
)
651+
if not skip_vision:
652+
self.vision_model._compile(
653+
compile_dir,
654+
compile_only=True,
655+
specializations=specializations["vision"],
656+
convert_to_fp16=True,
657+
mxfp6_matmul=mxfp6_matmul,
658+
mdp_ts_num_devices=num_devices,
659+
aic_num_cores=num_cores,
660+
custom_io=custom_io_vision,
661+
**compiler_options,
662+
)
660663

661664
custom_io_lang = {}
662665
# Inputs
@@ -669,18 +672,18 @@ def compile(
669672
if output_name.endswith("_RetainedState"):
670673
custom_io_lang[output_name] = kv_cache_dtype
671674

672-
self.lang_model._compile(
673-
compile_dir,
674-
compile_only=True,
675-
retained_state=True,
676-
specializations=specializations["lang"],
677-
convert_to_fp16=True,
678-
mxfp6_matmul=mxfp6_matmul,
679-
mdp_ts_num_devices=num_devices,
680-
aic_num_cores=num_cores,
681-
custom_io=custom_io_lang,
682-
**compiler_options,
683-
)
675+
self.lang_model._compile(
676+
compile_dir,
677+
compile_only=True,
678+
retained_state=True,
679+
specializations=specializations["lang"],
680+
convert_to_fp16=True,
681+
mxfp6_matmul=mxfp6_matmul,
682+
mdp_ts_num_devices=num_devices,
683+
aic_num_cores=num_cores,
684+
custom_io=custom_io_lang,
685+
**compiler_options,
686+
)
684687
return self.qpc_path
685688

686689
def generate(
@@ -1539,6 +1542,7 @@ def compile(
15391542
num_speculative_tokens: Optional[int] = None,
15401543
enable_qnn: bool = False,
15411544
qnn_config: Optional[str] = None,
1545+
prefill_only: Optional[bool] = None,
15421546
**compiler_options,
15431547
) -> str:
15441548
"""
@@ -1562,6 +1566,8 @@ def compile(
15621566
:aic_enable_depth_first (bool, optional): Enables DFS with default memory size. ``Defaults to False``.
15631567
:enable_qnn (bool): Enables QNN Compilation. ``Defaults to False.``
15641568
:qnn_config (str): Path of QNN Config parameters file. ``Defaults to None.``
1569+
:prefill_only (bool): if ``True`` compile for prefill only and if ``False`` compile for decode only. Defaults to None, which compiles for both ``prefill and ``decode``.
1570+
:compiler_options (dict, optional): Any other options that the `qaic-exec` takes. ``Defaults to None``.
15651571
15661572
Returns:
15671573
:str: Path of the compiled ``qpc`` package.
@@ -1583,48 +1589,33 @@ def compile(
15831589
"enable `continuous_batching=True` in `from_pretrained`."
15841590
)
15851591

1586-
kv_cache_batch_size = (
1587-
kv_cache_batch_size if kv_cache_batch_size else (full_batch_size if full_batch_size else batch_size)
1588-
)
1589-
# Define prefill specialization
1590-
prefill_specialization = {
1591-
# Prefill is always run with single BS for continuous batching.
1592-
"batch_size": 1 if self.continuous_batching else batch_size,
1593-
"seq_len": prefill_seq_len,
1594-
"ctx_len": ctx_len,
1595-
# TODO: should be renamed to kv_cache_batch_size in specialization too
1596-
}
1597-
prefill_specialization.update({"num_logits_to_keep": 1}) if self.is_tlm else ...
1598-
if self.continuous_batching:
1599-
prefill_specialization.update({"full_batch_size": kv_cache_batch_size})
1600-
else:
1601-
prefill_specialization.update({"batch_size": kv_cache_batch_size})
1602-
prefill_specialization.update({"full_batch_exec_size": full_batch_size}) if full_batch_size else ...
1603-
specializations = [
1604-
prefill_specialization,
1605-
]
1592+
# Infer kv_cache_batch_size if not provided
1593+
kv_cache_batch_size = kv_cache_batch_size or full_batch_size or batch_size
16061594

1607-
# Skip decode specialization if we are not in continuous batching and prefill_seq_len=1 as this repeats prefill specialization
1608-
if prefill_seq_len != 1 or self.continuous_batching:
1609-
decode_specialization = {
1610-
"batch_size": full_batch_size if self.continuous_batching else batch_size,
1611-
"seq_len": num_speculative_tokens + 1 if self.is_tlm else 1,
1612-
"ctx_len": ctx_len,
1613-
}
1614-
if self.continuous_batching:
1615-
decode_specialization.update({"full_batch_size": kv_cache_batch_size})
1616-
else:
1617-
decode_specialization.update({"batch_size": kv_cache_batch_size})
1618-
decode_specialization.update({"num_logits_to_keep": num_speculative_tokens + 1}) if self.is_tlm else ...
1619-
specializations.append(decode_specialization)
1595+
# --- Specializations ---
1596+
specializations = []
1597+
1598+
if prefill_only is None or prefill_only or prefill_seq_len == 1:
1599+
specializations.append(
1600+
self.build_prefill_specialization(
1601+
prefill_seq_len, ctx_len, batch_size, kv_cache_batch_size, full_batch_size
1602+
)
1603+
)
1604+
if prefill_only is None or not prefill_only:
1605+
decode_spec = self.build_decode_specialization(
1606+
prefill_seq_len, ctx_len, batch_size, kv_cache_batch_size, full_batch_size, num_speculative_tokens
1607+
)
1608+
if decode_spec:
1609+
specializations.append(decode_spec)
16201610

1611+
# --- Compilation ---
16211612
if enable_qnn:
16221613
if compiler_options:
1623-
logger.warning("Extra arguments to QNN compilation are supported via qnn_config.json only")
1614+
logger.warning("Extra arguments to QNN compilation are ignored. Use `qnn_config.json`.")
16241615

16251616
qpc_path = self._qnn_compile(
1626-
onnx_path,
1627-
compile_dir,
1617+
onnx_path=onnx_path,
1618+
compile_dir=compile_dir,
16281619
specializations=specializations,
16291620
prefill_seq_len=prefill_seq_len,
16301621
ctx_len=ctx_len,
@@ -1638,17 +1629,17 @@ def compile(
16381629
kv_cache_batch_size=kv_cache_batch_size,
16391630
)
16401631
else:
1641-
# Custom IO
1642-
custom_io = {}
16431632
kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16"
1633+
custom_io = {}
1634+
16441635
for suffix in ["", "_RetainedState"]:
16451636
for i in range(self.num_layers):
16461637
for kv in ["key", "value"]:
16471638
custom_io[f"past_{kv}.{i}{suffix}"] = kv_cache_dtype
16481639

16491640
qpc_path = self._compile(
1650-
onnx_path,
1651-
compile_dir,
1641+
onnx_path=onnx_path,
1642+
compile_dir=compile_dir,
16521643
compile_only=True,
16531644
retained_state=True,
16541645
specializations=specializations,
@@ -1660,6 +1651,7 @@ def compile(
16601651
aic_num_cores=num_cores,
16611652
**compiler_options,
16621653
)
1654+
16631655
return qpc_path
16641656

16651657
# FIXME: Update this method to match with transformers AutoModelForCausalLM.generate
@@ -1867,22 +1859,8 @@ def compile(
18671859
if num_speculative_tokens:
18681860
logger.warning("Speculative decoding is not yet enabled for AutoModelForSpeechSeq2Seq")
18691861

1870-
output_names = self.model.get_output_names()
1871-
1872-
kv_cache_dtype = "float16"
1873-
custom_io = {}
1874-
1875-
custom_io["input_features"] = kv_cache_dtype
1876-
1877-
# Slice output_names to get input names
1878-
for output_name in output_names:
1879-
if output_name.endswith("_RetainedState"):
1880-
custom_io[output_name[: -len("_RetainedState")]] = kv_cache_dtype
1881-
1882-
# Get output names
1883-
for output_name in output_names:
1884-
if output_name.endswith("_RetainedState"):
1885-
custom_io[output_name] = kv_cache_dtype
1862+
if enable_qnn or qnn_config:
1863+
logger.warning("QNN compile is not yet enabled for AutoModelForSpeechSeq2Seq")
18861864

18871865
return self._compile(
18881866
onnx_path,

tests/transformers/models/test_causal_lm_models.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,13 @@
5151
"Snowflake/Llama-3.1-SwiftKV-8B-Instruct", # SwiftKV model
5252
]
5353

54+
test_models_qnn = [
55+
"mistralai/Mixtral-8x7B-Instruct-v0.1",
56+
"meta-llama/Llama-3.2-1B",
57+
"unsloth/gemma-2b",
58+
"ibm-granite/granite-guardian-3.1-2b",
59+
]
60+
5461
spd_test_models = [
5562
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
5663
"Qwen/Qwen2-0.5B",
@@ -88,6 +95,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
8895
ctx_len: int = Constants.CTX_LEN,
8996
n_layer: int = 1,
9097
num_speculative_tokens: Optional[int] = None,
98+
prefill_only: Optional[bool] = None,
9199
):
92100
"""
93101
Validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching.
@@ -145,6 +153,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
145153
mxfp6=False,
146154
aic_enable_depth_first=False,
147155
num_speculative_tokens=num_speculative_tokens,
156+
prefill_only=prefill_only,
148157
)
149158
exec_info = qeff_model.generate(tokenizer, prompts=Constants.INPUT_STR)
150159
cloud_ai_100_tokens = exec_info.generated_ids[0][
@@ -193,6 +202,8 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
193202
aic_enable_depth_first=False,
194203
full_batch_size=full_batch_size,
195204
num_speculative_tokens=num_speculative_tokens,
205+
enable_qnn=enable_qnn,
206+
qnn_config=qnn_config,
196207
)
197208
exec_info_fbs = qeff_model.generate(tokenizer, prompts=fbs_prompts)
198209

@@ -361,10 +372,12 @@ def test_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_qnn(model_name):
361372
if model_name == "microsoft/Phi-3-mini-4k-instruct":
362373
n_layer = 2 # test only 2 layer models
363374
else:
364-
n_layer = 1
375+
n_layer = 2
376+
377+
check_non_hf_kv_vs_ort_vs_ai100(model_name=model_name, n_layer=n_layer)
365378

366-
qnn_config_json_path = os.path.join(os.getcwd(), "qnn_config.json")
367-
create_json(qnn_config_json_path, QnnConstants.QNN_SAMPLE_CONFIG)
379+
380+
@pytest.mark.on_qaic
368381
@pytest.mark.parametrize("model_name", spd_test_models)
369382
def test_causal_tlm_pytorch_vs_kv_vs_ort_vs_ai100(model_name):
370383
"""
@@ -391,3 +404,12 @@ def test_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_pl1():
391404
prompt_len = 1
392405

393406
check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, prompt_len=prompt_len)
407+
408+
409+
@pytest.mark.on_qaic
410+
def test_prefiill_only_pytorch_vs_kv_vs_ort_vs_ai100():
411+
model_name = "gpt2"
412+
n_layer = 1
413+
check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, n_layer=n_layer, prefill_only=True)
414+
415+
check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, n_layer=n_layer, prefill_only=False)

0 commit comments

Comments
 (0)