Skip to content

Commit 8f0c080

Browse files
quic-amitrajeplatero97
authored andcommitted
Disaggregated serving (quic#365)
Adding support of- 1. `prefill_only` 2. `compile_for` for VLM 3. `mdp_ts_json_path` --------- Signed-off-by: Rishin Raj <quic_rishinr@quicinc.com> Signed-off-by: Amit Raj <quic_amitraj@quicinc.com> Signed-off-by: Onkar Chougule <quic_ochougul@quicinc.com> Signed-off-by: Onkar Chougule <168134249+ochougul@users.noreply.github.com> Co-authored-by: Rishin Raj <quic_rishinr@quicinc.com> Co-authored-by: Onkar Chougule <quic_ochougul@quicinc.com> Co-authored-by: Onkar Chougule <168134249+ochougul@users.noreply.github.com> Signed-off-by: eplatero <quic_eplatero@quicinc.com>
1 parent db910f3 commit 8f0c080

File tree

3 files changed

+113
-250
lines changed

3 files changed

+113
-250
lines changed

QEfficient/base/modeling_qeff.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -254,22 +254,6 @@ def _compile(
254254
qpc_path = compile_dir / "qpc"
255255
if not onnx_path.is_file():
256256
raise FileNotFoundError(f"ONNX file not found at: {onnx_path}")
257-
258-
if enable_qnn:
259-
self.qpc_path = qnn_compile(
260-
onnx_path=onnx_path,
261-
qpc_base_path=compile_dir,
262-
specializations=specializations,
263-
custom_io=custom_io,
264-
device_group=list(range(mdp_ts_num_devices)),
265-
num_cores=compiler_options.get("aic_num_cores", 16),
266-
mxfp6=compiler_options.get("mxfp6_matmul", False),
267-
mxint8=mxint8_kv_cache,
268-
qnn_config=qnn_config,
269-
)
270-
271-
return self.qpc_path
272-
273257
command = constants.COMPILER + [f"-m={onnx_path}"]
274258
if mdp_ts_json_path := compiler_options.pop("mdp_ts_json_path", None):
275259
mdp_ts_num_devices = None

QEfficient/transformers/models/modeling_auto.py

Lines changed: 105 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -603,8 +603,8 @@ def compile(
603603
mxfp6_matmul: bool = False,
604604
mxint8_kv_cache: bool = False,
605605
num_speculative_tokens: Optional[int] = None,
606-
skip_vision: Optional[bool] = False,
607-
skip_lang: Optional[bool] = False,
606+
enable_qnn: bool = False,
607+
qnn_config: Optional[str] = None,
608608
**compiler_options,
609609
) -> str:
610610
if any(param is not None for param in [full_batch_size, kv_cache_batch_size, num_speculative_tokens]):
@@ -643,19 +643,17 @@ def compile(
643643
):
644644
self.export()
645645

646-
if not skip_vision:
647-
self.vision_model._compile(
648-
compile_dir,
649-
compile_only=True,
650-
specializations=specializations["vision"],
651-
convert_to_fp16=True,
652-
mxfp6_matmul=mxfp6_matmul,
653-
mdp_ts_num_devices=num_devices,
654-
aic_num_cores=num_cores,
655-
custom_io=custom_io_vision,
656-
mxint8_kv_cache=mxint8_kv_cache,
657-
**compiler_options,
658-
)
646+
self.vision_model._compile(
647+
compile_dir,
648+
compile_only=True,
649+
specializations=specializations["vision"],
650+
convert_to_fp16=True,
651+
mxfp6_matmul=mxfp6_matmul,
652+
mdp_ts_num_devices=num_devices,
653+
aic_num_cores=num_cores,
654+
custom_io=custom_io_vision,
655+
**compiler_options,
656+
)
659657

660658
if not skip_lang:
661659
custom_io_lang = {}
@@ -681,6 +679,7 @@ def compile(
681679
custom_io=custom_io_lang,
682680
**compiler_options,
683681
)
682+
return self.qpc_path
684683

685684
def generate(
686685
self,
@@ -1533,7 +1532,8 @@ def compile(
15331532
mxfp6_matmul: bool = False,
15341533
mxint8_kv_cache: bool = False,
15351534
num_speculative_tokens: Optional[int] = None,
1536-
prefill_only: Optional[bool] = None,
1535+
enable_qnn: bool = False,
1536+
qnn_config: Optional[str] = None,
15371537
**compiler_options,
15381538
) -> str:
15391539
"""
@@ -1555,14 +1555,8 @@ def compile(
15551555
:num_speculative_tokens (int, optional): Number of speculative tokens to take as input for Speculative Decoding Target Language Model.
15561556
:mos (int, optional): Effort level to reduce on-chip memory. Defaults to -1, meaning no effort. ``Defaults to -1``.
15571557
:aic_enable_depth_first (bool, optional): Enables DFS with default memory size. ``Defaults to False``.
1558-
:prefill_only (bool): if ``True`` compile for prefill only and if ``False`` compile for decode only. Defaults to None, which compiles for both ``prefill and ``decode``.
1559-
:compiler_options (dict, optional): Pass any compiler option as input. ``Defaults to None``.
1560-
Following flag can be passed in compiler_options to enable QNN Compilation path.
1561-
:enable_qnn (bool): Enables QNN Compilation. ``Defaults to False. if not passed.``
1562-
:qnn_config (str): Path of QNN Config parameters file. ``Defaults to None. if not passed``
1563-
for QAIC compilation path, any flag that is supported by ``qaic-exec`` can be passed. Params are converted to flags as below:
1564-
- aic_num_cores=16 -> -aic-num-cores=16
1565-
- convert_to_fp16=True -> -convert-to-fp16
1558+
:enable_qnn (bool): Enables QNN Compilation. ``Defaults to False.``
1559+
:qnn_config (str): Path of QNN Config parameters file. ``Defaults to None.``
15661560
15671561
Returns:
15681562
:str: Path of the compiled ``qpc`` package.
@@ -1572,9 +1566,19 @@ def compile(
15721566
raise TypeError("`prefill_only` must be a boolean.")
15731567

15741568
if self.is_tlm:
1575-
num_speculative_tokens: int = self.check_and_get_num_speculative_tokens(
1576-
num_speculative_tokens, prefill_seq_len
1577-
)
1569+
# assert num_speculative_tokens cfg is acceptable if defined
1570+
if num_speculative_tokens is None:
1571+
raise TypeError("missing required argument `num_speculative_tokens` as `is_tlm` is True.")
1572+
if not isinstance(num_speculative_tokens, int) and num_speculative_tokens < 2:
1573+
ValueError(
1574+
f"`num_speculative_tokens` arg should be an integer greater than 1, got {num_speculative_tokens}"
1575+
)
1576+
num_logits_to_keep = num_speculative_tokens + 1
1577+
if prefill_seq_len < num_logits_to_keep:
1578+
raise ValueError(
1579+
f"sequence length ({prefill_seq_len}) must be at least `num_speculative_tokens+1` ({num_logits_to_keep})"
1580+
)
1581+
15781582
if self.continuous_batching and full_batch_size is None:
15791583
raise TypeError("`full_batch_size` is required when `continuous_batching=True`.")
15801584

@@ -1584,50 +1588,83 @@ def compile(
15841588
"enable `continuous_batching=True` in `from_pretrained`."
15851589
)
15861590

1587-
# Infer kv_cache_batch_size if not provided
1588-
kv_cache_batch_size = kv_cache_batch_size or full_batch_size or batch_size
1589-
1590-
# --- Specializations ---
1591-
specializations = []
1591+
kv_cache_batch_size = (
1592+
kv_cache_batch_size if kv_cache_batch_size else (full_batch_size if full_batch_size else batch_size)
1593+
)
1594+
# Define prefill specialization
1595+
prefill_specialization = {
1596+
# Prefill is always run with single BS for continuous batching.
1597+
"batch_size": 1 if self.continuous_batching else batch_size,
1598+
"seq_len": prefill_seq_len,
1599+
"ctx_len": ctx_len,
1600+
# TODO: should be renamed to kv_cache_batch_size in specialization too
1601+
}
1602+
prefill_specialization.update({"num_logits_to_keep": 1}) if self.is_tlm else ...
1603+
if self.continuous_batching:
1604+
prefill_specialization.update({"full_batch_size": kv_cache_batch_size})
1605+
else:
1606+
prefill_specialization.update({"batch_size": kv_cache_batch_size})
1607+
prefill_specialization.update({"full_batch_exec_size": full_batch_size}) if full_batch_size else ...
1608+
specializations = [
1609+
prefill_specialization,
1610+
]
15921611

1593-
if prefill_only is None or prefill_only or prefill_seq_len == 1:
1594-
specializations.append(
1595-
self.build_prefill_specialization(
1596-
prefill_seq_len, ctx_len, batch_size, kv_cache_batch_size, full_batch_size
1597-
)
1612+
# Skip decode specialization if we are not in continuous batching and prefill_seq_len=1 as this repeats prefill specialization
1613+
if prefill_seq_len != 1 or self.continuous_batching:
1614+
decode_specialization = {
1615+
"batch_size": full_batch_size if self.continuous_batching else batch_size,
1616+
"seq_len": num_speculative_tokens + 1 if self.is_tlm else 1,
1617+
"ctx_len": ctx_len,
1618+
}
1619+
if self.continuous_batching:
1620+
decode_specialization.update({"full_batch_size": kv_cache_batch_size})
1621+
else:
1622+
decode_specialization.update({"batch_size": kv_cache_batch_size})
1623+
decode_specialization.update({"num_logits_to_keep": num_speculative_tokens + 1}) if self.is_tlm else ...
1624+
specializations.append(decode_specialization)
1625+
1626+
if enable_qnn:
1627+
if compiler_options:
1628+
logger.warning("Extra arguments to QNN compilation are supported via qnn_config.json only")
1629+
1630+
qpc_path = self._qnn_compile(
1631+
onnx_path,
1632+
compile_dir,
1633+
specializations=specializations,
1634+
prefill_seq_len=prefill_seq_len,
1635+
ctx_len=ctx_len,
1636+
batch_size=batch_size,
1637+
full_batch_size=full_batch_size,
1638+
mdp_ts_num_devices=num_devices,
1639+
num_cores=num_cores,
1640+
mxfp6_matmul=mxfp6_matmul,
1641+
mxint8_kv_cache=mxint8_kv_cache,
1642+
qnn_config=qnn_config,
1643+
kv_cache_batch_size=kv_cache_batch_size,
15981644
)
1599-
if prefill_only is None or not prefill_only:
1600-
decode_spec = self.build_decode_specialization(
1601-
prefill_seq_len, ctx_len, batch_size, kv_cache_batch_size, full_batch_size, num_speculative_tokens
1645+
else:
1646+
# Custom IO
1647+
custom_io = {}
1648+
kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16"
1649+
for suffix in ["", "_RetainedState"]:
1650+
for i in range(self.num_layers):
1651+
for kv in ["key", "value"]:
1652+
custom_io[f"past_{kv}.{i}{suffix}"] = kv_cache_dtype
1653+
1654+
qpc_path = self._compile(
1655+
onnx_path,
1656+
compile_dir,
1657+
compile_only=True,
1658+
retained_state=True,
1659+
specializations=specializations,
1660+
convert_to_fp16=True,
1661+
mxfp6_matmul=mxfp6_matmul,
1662+
custom_io=custom_io,
1663+
mdp_ts_num_devices=num_devices,
1664+
num_speculative_tokens=num_speculative_tokens,
1665+
aic_num_cores=num_cores,
1666+
**compiler_options,
16021667
)
1603-
if decode_spec:
1604-
specializations.append(decode_spec)
1605-
1606-
# --- Compilation ---
1607-
kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16"
1608-
custom_io = {}
1609-
1610-
for suffix in ["", "_RetainedState"]:
1611-
for i in range(self.num_layers):
1612-
for kv in ["key", "value"]:
1613-
custom_io[f"past_{kv}.{i}{suffix}"] = kv_cache_dtype
1614-
1615-
qpc_path = self._compile(
1616-
onnx_path=onnx_path,
1617-
compile_dir=compile_dir,
1618-
compile_only=True,
1619-
retained_state=True,
1620-
specializations=specializations,
1621-
convert_to_fp16=True,
1622-
mxfp6_matmul=mxfp6_matmul,
1623-
custom_io=custom_io,
1624-
mdp_ts_num_devices=num_devices,
1625-
num_speculative_tokens=num_speculative_tokens,
1626-
aic_num_cores=num_cores,
1627-
mxint8_kv_cache=mxint8_kv_cache,
1628-
**compiler_options,
1629-
)
1630-
16311668
return qpc_path
16321669

16331670
# FIXME: Update this method to match with transformers AutoModelForCausalLM.generate

0 commit comments

Comments
 (0)