@@ -605,6 +605,8 @@ def compile(
605
605
num_speculative_tokens : Optional [int ] = None ,
606
606
enable_qnn : bool = False ,
607
607
qnn_config : Optional [str ] = None ,
608
+ skip_vision : Optional [bool ] = False ,
609
+ skip_lang : Optional [bool ] = False ,
608
610
** compiler_options ,
609
611
) -> str :
610
612
if any (param is not None for param in [full_batch_size , kv_cache_batch_size , num_speculative_tokens ]):
@@ -643,17 +645,18 @@ def compile(
643
645
):
644
646
self .export ()
645
647
646
- self .vision_model ._compile (
647
- compile_dir ,
648
- compile_only = True ,
649
- specializations = specializations ["vision" ],
650
- convert_to_fp16 = True ,
651
- mxfp6_matmul = mxfp6_matmul ,
652
- mdp_ts_num_devices = num_devices ,
653
- aic_num_cores = num_cores ,
654
- custom_io = custom_io_vision ,
655
- ** compiler_options ,
656
- )
648
+ if not skip_vision :
649
+ self .vision_model ._compile (
650
+ compile_dir ,
651
+ compile_only = True ,
652
+ specializations = specializations ["vision" ],
653
+ convert_to_fp16 = True ,
654
+ mxfp6_matmul = mxfp6_matmul ,
655
+ mdp_ts_num_devices = num_devices ,
656
+ aic_num_cores = num_cores ,
657
+ custom_io = custom_io_vision ,
658
+ ** compiler_options ,
659
+ )
657
660
658
661
if not skip_lang :
659
662
custom_io_lang = {}
@@ -667,18 +670,18 @@ def compile(
667
670
if output_name .endswith ("_RetainedState" ):
668
671
custom_io_lang [output_name ] = kv_cache_dtype
669
672
670
- self .lang_model ._compile (
671
- compile_dir ,
672
- compile_only = True ,
673
- retained_state = True ,
674
- specializations = specializations ["lang" ],
675
- convert_to_fp16 = True ,
676
- mxfp6_matmul = mxfp6_matmul ,
677
- mdp_ts_num_devices = num_devices ,
678
- aic_num_cores = num_cores ,
679
- custom_io = custom_io_lang ,
680
- ** compiler_options ,
681
- )
673
+ self .lang_model ._compile (
674
+ compile_dir ,
675
+ compile_only = True ,
676
+ retained_state = True ,
677
+ specializations = specializations ["lang" ],
678
+ convert_to_fp16 = True ,
679
+ mxfp6_matmul = mxfp6_matmul ,
680
+ mdp_ts_num_devices = num_devices ,
681
+ aic_num_cores = num_cores ,
682
+ custom_io = custom_io_lang ,
683
+ ** compiler_options ,
684
+ )
682
685
return self .qpc_path
683
686
684
687
def generate (
@@ -1534,6 +1537,7 @@ def compile(
1534
1537
num_speculative_tokens : Optional [int ] = None ,
1535
1538
enable_qnn : bool = False ,
1536
1539
qnn_config : Optional [str ] = None ,
1540
+ prefill_only : Optional [bool ] = None ,
1537
1541
** compiler_options ,
1538
1542
) -> str :
1539
1543
"""
@@ -1557,6 +1561,8 @@ def compile(
1557
1561
:aic_enable_depth_first (bool, optional): Enables DFS with default memory size. ``Defaults to False``.
1558
1562
:enable_qnn (bool): Enables QNN Compilation. ``Defaults to False.``
1559
1563
:qnn_config (str): Path of QNN Config parameters file. ``Defaults to None.``
1564
+ :prefill_only (bool): if ``True`` compile for prefill only and if ``False`` compile for decode only. Defaults to None, which compiles for both ``prefill and ``decode``.
1565
+ :compiler_options (dict, optional): Any other options that the `qaic-exec` takes. ``Defaults to None``.
1560
1566
1561
1567
Returns:
1562
1568
:str: Path of the compiled ``qpc`` package.
@@ -1588,48 +1594,33 @@ def compile(
1588
1594
"enable `continuous_batching=True` in `from_pretrained`."
1589
1595
)
1590
1596
1591
- kv_cache_batch_size = (
1592
- kv_cache_batch_size if kv_cache_batch_size else (full_batch_size if full_batch_size else batch_size )
1593
- )
1594
- # Define prefill specialization
1595
- prefill_specialization = {
1596
- # Prefill is always run with single BS for continuous batching.
1597
- "batch_size" : 1 if self .continuous_batching else batch_size ,
1598
- "seq_len" : prefill_seq_len ,
1599
- "ctx_len" : ctx_len ,
1600
- # TODO: should be renamed to kv_cache_batch_size in specialization too
1601
- }
1602
- prefill_specialization .update ({"num_logits_to_keep" : 1 }) if self .is_tlm else ...
1603
- if self .continuous_batching :
1604
- prefill_specialization .update ({"full_batch_size" : kv_cache_batch_size })
1605
- else :
1606
- prefill_specialization .update ({"batch_size" : kv_cache_batch_size })
1607
- prefill_specialization .update ({"full_batch_exec_size" : full_batch_size }) if full_batch_size else ...
1608
- specializations = [
1609
- prefill_specialization ,
1610
- ]
1597
+ # Infer kv_cache_batch_size if not provided
1598
+ kv_cache_batch_size = kv_cache_batch_size or full_batch_size or batch_size
1611
1599
1612
- # Skip decode specialization if we are not in continuous batching and prefill_seq_len=1 as this repeats prefill specialization
1613
- if prefill_seq_len != 1 or self .continuous_batching :
1614
- decode_specialization = {
1615
- "batch_size" : full_batch_size if self .continuous_batching else batch_size ,
1616
- "seq_len" : num_speculative_tokens + 1 if self .is_tlm else 1 ,
1617
- "ctx_len" : ctx_len ,
1618
- }
1619
- if self .continuous_batching :
1620
- decode_specialization .update ({"full_batch_size" : kv_cache_batch_size })
1621
- else :
1622
- decode_specialization .update ({"batch_size" : kv_cache_batch_size })
1623
- decode_specialization .update ({"num_logits_to_keep" : num_speculative_tokens + 1 }) if self .is_tlm else ...
1624
- specializations .append (decode_specialization )
1600
+ # --- Specializations ---
1601
+ specializations = []
1602
+
1603
+ if prefill_only is None or prefill_only or prefill_seq_len == 1 :
1604
+ specializations .append (
1605
+ self .build_prefill_specialization (
1606
+ prefill_seq_len , ctx_len , batch_size , kv_cache_batch_size , full_batch_size
1607
+ )
1608
+ )
1609
+ if prefill_only is None or not prefill_only :
1610
+ decode_spec = self .build_decode_specialization (
1611
+ prefill_seq_len , ctx_len , batch_size , kv_cache_batch_size , full_batch_size , num_speculative_tokens
1612
+ )
1613
+ if decode_spec :
1614
+ specializations .append (decode_spec )
1625
1615
1616
+ # --- Compilation ---
1626
1617
if enable_qnn :
1627
1618
if compiler_options :
1628
- logger .warning ("Extra arguments to QNN compilation are supported via qnn_config.json only " )
1619
+ logger .warning ("Extra arguments to QNN compilation are ignored. Use ` qnn_config.json`. " )
1629
1620
1630
1621
qpc_path = self ._qnn_compile (
1631
- onnx_path ,
1632
- compile_dir ,
1622
+ onnx_path = onnx_path ,
1623
+ compile_dir = compile_dir ,
1633
1624
specializations = specializations ,
1634
1625
prefill_seq_len = prefill_seq_len ,
1635
1626
ctx_len = ctx_len ,
@@ -1643,17 +1634,17 @@ def compile(
1643
1634
kv_cache_batch_size = kv_cache_batch_size ,
1644
1635
)
1645
1636
else :
1646
- # Custom IO
1647
- custom_io = {}
1648
1637
kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16"
1638
+ custom_io = {}
1639
+
1649
1640
for suffix in ["" , "_RetainedState" ]:
1650
1641
for i in range (self .num_layers ):
1651
1642
for kv in ["key" , "value" ]:
1652
1643
custom_io [f"past_{ kv } .{ i } { suffix } " ] = kv_cache_dtype
1653
1644
1654
1645
qpc_path = self ._compile (
1655
- onnx_path ,
1656
- compile_dir ,
1646
+ onnx_path = onnx_path ,
1647
+ compile_dir = compile_dir ,
1657
1648
compile_only = True ,
1658
1649
retained_state = True ,
1659
1650
specializations = specializations ,
@@ -1665,6 +1656,7 @@ def compile(
1665
1656
aic_num_cores = num_cores ,
1666
1657
** compiler_options ,
1667
1658
)
1659
+
1668
1660
return qpc_path
1669
1661
1670
1662
# FIXME: Update this method to match with transformers AutoModelForCausalLM.generate
@@ -1829,6 +1821,10 @@ def compile(
1829
1821
num_devices : int = 1 ,
1830
1822
num_cores : int = 16 , # FIXME: Make this mandatory arg
1831
1823
mxfp6_matmul : bool = False ,
1824
+ mxint8_kv_cache : bool = False ,
1825
+ num_speculative_tokens : Optional [int ] = None ,
1826
+ enable_qnn : bool = False ,
1827
+ qnn_config : Optional [str ] = None ,
1832
1828
** compiler_options ,
1833
1829
) -> str :
1834
1830
"""
@@ -1851,7 +1847,27 @@ def compile(
1851
1847
Returns:
1852
1848
:str: Path of the compiled ``qpc`` package.
1853
1849
"""
1854
- specializations = self .model .get_specializations (batch_size , encoder_ctx_len , decoder_ctx_len , feature_len )
1850
+ specializations , compiler_options = self .model .get_specializations (
1851
+ batch_size ,
1852
+ encoder_ctx_len ,
1853
+ ctx_len ,
1854
+ ** compiler_options ,
1855
+ )
1856
+
1857
+ if full_batch_size :
1858
+ logger .warning ("Continuous batching is not yet enabled for AutoModelForSpeechSeq2Seq" )
1859
+
1860
+ if kv_cache_batch_size :
1861
+ logger .warning ("Prefix caching is not yet enabled for AutoModelForSpeechSeq2Seq" )
1862
+
1863
+ if mxint8_kv_cache :
1864
+ logger .warning ("mxint8 cache is not yet enabled for AutoModelForSpeechSeq2Seq" )
1865
+
1866
+ if num_speculative_tokens :
1867
+ logger .warning ("Speculative decoding is not yet enabled for AutoModelForSpeechSeq2Seq" )
1868
+
1869
+ if enable_qnn or qnn_config :
1870
+ logger .warning ("QNN compile is not yet enabled for AutoModelForSpeechSeq2Seq" )
1855
1871
1856
1872
return self ._compile (
1857
1873
onnx_path ,
0 commit comments