@@ -605,6 +605,8 @@ def compile(
605
605
num_speculative_tokens : Optional [int ] = None ,
606
606
enable_qnn : bool = False ,
607
607
qnn_config : Optional [str ] = None ,
608
+ skip_vision : Optional [bool ] = False ,
609
+ skip_lang : Optional [bool ] = False ,
608
610
** compiler_options ,
609
611
) -> str :
610
612
if any (param is not None for param in [full_batch_size , kv_cache_batch_size , num_speculative_tokens ]):
@@ -646,17 +648,18 @@ def compile(
646
648
):
647
649
self .export ()
648
650
649
- self .vision_model ._compile (
650
- compile_dir ,
651
- compile_only = True ,
652
- specializations = specializations ["vision" ],
653
- convert_to_fp16 = True ,
654
- mxfp6_matmul = mxfp6_matmul ,
655
- mdp_ts_num_devices = num_devices ,
656
- aic_num_cores = num_cores ,
657
- custom_io = custom_io_vision ,
658
- ** compiler_options ,
659
- )
651
+ if not skip_vision :
652
+ self .vision_model ._compile (
653
+ compile_dir ,
654
+ compile_only = True ,
655
+ specializations = specializations ["vision" ],
656
+ convert_to_fp16 = True ,
657
+ mxfp6_matmul = mxfp6_matmul ,
658
+ mdp_ts_num_devices = num_devices ,
659
+ aic_num_cores = num_cores ,
660
+ custom_io = custom_io_vision ,
661
+ ** compiler_options ,
662
+ )
660
663
661
664
custom_io_lang = {}
662
665
# Inputs
@@ -669,18 +672,18 @@ def compile(
669
672
if output_name .endswith ("_RetainedState" ):
670
673
custom_io_lang [output_name ] = kv_cache_dtype
671
674
672
- self .lang_model ._compile (
673
- compile_dir ,
674
- compile_only = True ,
675
- retained_state = True ,
676
- specializations = specializations ["lang" ],
677
- convert_to_fp16 = True ,
678
- mxfp6_matmul = mxfp6_matmul ,
679
- mdp_ts_num_devices = num_devices ,
680
- aic_num_cores = num_cores ,
681
- custom_io = custom_io_lang ,
682
- ** compiler_options ,
683
- )
675
+ self .lang_model ._compile (
676
+ compile_dir ,
677
+ compile_only = True ,
678
+ retained_state = True ,
679
+ specializations = specializations ["lang" ],
680
+ convert_to_fp16 = True ,
681
+ mxfp6_matmul = mxfp6_matmul ,
682
+ mdp_ts_num_devices = num_devices ,
683
+ aic_num_cores = num_cores ,
684
+ custom_io = custom_io_lang ,
685
+ ** compiler_options ,
686
+ )
684
687
return self .qpc_path
685
688
686
689
def generate (
@@ -1539,6 +1542,7 @@ def compile(
1539
1542
num_speculative_tokens : Optional [int ] = None ,
1540
1543
enable_qnn : bool = False ,
1541
1544
qnn_config : Optional [str ] = None ,
1545
+ prefill_only : Optional [bool ] = None ,
1542
1546
** compiler_options ,
1543
1547
) -> str :
1544
1548
"""
@@ -1562,6 +1566,8 @@ def compile(
1562
1566
:aic_enable_depth_first (bool, optional): Enables DFS with default memory size. ``Defaults to False``.
1563
1567
:enable_qnn (bool): Enables QNN Compilation. ``Defaults to False.``
1564
1568
:qnn_config (str): Path of QNN Config parameters file. ``Defaults to None.``
1569
+ :prefill_only (bool): if ``True`` compile for prefill only and if ``False`` compile for decode only. Defaults to None, which compiles for both ``prefill and ``decode``.
1570
+ :compiler_options (dict, optional): Any other options that the `qaic-exec` takes. ``Defaults to None``.
1565
1571
1566
1572
Returns:
1567
1573
:str: Path of the compiled ``qpc`` package.
@@ -1583,48 +1589,33 @@ def compile(
1583
1589
"enable `continuous_batching=True` in `from_pretrained`."
1584
1590
)
1585
1591
1586
- kv_cache_batch_size = (
1587
- kv_cache_batch_size if kv_cache_batch_size else (full_batch_size if full_batch_size else batch_size )
1588
- )
1589
- # Define prefill specialization
1590
- prefill_specialization = {
1591
- # Prefill is always run with single BS for continuous batching.
1592
- "batch_size" : 1 if self .continuous_batching else batch_size ,
1593
- "seq_len" : prefill_seq_len ,
1594
- "ctx_len" : ctx_len ,
1595
- # TODO: should be renamed to kv_cache_batch_size in specialization too
1596
- }
1597
- prefill_specialization .update ({"num_logits_to_keep" : 1 }) if self .is_tlm else ...
1598
- if self .continuous_batching :
1599
- prefill_specialization .update ({"full_batch_size" : kv_cache_batch_size })
1600
- else :
1601
- prefill_specialization .update ({"batch_size" : kv_cache_batch_size })
1602
- prefill_specialization .update ({"full_batch_exec_size" : full_batch_size }) if full_batch_size else ...
1603
- specializations = [
1604
- prefill_specialization ,
1605
- ]
1592
+ # Infer kv_cache_batch_size if not provided
1593
+ kv_cache_batch_size = kv_cache_batch_size or full_batch_size or batch_size
1606
1594
1607
- # Skip decode specialization if we are not in continuous batching and prefill_seq_len=1 as this repeats prefill specialization
1608
- if prefill_seq_len != 1 or self .continuous_batching :
1609
- decode_specialization = {
1610
- "batch_size" : full_batch_size if self .continuous_batching else batch_size ,
1611
- "seq_len" : num_speculative_tokens + 1 if self .is_tlm else 1 ,
1612
- "ctx_len" : ctx_len ,
1613
- }
1614
- if self .continuous_batching :
1615
- decode_specialization .update ({"full_batch_size" : kv_cache_batch_size })
1616
- else :
1617
- decode_specialization .update ({"batch_size" : kv_cache_batch_size })
1618
- decode_specialization .update ({"num_logits_to_keep" : num_speculative_tokens + 1 }) if self .is_tlm else ...
1619
- specializations .append (decode_specialization )
1595
+ # --- Specializations ---
1596
+ specializations = []
1597
+
1598
+ if prefill_only is None or prefill_only or prefill_seq_len == 1 :
1599
+ specializations .append (
1600
+ self .build_prefill_specialization (
1601
+ prefill_seq_len , ctx_len , batch_size , kv_cache_batch_size , full_batch_size
1602
+ )
1603
+ )
1604
+ if prefill_only is None or not prefill_only :
1605
+ decode_spec = self .build_decode_specialization (
1606
+ prefill_seq_len , ctx_len , batch_size , kv_cache_batch_size , full_batch_size , num_speculative_tokens
1607
+ )
1608
+ if decode_spec :
1609
+ specializations .append (decode_spec )
1620
1610
1611
+ # --- Compilation ---
1621
1612
if enable_qnn :
1622
1613
if compiler_options :
1623
- logger .warning ("Extra arguments to QNN compilation are supported via qnn_config.json only " )
1614
+ logger .warning ("Extra arguments to QNN compilation are ignored. Use ` qnn_config.json`. " )
1624
1615
1625
1616
qpc_path = self ._qnn_compile (
1626
- onnx_path ,
1627
- compile_dir ,
1617
+ onnx_path = onnx_path ,
1618
+ compile_dir = compile_dir ,
1628
1619
specializations = specializations ,
1629
1620
prefill_seq_len = prefill_seq_len ,
1630
1621
ctx_len = ctx_len ,
@@ -1638,17 +1629,17 @@ def compile(
1638
1629
kv_cache_batch_size = kv_cache_batch_size ,
1639
1630
)
1640
1631
else :
1641
- # Custom IO
1642
- custom_io = {}
1643
1632
kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16"
1633
+ custom_io = {}
1634
+
1644
1635
for suffix in ["" , "_RetainedState" ]:
1645
1636
for i in range (self .num_layers ):
1646
1637
for kv in ["key" , "value" ]:
1647
1638
custom_io [f"past_{ kv } .{ i } { suffix } " ] = kv_cache_dtype
1648
1639
1649
1640
qpc_path = self ._compile (
1650
- onnx_path ,
1651
- compile_dir ,
1641
+ onnx_path = onnx_path ,
1642
+ compile_dir = compile_dir ,
1652
1643
compile_only = True ,
1653
1644
retained_state = True ,
1654
1645
specializations = specializations ,
@@ -1660,6 +1651,7 @@ def compile(
1660
1651
aic_num_cores = num_cores ,
1661
1652
** compiler_options ,
1662
1653
)
1654
+
1663
1655
return qpc_path
1664
1656
1665
1657
# FIXME: Update this method to match with transformers AutoModelForCausalLM.generate
@@ -1867,22 +1859,8 @@ def compile(
1867
1859
if num_speculative_tokens :
1868
1860
logger .warning ("Speculative decoding is not yet enabled for AutoModelForSpeechSeq2Seq" )
1869
1861
1870
- output_names = self .model .get_output_names ()
1871
-
1872
- kv_cache_dtype = "float16"
1873
- custom_io = {}
1874
-
1875
- custom_io ["input_features" ] = kv_cache_dtype
1876
-
1877
- # Slice output_names to get input names
1878
- for output_name in output_names :
1879
- if output_name .endswith ("_RetainedState" ):
1880
- custom_io [output_name [: - len ("_RetainedState" )]] = kv_cache_dtype
1881
-
1882
- # Get output names
1883
- for output_name in output_names :
1884
- if output_name .endswith ("_RetainedState" ):
1885
- custom_io [output_name ] = kv_cache_dtype
1862
+ if enable_qnn or qnn_config :
1863
+ logger .warning ("QNN compile is not yet enabled for AutoModelForSpeechSeq2Seq" )
1886
1864
1887
1865
return self ._compile (
1888
1866
onnx_path ,
0 commit comments