Skip to content

Commit aed7c50

Browse files
committed
Convert args to LlmConfig
[ghstack-poisoned]
1 parent f8fc412 commit aed7c50

File tree

1 file changed

+98
-2
lines changed

1 file changed

+98
-2
lines changed

examples/models/llama/config/llm_config_utils.py

Lines changed: 98 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,16 @@
77

88
import argparse
99

10-
from executorch.examples.models.llama.config.llm_config import LlmConfig
10+
from executorch.examples.models.llama.config.llm_config import (
11+
CoreMLComputeUnit,
12+
CoreMLQuantize,
13+
DtypeOverride,
14+
LlmConfig,
15+
ModelType,
16+
PreqMode,
17+
Pt2eQuantize,
18+
SpinQuant,
19+
)
1120

1221

1322
def convert_args_to_llm_config(args: argparse.Namespace) -> LlmConfig:
@@ -17,6 +26,93 @@ def convert_args_to_llm_config(args: argparse.Namespace) -> LlmConfig:
1726
"""
1827
llm_config = LlmConfig()
1928

20-
# TODO: conversion code.
29+
# BaseConfig
30+
llm_config.base.model_class = ModelType(args.model)
31+
llm_config.base.params = args.params
32+
llm_config.base.checkpoint = args.checkpoint
33+
llm_config.base.checkpoint_dir = args.checkpoint_dir
34+
llm_config.base.tokenizer_path = args.tokenizer_path
35+
llm_config.base.metadata = args.metadata
36+
llm_config.base.use_lora = bool(args.use_lora)
37+
llm_config.base.fairseq2 = args.fairseq2
38+
39+
# PreqMode settings
40+
if args.preq_mode:
41+
llm_config.base.preq_mode = PreqMode(args.preq_mode)
42+
llm_config.base.preq_group_size = args.preq_group_size
43+
llm_config.base.preq_embedding_quantize = args.preq_embedding_quantize
44+
45+
# ModelConfig
46+
llm_config.model.dtype_override = DtypeOverride(args.dtype_override)
47+
llm_config.model.enable_dynamic_shape = args.enable_dynamic_shape
48+
llm_config.model.use_shared_embedding = args.use_shared_embedding
49+
llm_config.model.use_sdpa_with_kv_cache = args.use_sdpa_with_kv_cache
50+
llm_config.model.expand_rope_table = args.expand_rope_table
51+
llm_config.model.use_attention_sink = args.use_attention_sink
52+
llm_config.model.output_prune_map = args.output_prune_map
53+
llm_config.model.input_prune_map = args.input_prune_map
54+
llm_config.model.use_kv_cache = args.use_kv_cache
55+
llm_config.model.quantize_kv_cache = args.quantize_kv_cache
56+
llm_config.model.local_global_attention = args.local_global_attention
57+
58+
# ExportConfig
59+
llm_config.export.max_seq_length = args.max_seq_length
60+
llm_config.export.max_context_length = args.max_context_length
61+
llm_config.export.output_dir = args.output_dir
62+
llm_config.export.output_name = args.output_name
63+
llm_config.export.so_library = args.so_library
64+
llm_config.export.export_only = args.export_only
65+
66+
# QuantizationConfig
67+
llm_config.quantization.qmode = args.quantization_mode
68+
llm_config.quantization.embedding_quantize = args.embedding_quantize
69+
if args.pt2e_quantize:
70+
llm_config.quantization.pt2e_quantize = Pt2eQuantize(args.pt2e_quantize)
71+
llm_config.quantization.group_size = args.group_size
72+
if args.use_spin_quant:
73+
llm_config.quantization.use_spin_quant = SpinQuant(args.use_spin_quant)
74+
llm_config.quantization.use_qat = args.use_qat
75+
llm_config.quantization.calibration_tasks = args.calibration_tasks
76+
llm_config.quantization.calibration_limit = args.calibration_limit
77+
llm_config.quantization.calibration_seq_length = args.calibration_seq_length
78+
llm_config.quantization.calibration_data = args.calibration_data
79+
80+
# BackendConfig
81+
# XNNPack
82+
llm_config.backend.xnnpack.enabled = args.xnnpack
83+
llm_config.backend.xnnpack.extended_ops = args.xnnpack_extended_ops
84+
85+
# CoreML
86+
llm_config.backend.coreml.enabled = args.coreml
87+
llm_config.backend.coreml.enable_state = getattr(args, "coreml_enable_state", False)
88+
llm_config.backend.coreml.preserve_sdpa = getattr(
89+
args, "coreml_preserve_sdpa", False
90+
)
91+
if args.coreml_quantize:
92+
llm_config.backend.coreml.quantize = CoreMLQuantize(args.coreml_quantize)
93+
llm_config.backend.coreml.ios = args.coreml_ios
94+
llm_config.backend.coreml.compute_units = CoreMLComputeUnit(
95+
args.coreml_compute_units
96+
)
97+
98+
# Vulkan
99+
llm_config.backend.vulkan.enabled = args.vulkan
100+
101+
# QNN
102+
llm_config.backend.qnn.enabled = args.qnn
103+
llm_config.backend.qnn.use_sha = args.use_qnn_sha
104+
llm_config.backend.qnn.soc_model = args.soc_model
105+
llm_config.backend.qnn.optimized_rotation_path = args.optimized_rotation_path
106+
llm_config.backend.qnn.num_sharding = args.num_sharding
107+
108+
# MPS
109+
llm_config.backend.mps.enabled = args.mps
110+
111+
# DebugConfig
112+
llm_config.debug.profile_memory = args.profile_memory
113+
llm_config.debug.profile_path = args.profile_path
114+
llm_config.debug.generate_etrecord = args.generate_etrecord
115+
llm_config.debug.generate_full_logits = args.generate_full_logits
116+
llm_config.debug.verbose = args.verbose
21117

22118
return llm_config

0 commit comments

Comments
 (0)