Skip to content

Commit 6078b91

Browse files
committed
refactor: Use llm_config instead of args in export_llama functions
ghstack-source-id: a0e5237 Pull Request resolved: #11084
1 parent b02cdce commit 6078b91

File tree

1 file changed

+125
-115
lines changed

1 file changed

+125
-115
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 125 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -579,49 +579,54 @@ def export_llama(
579579
if isinstance(export_options, argparse.Namespace):
580580
# Legacy CLI.
581581
args = export_options
582-
llm_config = convert_args_to_llm_config(export_options) # noqa: F841
582+
llm_config = convert_args_to_llm_config(export_options)
583583
elif isinstance(export_options, DictConfig):
584584
# Hydra CLI.
585-
llm_config = export_options # noqa: F841
585+
llm_config = export_options
586+
# Create an args object for backward compatibility during transition
587+
args = argparse.Namespace()
588+
for key, value in llm_config.items():
589+
setattr(args, key, value)
586590
else:
587591
raise ValueError(
588592
"Input to export_llama must be either of type argparse.Namespace or LlmConfig"
589593
)
590594

591-
# TODO: refactor rest of export_llama to use llm_config instead of args.
592-
593595
# If a checkpoint isn't provided for an HF OSS model, download and convert the
594596
# weights first.
595-
if not args.checkpoint and args.model in HUGGING_FACE_REPO_IDS:
596-
repo_id = HUGGING_FACE_REPO_IDS[args.model]
597-
if args.model == "qwen2_5":
597+
model_name = llm_config.base.model_class
598+
if not llm_config.base.checkpoint and model_name in HUGGING_FACE_REPO_IDS:
599+
repo_id = HUGGING_FACE_REPO_IDS[model_name]
600+
if model_name == "qwen2_5":
598601
from executorch.examples.models.qwen2_5 import ( # pyre-ignore[21]
599602
convert_weights,
600603
)
601-
elif args.model.startswith("qwen3"):
604+
elif model_name.startswith("qwen3"):
602605
from executorch.examples.models.qwen3 import ( # pyre-ignore[21]
603606
convert_weights,
604607
)
605-
elif args.model == "phi_4_mini":
608+
elif model_name == "phi_4_mini":
606609
from executorch.examples.models.phi_4_mini import ( # pyre-ignore[21]
607610
convert_weights,
608611
)
609-
elif args.model == "smollm2":
612+
elif model_name == "smollm2":
610613
from executorch.examples.models.smollm2 import ( # pyre-ignore[21]
611614
convert_weights,
612615
)
613616
else:
614617
raise ValueError(
615-
f"Converting weights to meta format for {args.model} is not yet supported"
618+
f"Converting weights to meta format for {model_name} is not yet supported"
616619
)
617-
args.checkpoint = download_and_convert_hf_checkpoint(repo_id, convert_weights)
620+
checkpoint = download_and_convert_hf_checkpoint(repo_id, convert_weights)
621+
llm_config.base.checkpoint = checkpoint
622+
args.checkpoint = checkpoint
618623

619-
if args.profile_path is not None:
624+
if llm_config.debug.profile_path is not None:
620625
try:
621626
from executorch.util.python_profiler import CProfilerFlameGraph
622627

623-
with CProfilerFlameGraph(args.profile_path):
624-
builder = _export_llama(args)
628+
with CProfilerFlameGraph(llm_config.debug.profile_path):
629+
builder = _export_llama(llm_config, args)
625630
assert (
626631
filename := builder.get_saved_pte_filename()
627632
) is not None, "Fail to get file name from builder"
@@ -632,53 +637,53 @@ def export_llama(
632637
)
633638
return ""
634639
else:
635-
builder = _export_llama(args)
640+
builder = _export_llama(llm_config, args)
636641
assert (
637642
filename := builder.get_saved_pte_filename()
638643
) is not None, "Fail to get file name from builder"
639644
return filename
640645

641646

642-
def _prepare_for_llama_export(args) -> LLMEdgeManager:
647+
def _prepare_for_llama_export(llm_config, args) -> LLMEdgeManager:
643648
"""
644649
Helper function for export_llama. Loads the model from checkpoint and params,
645650
and sets up a LLMEdgeManager with initial transforms and dtype conversion.
646651
647652
Returns a LLMEdgeManager prior to calling export_to_edge with quantizers
648653
"""
649654
# load model from checkpoint and params.json
650-
checkpoint_path = canonical_path(args.checkpoint) if args.checkpoint else None
655+
checkpoint_path = canonical_path(llm_config.base.checkpoint) if llm_config.base.checkpoint else None
651656
checkpoint_dir = (
652-
canonical_path(args.checkpoint_dir) if args.checkpoint_dir else None
657+
canonical_path(llm_config.base.checkpoint_dir) if llm_config.base.checkpoint_dir else None
653658
)
654-
params_path = canonical_path(args.params) if args.params else None
655-
output_dir_path = canonical_path(args.output_dir, dir=True)
656-
weight_type = WeightType.FAIRSEQ2 if args.fairseq2 else WeightType.LLAMA
659+
params_path = canonical_path(llm_config.base.params) if llm_config.base.params else None
660+
output_dir_path = canonical_path(llm_config.export.output_dir, dir=True)
661+
weight_type = WeightType.FAIRSEQ2 if llm_config.base.fairseq2 else WeightType.LLAMA
657662

658-
# Convert dtype override string arg to actual type.
659-
dtype_override = DType[args.dtype_override]
663+
# Convert dtype override string to actual type
664+
dtype_override = DType[llm_config.model.dtype_override]
660665

661666
edge_manager = _load_llama_model(
662-
args.model,
667+
llm_config.base.model_class,
663668
checkpoint=checkpoint_path,
664669
checkpoint_dir=checkpoint_dir,
665670
params_path=params_path,
666-
use_kv_cache=args.use_kv_cache,
667-
use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache,
668-
generate_full_logits=args.generate_full_logits,
671+
use_kv_cache=llm_config.model.use_kv_cache,
672+
use_sdpa_with_kv_cache=llm_config.model.use_sdpa_with_kv_cache,
673+
generate_full_logits=llm_config.debug.generate_full_logits,
669674
weight_type=weight_type,
670-
enable_dynamic_shape=args.enable_dynamic_shape,
671-
calibration_tasks=args.calibration_tasks,
672-
calibration_limit=args.calibration_limit,
673-
calibration_seq_length=args.calibration_seq_length,
674-
calibration_data=args.calibration_data,
675-
tokenizer_path=args.tokenizer_path,
676-
verbose=args.verbose,
677-
max_seq_len=args.max_seq_length,
678-
max_context_len=args.max_context_length,
679-
input_prune_map_path=args.input_prune_map,
680-
output_prune_map_path=args.output_prune_map,
681-
metadata_str=args.metadata,
675+
enable_dynamic_shape=llm_config.model.enable_dynamic_shape,
676+
calibration_tasks=llm_config.quantization.calibration_tasks,
677+
calibration_limit=llm_config.quantization.calibration_limit,
678+
calibration_seq_length=llm_config.quantization.calibration_seq_length,
679+
calibration_data=llm_config.quantization.calibration_data,
680+
tokenizer_path=llm_config.base.tokenizer_path,
681+
verbose=llm_config.debug.verbose,
682+
max_seq_len=llm_config.export.max_seq_length,
683+
max_context_len=llm_config.export.max_context_length,
684+
input_prune_map_path=llm_config.model.input_prune_map,
685+
output_prune_map_path=llm_config.model.output_prune_map,
686+
metadata_str=llm_config.base.metadata,
682687
dtype_override=dtype_override,
683688
args=args,
684689
)
@@ -710,63 +715,63 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
710715
edge_manager = edge_manager.set_output_dir(output_dir_path).source_transform(
711716
_get_source_transforms(
712717
dtype_override=dtype_override,
713-
checkpoint=args.checkpoint,
718+
checkpoint=llm_config.base.checkpoint,
714719
checkpoint_dtype=DType.from_torch_dtype(checkpoint_dtype), # type: ignore
715-
tokenizer_path=args.tokenizer_path,
716-
use_spin_quant=args.use_spin_quant,
717-
embedding_quantize=args.embedding_quantize,
718-
use_shared_embedding=args.use_shared_embedding,
719-
quantization_mode=args.quantization_mode,
720-
group_size=args.group_size,
721-
calibration_tasks=args.calibration_tasks,
722-
calibration_limit=args.calibration_limit,
723-
calibration_seq_length=args.calibration_seq_length,
724-
expand_rope_table=args.expand_rope_table,
720+
tokenizer_path=llm_config.base.tokenizer_path,
721+
use_spin_quant=llm_config.quantization.use_spin_quant,
722+
embedding_quantize=llm_config.quantization.embedding_quantize,
723+
use_shared_embedding=llm_config.model.use_shared_embedding,
724+
quantization_mode=llm_config.quantization.qmode,
725+
group_size=llm_config.quantization.group_size,
726+
calibration_tasks=llm_config.quantization.calibration_tasks,
727+
calibration_limit=llm_config.quantization.calibration_limit,
728+
calibration_seq_length=llm_config.quantization.calibration_seq_length,
729+
expand_rope_table=llm_config.model.expand_rope_table,
725730
use_custom_sdpa_with_attention_mask=getattr(
726-
args, "use_custom_sdpa_with_attention_mask", False
731+
llm_config.model, "use_custom_sdpa_with_attention_mask", False
727732
),
728-
use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache,
729-
quantize_kv_cache=args.quantize_kv_cache,
730-
use_kv_cache=args.use_kv_cache,
731-
qnn=args.qnn,
732-
use_qnn_sha=args.use_qnn_sha,
733-
optimized_rotation_path=args.optimized_rotation_path,
734-
mps=args.mps,
735-
coreml=args.coreml,
736-
coreml_ios=args.coreml_ios,
737-
vulkan=args.vulkan,
738-
use_qat=args.use_qat,
739-
use_lora=args.use_lora,
740-
preq_mode=args.preq_mode,
741-
preq_group_size=args.preq_group_size,
742-
preq_embedding_quantize=args.preq_embedding_quantize,
733+
use_sdpa_with_kv_cache=llm_config.model.use_sdpa_with_kv_cache,
734+
quantize_kv_cache=llm_config.model.quantize_kv_cache,
735+
use_kv_cache=llm_config.model.use_kv_cache,
736+
qnn=llm_config.backend.qnn.enabled,
737+
use_qnn_sha=llm_config.backend.qnn.use_sha,
738+
optimized_rotation_path=llm_config.backend.qnn.optimized_rotation_path,
739+
mps=llm_config.backend.mps.enabled,
740+
coreml=llm_config.backend.coreml.enabled,
741+
coreml_ios=llm_config.backend.coreml.ios,
742+
vulkan=llm_config.backend.vulkan.enabled,
743+
use_qat=llm_config.quantization.use_qat,
744+
use_lora=llm_config.base.use_lora,
745+
preq_mode=llm_config.base.preq_mode,
746+
preq_group_size=llm_config.base.preq_group_size,
747+
preq_embedding_quantize=llm_config.base.preq_embedding_quantize,
743748
)
744749
)
745750

746751
return edge_manager
747752

748753

749-
def get_quantizer_and_quant_params(args):
754+
def get_quantizer_and_quant_params(llm_config):
750755
pt2e_quant_params = get_pt2e_quantization_params(
751-
args.pt2e_quantize, args.quantization_mode
756+
llm_config.quantization.pt2e_quantize, llm_config.quantization.qmode
752757
)
753-
quantizers = get_pt2e_quantizers(pt2e_quant_params, args.so_library)
758+
quantizers = get_pt2e_quantizers(pt2e_quant_params, llm_config.export.so_library)
754759
quant_dtype = None
755-
if args.qnn and args.pt2e_quantize:
760+
if llm_config.backend.qnn.enabled and llm_config.quantization.pt2e_quantize:
756761
assert len(quantizers) == 0, "Should not enable both xnnpack and qnn"
757762
qnn_quantizer, quant_dtype = get_qnn_quantizer(
758-
args.pt2e_quantize, args.quantization_mode
763+
llm_config.quantization.pt2e_quantize, llm_config.quantization.qmode
759764
)
760765
quantizers.append(qnn_quantizer)
761-
if args.coreml and args.pt2e_quantize:
766+
if llm_config.backend.coreml.enabled and llm_config.quantization.pt2e_quantize:
762767
assert len(quantizers) == 0, "Should not enable both xnnpack / qnn and coreml"
763-
coreml_quantizer = get_coreml_quantizer(args.pt2e_quantize)
768+
coreml_quantizer = get_coreml_quantizer(llm_config.quantization.pt2e_quantize)
764769
quantizers.append(coreml_quantizer)
765-
if args.vulkan and args.pt2e_quantize:
770+
if llm_config.backend.vulkan.enabled and llm_config.quantization.pt2e_quantize:
766771
assert (
767772
len(quantizers) == 0
768773
), "Should not enable both vulkan and other quantizers"
769-
vulkan_quantizer = get_vulkan_quantizer(args.pt2e_quantize)
774+
vulkan_quantizer = get_vulkan_quantizer(llm_config.quantization.pt2e_quantize)
770775
quantizers.append(vulkan_quantizer)
771776
logging.info(f"Applying quantizers: {quantizers}")
772777
return pt2e_quant_params, quantizers, quant_dtype
@@ -789,28 +794,32 @@ def _qmode_type(value):
789794
)
790795

791796

792-
def _validate_args(args):
797+
def _validate_args(llm_config):
793798
"""
794799
TODO: Combine all the backends under --backend args
795800
"""
796801

797-
if args.max_context_length < args.max_seq_length:
802+
if llm_config.export.max_context_length < llm_config.export.max_seq_length:
798803
raise ValueError(
799-
f"max_context_length {args.max_context_length} must be >= max_seq_len {args.max_seq_length}. max_context_length impacts kv cache size that is used to remember history, while max_seq_length refers to user prompt length. Please use --max_context_length to specify context length."
804+
f"max_context_length {llm_config.export.max_context_length} must be >= max_seq_len {llm_config.export.max_seq_length}. max_context_length impacts kv cache size that is used to remember history, while max_seq_length refers to user prompt length. Please use --max_context_length to specify context length."
800805
)
801-
if args.enable_dynamic_shape and (args.coreml or args.mps or args.qnn):
806+
if llm_config.model.enable_dynamic_shape and (
807+
llm_config.backend.coreml.enabled or
808+
llm_config.backend.mps.enabled or
809+
llm_config.backend.qnn.enabled
810+
):
802811
raise ValueError(
803812
"Dynamic shape is not supported with coreml, MPS or qnn backends."
804813
" Please use --disable_dynamic_shape."
805814
)
806815

807-
if args.num_sharding > 0 and not args.qnn:
816+
if llm_config.backend.qnn.num_sharding > 0 and not llm_config.backend.qnn.enabled:
808817
raise ValueError("Model shard is only supported with qnn backend now.")
809818

810-
if args.use_shared_embedding:
819+
if llm_config.model.use_shared_embedding:
811820
if not (
812-
args.embedding_quantize is not None
813-
and args.embedding_quantize.startswith("torchao:")
821+
llm_config.quantization.embedding_quantize is not None
822+
and llm_config.quantization.embedding_quantize.startswith("torchao:")
814823
):
815824
raise ValueError(
816825
"Shared embedding is only supported with torchao quantization."
@@ -1038,38 +1047,39 @@ def _to_edge_and_lower_llama( # noqa: C901
10381047
return builder
10391048

10401049

1041-
def _export_llama(args) -> LLMEdgeManager: # noqa: C901
1042-
_validate_args(args)
1050+
def _export_llama(llm_config, args) -> LLMEdgeManager: # noqa: C901
1051+
_validate_args(llm_config)
10431052

1044-
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)
1053+
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(llm_config)
10451054

10461055
additional_passes = []
1047-
if args.model in TORCHTUNE_DEFINED_MODELS:
1056+
if llm_config.base.model_class in TORCHTUNE_DEFINED_MODELS:
10481057
additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])]
10491058

10501059
# export_to_edge
1051-
builder_exported = _prepare_for_llama_export(args).export()
1060+
builder_exported = _prepare_for_llama_export(llm_config, args).export()
10521061
builder_exported.run_canonical_optimizations()
10531062
modelname = builder_exported.modelname
10541063

1055-
if args.export_only:
1064+
if llm_config.export.export_only:
10561065
exit()
10571066

10581067
if pt2e_quant_params is not None and pt2e_quant_params.quantize_linear is not None:
1059-
# Force xnnpack to be true if pt2e_quant_params is not None and args.xnnpack is False
1068+
# Force xnnpack to be true if pt2e_quant_params is not None and xnnpack is False
1069+
llm_config.backend.xnnpack.enabled = True
10601070
args.xnnpack = True
10611071

1062-
if args.xnnpack:
1072+
if llm_config.backend.xnnpack.enabled:
10631073
builder = _to_edge_and_lower_llama_xnnpack(
10641074
builder_exported,
10651075
modelname,
10661076
additional_passes,
10671077
pt2e_quant_params,
10681078
quantizers,
10691079
quant_dtype,
1070-
xnnpack_extended_ops=args.xnnpack_extended_ops,
1071-
generate_etrecord=args.generate_etrecord,
1072-
verbose=args.verbose,
1080+
xnnpack_extended_ops=llm_config.backend.xnnpack.extended_ops,
1081+
generate_etrecord=llm_config.debug.generate_etrecord,
1082+
verbose=llm_config.debug.verbose,
10731083
)
10741084
else:
10751085
builder = _to_edge_and_lower_llama(
@@ -1079,33 +1089,33 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
10791089
pt2e_quant_params,
10801090
quantizers,
10811091
quant_dtype,
1082-
vulkan=args.vulkan,
1083-
mps=args.mps,
1084-
coreml=args.coreml,
1085-
qnn=args.qnn,
1086-
dtype_override=args.dtype_override,
1087-
enable_dynamic_shape=args.enable_dynamic_shape,
1088-
use_kv_cache=args.use_kv_cache,
1089-
embedding_quantize=args.embedding_quantize,
1090-
pt2e_quantize=args.pt2e_quantize,
1091-
coreml_ios=args.coreml_ios,
1092-
coreml_quantize=args.coreml_quantize,
1093-
coreml_compute_units=args.coreml_compute_units,
1094-
use_qnn_sha=args.use_qnn_sha,
1095-
num_sharding=args.num_sharding,
1096-
soc_model=args.soc_model,
1097-
generate_etrecord=args.generate_etrecord,
1098-
verbose=args.verbose,
1092+
vulkan=llm_config.backend.vulkan.enabled,
1093+
mps=llm_config.backend.mps.enabled,
1094+
coreml=llm_config.backend.coreml.enabled,
1095+
qnn=llm_config.backend.qnn.enabled,
1096+
dtype_override=llm_config.model.dtype_override,
1097+
enable_dynamic_shape=llm_config.model.enable_dynamic_shape,
1098+
use_kv_cache=llm_config.model.use_kv_cache,
1099+
embedding_quantize=llm_config.quantization.embedding_quantize,
1100+
pt2e_quantize=llm_config.quantization.pt2e_quantize,
1101+
coreml_ios=llm_config.backend.coreml.ios_version,
1102+
coreml_quantize=llm_config.backend.coreml.quantize,
1103+
coreml_compute_units=llm_config.backend.coreml.compute_units,
1104+
use_qnn_sha=llm_config.backend.qnn.use_sha,
1105+
num_sharding=llm_config.backend.qnn.num_sharding,
1106+
soc_model=llm_config.backend.qnn.soc_model,
1107+
generate_etrecord=llm_config.debug.generate_etrecord,
1108+
verbose=llm_config.debug.verbose,
10991109
)
11001110

1101-
if args.profile_memory:
1111+
if llm_config.debug.profile_memory:
11021112
generate_memory_trace(builder.export_program, "memory_profile.json")
11031113

11041114
if builder.dtype == DType.fp16:
11051115
modelname = f"{modelname}_h"
11061116

1107-
if args.output_name:
1108-
modelname = args.output_name
1117+
if llm_config.export.output_name:
1118+
modelname = llm_config.export.output_name
11091119
if modelname.endswith(".pte"):
11101120
output_file = modelname
11111121
modelname = modelname[:-4]

0 commit comments

Comments
 (0)