@@ -579,49 +579,54 @@ def export_llama(
579
579
if isinstance (export_options , argparse .Namespace ):
580
580
# Legacy CLI.
581
581
args = export_options
582
- llm_config = convert_args_to_llm_config (export_options ) # noqa: F841
582
+ llm_config = convert_args_to_llm_config (export_options )
583
583
elif isinstance (export_options , DictConfig ):
584
584
# Hydra CLI.
585
- llm_config = export_options # noqa: F841
585
+ llm_config = export_options
586
+ # Create an args object for backward compatibility during transition
587
+ args = argparse .Namespace ()
588
+ for key , value in llm_config .items ():
589
+ setattr (args , key , value )
586
590
else :
587
591
raise ValueError (
588
592
"Input to export_llama must be either of type argparse.Namespace or LlmConfig"
589
593
)
590
594
591
- # TODO: refactor rest of export_llama to use llm_config instead of args.
592
-
593
595
# If a checkpoint isn't provided for an HF OSS model, download and convert the
594
596
# weights first.
595
- if not args .checkpoint and args .model in HUGGING_FACE_REPO_IDS :
596
- repo_id = HUGGING_FACE_REPO_IDS [args .model ]
597
- if args .model == "qwen2_5" :
597
+ model_name = llm_config .base .model_class
598
+ if not llm_config .base .checkpoint and model_name in HUGGING_FACE_REPO_IDS :
599
+ repo_id = HUGGING_FACE_REPO_IDS [model_name ]
600
+ if model_name == "qwen2_5" :
598
601
from executorch .examples .models .qwen2_5 import ( # pyre-ignore[21]
599
602
convert_weights ,
600
603
)
601
- elif args . model .startswith ("qwen3" ):
604
+ elif model_name .startswith ("qwen3" ):
602
605
from executorch .examples .models .qwen3 import ( # pyre-ignore[21]
603
606
convert_weights ,
604
607
)
605
- elif args . model == "phi_4_mini" :
608
+ elif model_name == "phi_4_mini" :
606
609
from executorch .examples .models .phi_4_mini import ( # pyre-ignore[21]
607
610
convert_weights ,
608
611
)
609
- elif args . model == "smollm2" :
612
+ elif model_name == "smollm2" :
610
613
from executorch .examples .models .smollm2 import ( # pyre-ignore[21]
611
614
convert_weights ,
612
615
)
613
616
else :
614
617
raise ValueError (
615
- f"Converting weights to meta format for { args . model } is not yet supported"
618
+ f"Converting weights to meta format for { model_name } is not yet supported"
616
619
)
617
- args .checkpoint = download_and_convert_hf_checkpoint (repo_id , convert_weights )
620
+ checkpoint = download_and_convert_hf_checkpoint (repo_id , convert_weights )
621
+ llm_config .base .checkpoint = checkpoint
622
+ args .checkpoint = checkpoint
618
623
619
- if args .profile_path is not None :
624
+ if llm_config . debug .profile_path is not None :
620
625
try :
621
626
from executorch .util .python_profiler import CProfilerFlameGraph
622
627
623
- with CProfilerFlameGraph (args .profile_path ):
624
- builder = _export_llama (args )
628
+ with CProfilerFlameGraph (llm_config . debug .profile_path ):
629
+ builder = _export_llama (llm_config , args )
625
630
assert (
626
631
filename := builder .get_saved_pte_filename ()
627
632
) is not None , "Fail to get file name from builder"
@@ -632,53 +637,53 @@ def export_llama(
632
637
)
633
638
return ""
634
639
else :
635
- builder = _export_llama (args )
640
+ builder = _export_llama (llm_config , args )
636
641
assert (
637
642
filename := builder .get_saved_pte_filename ()
638
643
) is not None , "Fail to get file name from builder"
639
644
return filename
640
645
641
646
642
- def _prepare_for_llama_export (args ) -> LLMEdgeManager :
647
+ def _prepare_for_llama_export (llm_config , args ) -> LLMEdgeManager :
643
648
"""
644
649
Helper function for export_llama. Loads the model from checkpoint and params,
645
650
and sets up a LLMEdgeManager with initial transforms and dtype conversion.
646
651
647
652
Returns a LLMEdgeManager prior to calling export_to_edge with quantizers
648
653
"""
649
654
# load model from checkpoint and params.json
650
- checkpoint_path = canonical_path (args . checkpoint ) if args .checkpoint else None
655
+ checkpoint_path = canonical_path (llm_config . base . checkpoint ) if llm_config . base .checkpoint else None
651
656
checkpoint_dir = (
652
- canonical_path (args . checkpoint_dir ) if args .checkpoint_dir else None
657
+ canonical_path (llm_config . base . checkpoint_dir ) if llm_config . base .checkpoint_dir else None
653
658
)
654
- params_path = canonical_path (args . params ) if args .params else None
655
- output_dir_path = canonical_path (args .output_dir , dir = True )
656
- weight_type = WeightType .FAIRSEQ2 if args .fairseq2 else WeightType .LLAMA
659
+ params_path = canonical_path (llm_config . base . params ) if llm_config . base .params else None
660
+ output_dir_path = canonical_path (llm_config . export .output_dir , dir = True )
661
+ weight_type = WeightType .FAIRSEQ2 if llm_config . base .fairseq2 else WeightType .LLAMA
657
662
658
- # Convert dtype override string arg to actual type.
659
- dtype_override = DType [args .dtype_override ]
663
+ # Convert dtype override string to actual type
664
+ dtype_override = DType [llm_config . model .dtype_override ]
660
665
661
666
edge_manager = _load_llama_model (
662
- args . model ,
667
+ llm_config . base . model_class ,
663
668
checkpoint = checkpoint_path ,
664
669
checkpoint_dir = checkpoint_dir ,
665
670
params_path = params_path ,
666
- use_kv_cache = args .use_kv_cache ,
667
- use_sdpa_with_kv_cache = args .use_sdpa_with_kv_cache ,
668
- generate_full_logits = args .generate_full_logits ,
671
+ use_kv_cache = llm_config . model .use_kv_cache ,
672
+ use_sdpa_with_kv_cache = llm_config . model .use_sdpa_with_kv_cache ,
673
+ generate_full_logits = llm_config . debug .generate_full_logits ,
669
674
weight_type = weight_type ,
670
- enable_dynamic_shape = args .enable_dynamic_shape ,
671
- calibration_tasks = args .calibration_tasks ,
672
- calibration_limit = args .calibration_limit ,
673
- calibration_seq_length = args .calibration_seq_length ,
674
- calibration_data = args .calibration_data ,
675
- tokenizer_path = args .tokenizer_path ,
676
- verbose = args .verbose ,
677
- max_seq_len = args .max_seq_length ,
678
- max_context_len = args .max_context_length ,
679
- input_prune_map_path = args .input_prune_map ,
680
- output_prune_map_path = args .output_prune_map ,
681
- metadata_str = args .metadata ,
675
+ enable_dynamic_shape = llm_config . model .enable_dynamic_shape ,
676
+ calibration_tasks = llm_config . quantization .calibration_tasks ,
677
+ calibration_limit = llm_config . quantization .calibration_limit ,
678
+ calibration_seq_length = llm_config . quantization .calibration_seq_length ,
679
+ calibration_data = llm_config . quantization .calibration_data ,
680
+ tokenizer_path = llm_config . base .tokenizer_path ,
681
+ verbose = llm_config . debug .verbose ,
682
+ max_seq_len = llm_config . export .max_seq_length ,
683
+ max_context_len = llm_config . export .max_context_length ,
684
+ input_prune_map_path = llm_config . model .input_prune_map ,
685
+ output_prune_map_path = llm_config . model .output_prune_map ,
686
+ metadata_str = llm_config . base .metadata ,
682
687
dtype_override = dtype_override ,
683
688
args = args ,
684
689
)
@@ -710,63 +715,63 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
710
715
edge_manager = edge_manager .set_output_dir (output_dir_path ).source_transform (
711
716
_get_source_transforms (
712
717
dtype_override = dtype_override ,
713
- checkpoint = args .checkpoint ,
718
+ checkpoint = llm_config . base .checkpoint ,
714
719
checkpoint_dtype = DType .from_torch_dtype (checkpoint_dtype ), # type: ignore
715
- tokenizer_path = args .tokenizer_path ,
716
- use_spin_quant = args .use_spin_quant ,
717
- embedding_quantize = args .embedding_quantize ,
718
- use_shared_embedding = args .use_shared_embedding ,
719
- quantization_mode = args . quantization_mode ,
720
- group_size = args .group_size ,
721
- calibration_tasks = args .calibration_tasks ,
722
- calibration_limit = args .calibration_limit ,
723
- calibration_seq_length = args .calibration_seq_length ,
724
- expand_rope_table = args .expand_rope_table ,
720
+ tokenizer_path = llm_config . base .tokenizer_path ,
721
+ use_spin_quant = llm_config . quantization .use_spin_quant ,
722
+ embedding_quantize = llm_config . quantization .embedding_quantize ,
723
+ use_shared_embedding = llm_config . model .use_shared_embedding ,
724
+ quantization_mode = llm_config . quantization . qmode ,
725
+ group_size = llm_config . quantization .group_size ,
726
+ calibration_tasks = llm_config . quantization .calibration_tasks ,
727
+ calibration_limit = llm_config . quantization .calibration_limit ,
728
+ calibration_seq_length = llm_config . quantization .calibration_seq_length ,
729
+ expand_rope_table = llm_config . model .expand_rope_table ,
725
730
use_custom_sdpa_with_attention_mask = getattr (
726
- args , "use_custom_sdpa_with_attention_mask" , False
731
+ llm_config . model , "use_custom_sdpa_with_attention_mask" , False
727
732
),
728
- use_sdpa_with_kv_cache = args .use_sdpa_with_kv_cache ,
729
- quantize_kv_cache = args .quantize_kv_cache ,
730
- use_kv_cache = args .use_kv_cache ,
731
- qnn = args . qnn ,
732
- use_qnn_sha = args . use_qnn_sha ,
733
- optimized_rotation_path = args .optimized_rotation_path ,
734
- mps = args . mps ,
735
- coreml = args . coreml ,
736
- coreml_ios = args . coreml_ios ,
737
- vulkan = args . vulkan ,
738
- use_qat = args .use_qat ,
739
- use_lora = args .use_lora ,
740
- preq_mode = args .preq_mode ,
741
- preq_group_size = args .preq_group_size ,
742
- preq_embedding_quantize = args .preq_embedding_quantize ,
733
+ use_sdpa_with_kv_cache = llm_config . model .use_sdpa_with_kv_cache ,
734
+ quantize_kv_cache = llm_config . model .quantize_kv_cache ,
735
+ use_kv_cache = llm_config . model .use_kv_cache ,
736
+ qnn = llm_config . backend . qnn . enabled ,
737
+ use_qnn_sha = llm_config . backend . qnn . use_sha ,
738
+ optimized_rotation_path = llm_config . backend . qnn .optimized_rotation_path ,
739
+ mps = llm_config . backend . mps . enabled ,
740
+ coreml = llm_config . backend . coreml . enabled ,
741
+ coreml_ios = llm_config . backend . coreml . ios ,
742
+ vulkan = llm_config . backend . vulkan . enabled ,
743
+ use_qat = llm_config . quantization .use_qat ,
744
+ use_lora = llm_config . base .use_lora ,
745
+ preq_mode = llm_config . base .preq_mode ,
746
+ preq_group_size = llm_config . base .preq_group_size ,
747
+ preq_embedding_quantize = llm_config . base .preq_embedding_quantize ,
743
748
)
744
749
)
745
750
746
751
return edge_manager
747
752
748
753
749
- def get_quantizer_and_quant_params (args ):
754
+ def get_quantizer_and_quant_params (llm_config ):
750
755
pt2e_quant_params = get_pt2e_quantization_params (
751
- args . pt2e_quantize , args . quantization_mode
756
+ llm_config . quantization . pt2e_quantize , llm_config . quantization . qmode
752
757
)
753
- quantizers = get_pt2e_quantizers (pt2e_quant_params , args .so_library )
758
+ quantizers = get_pt2e_quantizers (pt2e_quant_params , llm_config . export .so_library )
754
759
quant_dtype = None
755
- if args . qnn and args .pt2e_quantize :
760
+ if llm_config . backend . qnn . enabled and llm_config . quantization .pt2e_quantize :
756
761
assert len (quantizers ) == 0 , "Should not enable both xnnpack and qnn"
757
762
qnn_quantizer , quant_dtype = get_qnn_quantizer (
758
- args . pt2e_quantize , args . quantization_mode
763
+ llm_config . quantization . pt2e_quantize , llm_config . quantization . qmode
759
764
)
760
765
quantizers .append (qnn_quantizer )
761
- if args . coreml and args .pt2e_quantize :
766
+ if llm_config . backend . coreml . enabled and llm_config . quantization .pt2e_quantize :
762
767
assert len (quantizers ) == 0 , "Should not enable both xnnpack / qnn and coreml"
763
- coreml_quantizer = get_coreml_quantizer (args .pt2e_quantize )
768
+ coreml_quantizer = get_coreml_quantizer (llm_config . quantization .pt2e_quantize )
764
769
quantizers .append (coreml_quantizer )
765
- if args . vulkan and args .pt2e_quantize :
770
+ if llm_config . backend . vulkan . enabled and llm_config . quantization .pt2e_quantize :
766
771
assert (
767
772
len (quantizers ) == 0
768
773
), "Should not enable both vulkan and other quantizers"
769
- vulkan_quantizer = get_vulkan_quantizer (args .pt2e_quantize )
774
+ vulkan_quantizer = get_vulkan_quantizer (llm_config . quantization .pt2e_quantize )
770
775
quantizers .append (vulkan_quantizer )
771
776
logging .info (f"Applying quantizers: { quantizers } " )
772
777
return pt2e_quant_params , quantizers , quant_dtype
@@ -789,28 +794,32 @@ def _qmode_type(value):
789
794
)
790
795
791
796
792
- def _validate_args (args ):
797
+ def _validate_args (llm_config ):
793
798
"""
794
799
TODO: Combine all the backends under --backend args
795
800
"""
796
801
797
- if args . max_context_length < args .max_seq_length :
802
+ if llm_config . export . max_context_length < llm_config . export .max_seq_length :
798
803
raise ValueError (
799
- f"max_context_length { args . max_context_length } must be >= max_seq_len { args .max_seq_length } . max_context_length impacts kv cache size that is used to remember history, while max_seq_length refers to user prompt length. Please use --max_context_length to specify context length."
804
+ f"max_context_length { llm_config . export . max_context_length } must be >= max_seq_len { llm_config . export .max_seq_length } . max_context_length impacts kv cache size that is used to remember history, while max_seq_length refers to user prompt length. Please use --max_context_length to specify context length."
800
805
)
801
- if args .enable_dynamic_shape and (args .coreml or args .mps or args .qnn ):
806
+ if llm_config .model .enable_dynamic_shape and (
807
+ llm_config .backend .coreml .enabled or
808
+ llm_config .backend .mps .enabled or
809
+ llm_config .backend .qnn .enabled
810
+ ):
802
811
raise ValueError (
803
812
"Dynamic shape is not supported with coreml, MPS or qnn backends."
804
813
" Please use --disable_dynamic_shape."
805
814
)
806
815
807
- if args . num_sharding > 0 and not args . qnn :
816
+ if llm_config . backend . qnn . num_sharding > 0 and not llm_config . backend . qnn . enabled :
808
817
raise ValueError ("Model shard is only supported with qnn backend now." )
809
818
810
- if args .use_shared_embedding :
819
+ if llm_config . model .use_shared_embedding :
811
820
if not (
812
- args .embedding_quantize is not None
813
- and args .embedding_quantize .startswith ("torchao:" )
821
+ llm_config . quantization .embedding_quantize is not None
822
+ and llm_config . quantization .embedding_quantize .startswith ("torchao:" )
814
823
):
815
824
raise ValueError (
816
825
"Shared embedding is only supported with torchao quantization."
@@ -1038,38 +1047,39 @@ def _to_edge_and_lower_llama( # noqa: C901
1038
1047
return builder
1039
1048
1040
1049
1041
- def _export_llama (args ) -> LLMEdgeManager : # noqa: C901
1042
- _validate_args (args )
1050
+ def _export_llama (llm_config , args ) -> LLMEdgeManager : # noqa: C901
1051
+ _validate_args (llm_config )
1043
1052
1044
- pt2e_quant_params , quantizers , quant_dtype = get_quantizer_and_quant_params (args )
1053
+ pt2e_quant_params , quantizers , quant_dtype = get_quantizer_and_quant_params (llm_config )
1045
1054
1046
1055
additional_passes = []
1047
- if args . model in TORCHTUNE_DEFINED_MODELS :
1056
+ if llm_config . base . model_class in TORCHTUNE_DEFINED_MODELS :
1048
1057
additional_passes = [InitializedMutableBufferPass (["kv_cache_pos" ])]
1049
1058
1050
1059
# export_to_edge
1051
- builder_exported = _prepare_for_llama_export (args ).export ()
1060
+ builder_exported = _prepare_for_llama_export (llm_config , args ).export ()
1052
1061
builder_exported .run_canonical_optimizations ()
1053
1062
modelname = builder_exported .modelname
1054
1063
1055
- if args .export_only :
1064
+ if llm_config . export .export_only :
1056
1065
exit ()
1057
1066
1058
1067
if pt2e_quant_params is not None and pt2e_quant_params .quantize_linear is not None :
1059
- # Force xnnpack to be true if pt2e_quant_params is not None and args.xnnpack is False
1068
+ # Force xnnpack to be true if pt2e_quant_params is not None and xnnpack is False
1069
+ llm_config .backend .xnnpack .enabled = True
1060
1070
args .xnnpack = True
1061
1071
1062
- if args . xnnpack :
1072
+ if llm_config . backend . xnnpack . enabled :
1063
1073
builder = _to_edge_and_lower_llama_xnnpack (
1064
1074
builder_exported ,
1065
1075
modelname ,
1066
1076
additional_passes ,
1067
1077
pt2e_quant_params ,
1068
1078
quantizers ,
1069
1079
quant_dtype ,
1070
- xnnpack_extended_ops = args . xnnpack_extended_ops ,
1071
- generate_etrecord = args .generate_etrecord ,
1072
- verbose = args .verbose ,
1080
+ xnnpack_extended_ops = llm_config . backend . xnnpack . extended_ops ,
1081
+ generate_etrecord = llm_config . debug .generate_etrecord ,
1082
+ verbose = llm_config . debug .verbose ,
1073
1083
)
1074
1084
else :
1075
1085
builder = _to_edge_and_lower_llama (
@@ -1079,33 +1089,33 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
1079
1089
pt2e_quant_params ,
1080
1090
quantizers ,
1081
1091
quant_dtype ,
1082
- vulkan = args . vulkan ,
1083
- mps = args . mps ,
1084
- coreml = args . coreml ,
1085
- qnn = args . qnn ,
1086
- dtype_override = args .dtype_override ,
1087
- enable_dynamic_shape = args .enable_dynamic_shape ,
1088
- use_kv_cache = args .use_kv_cache ,
1089
- embedding_quantize = args .embedding_quantize ,
1090
- pt2e_quantize = args .pt2e_quantize ,
1091
- coreml_ios = args . coreml_ios ,
1092
- coreml_quantize = args . coreml_quantize ,
1093
- coreml_compute_units = args . coreml_compute_units ,
1094
- use_qnn_sha = args . use_qnn_sha ,
1095
- num_sharding = args .num_sharding ,
1096
- soc_model = args .soc_model ,
1097
- generate_etrecord = args .generate_etrecord ,
1098
- verbose = args .verbose ,
1092
+ vulkan = llm_config . backend . vulkan . enabled ,
1093
+ mps = llm_config . backend . mps . enabled ,
1094
+ coreml = llm_config . backend . coreml . enabled ,
1095
+ qnn = llm_config . backend . qnn . enabled ,
1096
+ dtype_override = llm_config . model .dtype_override ,
1097
+ enable_dynamic_shape = llm_config . model .enable_dynamic_shape ,
1098
+ use_kv_cache = llm_config . model .use_kv_cache ,
1099
+ embedding_quantize = llm_config . quantization .embedding_quantize ,
1100
+ pt2e_quantize = llm_config . quantization .pt2e_quantize ,
1101
+ coreml_ios = llm_config . backend . coreml . ios_version ,
1102
+ coreml_quantize = llm_config . backend . coreml . quantize ,
1103
+ coreml_compute_units = llm_config . backend . coreml . compute_units ,
1104
+ use_qnn_sha = llm_config . backend . qnn . use_sha ,
1105
+ num_sharding = llm_config . backend . qnn .num_sharding ,
1106
+ soc_model = llm_config . backend . qnn .soc_model ,
1107
+ generate_etrecord = llm_config . debug .generate_etrecord ,
1108
+ verbose = llm_config . debug .verbose ,
1099
1109
)
1100
1110
1101
- if args .profile_memory :
1111
+ if llm_config . debug .profile_memory :
1102
1112
generate_memory_trace (builder .export_program , "memory_profile.json" )
1103
1113
1104
1114
if builder .dtype == DType .fp16 :
1105
1115
modelname = f"{ modelname } _h"
1106
1116
1107
- if args .output_name :
1108
- modelname = args .output_name
1117
+ if llm_config . export .output_name :
1118
+ modelname = llm_config . export .output_name
1109
1119
if modelname .endswith (".pte" ):
1110
1120
output_file = modelname
1111
1121
modelname = modelname [:- 4 ]
0 commit comments