@@ -603,8 +603,8 @@ def compile(
603603 mxfp6_matmul : bool = False ,
604604 mxint8_kv_cache : bool = False ,
605605 num_speculative_tokens : Optional [int ] = None ,
606- skip_vision : Optional [ bool ] = False ,
607- skip_lang : Optional [bool ] = False ,
606+ enable_qnn : bool = False ,
607+ qnn_config : Optional [str ] = None ,
608608 ** compiler_options ,
609609 ) -> str :
610610 if any (param is not None for param in [full_batch_size , kv_cache_batch_size , num_speculative_tokens ]):
@@ -643,19 +643,17 @@ def compile(
643643 ):
644644 self .export ()
645645
646- if not skip_vision :
647- self .vision_model ._compile (
648- compile_dir ,
649- compile_only = True ,
650- specializations = specializations ["vision" ],
651- convert_to_fp16 = True ,
652- mxfp6_matmul = mxfp6_matmul ,
653- mdp_ts_num_devices = num_devices ,
654- aic_num_cores = num_cores ,
655- custom_io = custom_io_vision ,
656- mxint8_kv_cache = mxint8_kv_cache ,
657- ** compiler_options ,
658- )
646+ self .vision_model ._compile (
647+ compile_dir ,
648+ compile_only = True ,
649+ specializations = specializations ["vision" ],
650+ convert_to_fp16 = True ,
651+ mxfp6_matmul = mxfp6_matmul ,
652+ mdp_ts_num_devices = num_devices ,
653+ aic_num_cores = num_cores ,
654+ custom_io = custom_io_vision ,
655+ ** compiler_options ,
656+ )
659657
660658 if not skip_lang :
661659 custom_io_lang = {}
@@ -681,6 +679,7 @@ def compile(
681679 custom_io = custom_io_lang ,
682680 ** compiler_options ,
683681 )
682+ return self .qpc_path
684683
685684 def generate (
686685 self ,
@@ -1533,7 +1532,8 @@ def compile(
15331532 mxfp6_matmul : bool = False ,
15341533 mxint8_kv_cache : bool = False ,
15351534 num_speculative_tokens : Optional [int ] = None ,
1536- prefill_only : Optional [bool ] = None ,
1535+ enable_qnn : bool = False ,
1536+ qnn_config : Optional [str ] = None ,
15371537 ** compiler_options ,
15381538 ) -> str :
15391539 """
@@ -1555,14 +1555,8 @@ def compile(
15551555 :num_speculative_tokens (int, optional): Number of speculative tokens to take as input for Speculative Decoding Target Language Model.
15561556 :mos (int, optional): Effort level to reduce on-chip memory. Defaults to -1, meaning no effort. ``Defaults to -1``.
15571557 :aic_enable_depth_first (bool, optional): Enables DFS with default memory size. ``Defaults to False``.
1558- :prefill_only (bool): if ``True`` compile for prefill only and if ``False`` compile for decode only. Defaults to None, which compiles for both ``prefill and ``decode``.
1559- :compiler_options (dict, optional): Pass any compiler option as input. ``Defaults to None``.
1560- Following flag can be passed in compiler_options to enable QNN Compilation path.
1561- :enable_qnn (bool): Enables QNN Compilation. ``Defaults to False. if not passed.``
1562- :qnn_config (str): Path of QNN Config parameters file. ``Defaults to None. if not passed``
1563- for QAIC compilation path, any flag that is supported by ``qaic-exec`` can be passed. Params are converted to flags as below:
1564- - aic_num_cores=16 -> -aic-num-cores=16
1565- - convert_to_fp16=True -> -convert-to-fp16
1558+ :enable_qnn (bool): Enables QNN Compilation. ``Defaults to False.``
1559+ :qnn_config (str): Path of QNN Config parameters file. ``Defaults to None.``
15661560
15671561 Returns:
15681562 :str: Path of the compiled ``qpc`` package.
@@ -1572,9 +1566,19 @@ def compile(
15721566 raise TypeError ("`prefill_only` must be a boolean." )
15731567
15741568 if self .is_tlm :
1575- num_speculative_tokens : int = self .check_and_get_num_speculative_tokens (
1576- num_speculative_tokens , prefill_seq_len
1577- )
1569+ # assert num_speculative_tokens cfg is acceptable if defined
1570+ if num_speculative_tokens is None :
1571+ raise TypeError ("missing required argument `num_speculative_tokens` as `is_tlm` is True." )
1572+ if not isinstance (num_speculative_tokens , int ) and num_speculative_tokens < 2 :
1573+ ValueError (
1574+ f"`num_speculative_tokens` arg should be an integer greater than 1, got { num_speculative_tokens } "
1575+ )
1576+ num_logits_to_keep = num_speculative_tokens + 1
1577+ if prefill_seq_len < num_logits_to_keep :
1578+ raise ValueError (
1579+ f"sequence length ({ prefill_seq_len } ) must be at least `num_speculative_tokens+1` ({ num_logits_to_keep } )"
1580+ )
1581+
15781582 if self .continuous_batching and full_batch_size is None :
15791583 raise TypeError ("`full_batch_size` is required when `continuous_batching=True`." )
15801584
@@ -1584,50 +1588,83 @@ def compile(
15841588 "enable `continuous_batching=True` in `from_pretrained`."
15851589 )
15861590
1587- # Infer kv_cache_batch_size if not provided
1588- kv_cache_batch_size = kv_cache_batch_size or full_batch_size or batch_size
1589-
1590- # --- Specializations ---
1591- specializations = []
1591+ kv_cache_batch_size = (
1592+ kv_cache_batch_size if kv_cache_batch_size else (full_batch_size if full_batch_size else batch_size )
1593+ )
1594+ # Define prefill specialization
1595+ prefill_specialization = {
1596+ # Prefill is always run with single BS for continuous batching.
1597+ "batch_size" : 1 if self .continuous_batching else batch_size ,
1598+ "seq_len" : prefill_seq_len ,
1599+ "ctx_len" : ctx_len ,
1600+ # TODO: should be renamed to kv_cache_batch_size in specialization too
1601+ }
1602+ prefill_specialization .update ({"num_logits_to_keep" : 1 }) if self .is_tlm else ...
1603+ if self .continuous_batching :
1604+ prefill_specialization .update ({"full_batch_size" : kv_cache_batch_size })
1605+ else :
1606+ prefill_specialization .update ({"batch_size" : kv_cache_batch_size })
1607+ prefill_specialization .update ({"full_batch_exec_size" : full_batch_size }) if full_batch_size else ...
1608+ specializations = [
1609+ prefill_specialization ,
1610+ ]
15921611
1593- if prefill_only is None or prefill_only or prefill_seq_len == 1 :
1594- specializations .append (
1595- self .build_prefill_specialization (
1596- prefill_seq_len , ctx_len , batch_size , kv_cache_batch_size , full_batch_size
1597- )
1612+ # Skip decode specialization if we are not in continuous batching and prefill_seq_len=1 as this repeats prefill specialization
1613+ if prefill_seq_len != 1 or self .continuous_batching :
1614+ decode_specialization = {
1615+ "batch_size" : full_batch_size if self .continuous_batching else batch_size ,
1616+ "seq_len" : num_speculative_tokens + 1 if self .is_tlm else 1 ,
1617+ "ctx_len" : ctx_len ,
1618+ }
1619+ if self .continuous_batching :
1620+ decode_specialization .update ({"full_batch_size" : kv_cache_batch_size })
1621+ else :
1622+ decode_specialization .update ({"batch_size" : kv_cache_batch_size })
1623+ decode_specialization .update ({"num_logits_to_keep" : num_speculative_tokens + 1 }) if self .is_tlm else ...
1624+ specializations .append (decode_specialization )
1625+
1626+ if enable_qnn :
1627+ if compiler_options :
1628+ logger .warning ("Extra arguments to QNN compilation are supported via qnn_config.json only" )
1629+
1630+ qpc_path = self ._qnn_compile (
1631+ onnx_path ,
1632+ compile_dir ,
1633+ specializations = specializations ,
1634+ prefill_seq_len = prefill_seq_len ,
1635+ ctx_len = ctx_len ,
1636+ batch_size = batch_size ,
1637+ full_batch_size = full_batch_size ,
1638+ mdp_ts_num_devices = num_devices ,
1639+ num_cores = num_cores ,
1640+ mxfp6_matmul = mxfp6_matmul ,
1641+ mxint8_kv_cache = mxint8_kv_cache ,
1642+ qnn_config = qnn_config ,
1643+ kv_cache_batch_size = kv_cache_batch_size ,
15981644 )
1599- if prefill_only is None or not prefill_only :
1600- decode_spec = self .build_decode_specialization (
1601- prefill_seq_len , ctx_len , batch_size , kv_cache_batch_size , full_batch_size , num_speculative_tokens
1645+ else :
1646+ # Custom IO
1647+ custom_io = {}
1648+ kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16"
1649+ for suffix in ["" , "_RetainedState" ]:
1650+ for i in range (self .num_layers ):
1651+ for kv in ["key" , "value" ]:
1652+ custom_io [f"past_{ kv } .{ i } { suffix } " ] = kv_cache_dtype
1653+
1654+ qpc_path = self ._compile (
1655+ onnx_path ,
1656+ compile_dir ,
1657+ compile_only = True ,
1658+ retained_state = True ,
1659+ specializations = specializations ,
1660+ convert_to_fp16 = True ,
1661+ mxfp6_matmul = mxfp6_matmul ,
1662+ custom_io = custom_io ,
1663+ mdp_ts_num_devices = num_devices ,
1664+ num_speculative_tokens = num_speculative_tokens ,
1665+ aic_num_cores = num_cores ,
1666+ ** compiler_options ,
16021667 )
1603- if decode_spec :
1604- specializations .append (decode_spec )
1605-
1606- # --- Compilation ---
1607- kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16"
1608- custom_io = {}
1609-
1610- for suffix in ["" , "_RetainedState" ]:
1611- for i in range (self .num_layers ):
1612- for kv in ["key" , "value" ]:
1613- custom_io [f"past_{ kv } .{ i } { suffix } " ] = kv_cache_dtype
1614-
1615- qpc_path = self ._compile (
1616- onnx_path = onnx_path ,
1617- compile_dir = compile_dir ,
1618- compile_only = True ,
1619- retained_state = True ,
1620- specializations = specializations ,
1621- convert_to_fp16 = True ,
1622- mxfp6_matmul = mxfp6_matmul ,
1623- custom_io = custom_io ,
1624- mdp_ts_num_devices = num_devices ,
1625- num_speculative_tokens = num_speculative_tokens ,
1626- aic_num_cores = num_cores ,
1627- mxint8_kv_cache = mxint8_kv_cache ,
1628- ** compiler_options ,
1629- )
1630-
16311668 return qpc_path
16321669
16331670 # FIXME: Update this method to match with transformers AutoModelForCausalLM.generate
0 commit comments