7
7
8
8
import argparse
9
9
10
- from executorch .examples .models .llama .config .llm_config import LlmConfig
10
+ from executorch .examples .models .llama .config .llm_config import (
11
+ CoreMLComputeUnit ,
12
+ CoreMLQuantize ,
13
+ DtypeOverride ,
14
+ LlmConfig ,
15
+ ModelType ,
16
+ PreqMode ,
17
+ Pt2eQuantize ,
18
+ SpinQuant ,
19
+ )
11
20
12
21
13
22
def convert_args_to_llm_config (args : argparse .Namespace ) -> LlmConfig :
@@ -17,6 +26,93 @@ def convert_args_to_llm_config(args: argparse.Namespace) -> LlmConfig:
17
26
"""
18
27
llm_config = LlmConfig ()
19
28
20
- # TODO: conversion code.
29
+ # BaseConfig
30
+ llm_config .base .model_class = ModelType (args .model )
31
+ llm_config .base .params = args .params
32
+ llm_config .base .checkpoint = args .checkpoint
33
+ llm_config .base .checkpoint_dir = args .checkpoint_dir
34
+ llm_config .base .tokenizer_path = args .tokenizer_path
35
+ llm_config .base .metadata = args .metadata
36
+ llm_config .base .use_lora = bool (args .use_lora )
37
+ llm_config .base .fairseq2 = args .fairseq2
38
+
39
+ # PreqMode settings
40
+ if args .preq_mode :
41
+ llm_config .base .preq_mode = PreqMode (args .preq_mode )
42
+ llm_config .base .preq_group_size = args .preq_group_size
43
+ llm_config .base .preq_embedding_quantize = args .preq_embedding_quantize
44
+
45
+ # ModelConfig
46
+ llm_config .model .dtype_override = DtypeOverride (args .dtype_override )
47
+ llm_config .model .enable_dynamic_shape = args .enable_dynamic_shape
48
+ llm_config .model .use_shared_embedding = args .use_shared_embedding
49
+ llm_config .model .use_sdpa_with_kv_cache = args .use_sdpa_with_kv_cache
50
+ llm_config .model .expand_rope_table = args .expand_rope_table
51
+ llm_config .model .use_attention_sink = args .use_attention_sink
52
+ llm_config .model .output_prune_map = args .output_prune_map
53
+ llm_config .model .input_prune_map = args .input_prune_map
54
+ llm_config .model .use_kv_cache = args .use_kv_cache
55
+ llm_config .model .quantize_kv_cache = args .quantize_kv_cache
56
+ llm_config .model .local_global_attention = args .local_global_attention
57
+
58
+ # ExportConfig
59
+ llm_config .export .max_seq_length = args .max_seq_length
60
+ llm_config .export .max_context_length = args .max_context_length
61
+ llm_config .export .output_dir = args .output_dir
62
+ llm_config .export .output_name = args .output_name
63
+ llm_config .export .so_library = args .so_library
64
+ llm_config .export .export_only = args .export_only
65
+
66
+ # QuantizationConfig
67
+ llm_config .quantization .qmode = args .quantization_mode
68
+ llm_config .quantization .embedding_quantize = args .embedding_quantize
69
+ if args .pt2e_quantize :
70
+ llm_config .quantization .pt2e_quantize = Pt2eQuantize (args .pt2e_quantize )
71
+ llm_config .quantization .group_size = args .group_size
72
+ if args .use_spin_quant :
73
+ llm_config .quantization .use_spin_quant = SpinQuant (args .use_spin_quant )
74
+ llm_config .quantization .use_qat = args .use_qat
75
+ llm_config .quantization .calibration_tasks = args .calibration_tasks
76
+ llm_config .quantization .calibration_limit = args .calibration_limit
77
+ llm_config .quantization .calibration_seq_length = args .calibration_seq_length
78
+ llm_config .quantization .calibration_data = args .calibration_data
79
+
80
+ # BackendConfig
81
+ # XNNPack
82
+ llm_config .backend .xnnpack .enabled = args .xnnpack
83
+ llm_config .backend .xnnpack .extended_ops = args .xnnpack_extended_ops
84
+
85
+ # CoreML
86
+ llm_config .backend .coreml .enabled = args .coreml
87
+ llm_config .backend .coreml .enable_state = getattr (args , "coreml_enable_state" , False )
88
+ llm_config .backend .coreml .preserve_sdpa = getattr (
89
+ args , "coreml_preserve_sdpa" , False
90
+ )
91
+ if args .coreml_quantize :
92
+ llm_config .backend .coreml .quantize = CoreMLQuantize (args .coreml_quantize )
93
+ llm_config .backend .coreml .ios = args .coreml_ios
94
+ llm_config .backend .coreml .compute_units = CoreMLComputeUnit (
95
+ args .coreml_compute_units
96
+ )
97
+
98
+ # Vulkan
99
+ llm_config .backend .vulkan .enabled = args .vulkan
100
+
101
+ # QNN
102
+ llm_config .backend .qnn .enabled = args .qnn
103
+ llm_config .backend .qnn .use_sha = args .use_qnn_sha
104
+ llm_config .backend .qnn .soc_model = args .soc_model
105
+ llm_config .backend .qnn .optimized_rotation_path = args .optimized_rotation_path
106
+ llm_config .backend .qnn .num_sharding = args .num_sharding
107
+
108
+ # MPS
109
+ llm_config .backend .mps .enabled = args .mps
110
+
111
+ # DebugConfig
112
+ llm_config .debug .profile_memory = args .profile_memory
113
+ llm_config .debug .profile_path = args .profile_path
114
+ llm_config .debug .generate_etrecord = args .generate_etrecord
115
+ llm_config .debug .generate_full_logits = args .generate_full_logits
116
+ llm_config .debug .verbose = args .verbose
21
117
22
118
return llm_config
0 commit comments