Skip to content

Introduce hydra framework with backwards compatibility #11029

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: gh/jackzhxng/11/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions examples/models/llama/config/llm_config_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import argparse

from executorch.examples.models.llama.config.llm_config import LlmConfig


def convert_args_to_llm_config(args: argparse.Namespace) -> LlmConfig:
"""
To support legacy purposes, this function converts CLI args from
argparse to an LlmConfig, which is used by the LLM export process.
"""
llm_config = LlmConfig()

# TODO: conversion code.

return llm_config
38 changes: 29 additions & 9 deletions examples/models/llama/export_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,50 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# Example script for exporting Llama2 to flatbuffer

import logging

# force=True to ensure logging while in debugger. Set up logger before any
# other imports.
import logging

FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=FORMAT, force=True)

import argparse
import runpy
import sys

import torch

from .export_llama_lib import build_args_parser, export_llama

sys.setrecursionlimit(4096)


def parse_hydra_arg():
"""First parse out the arg for whether to use Hydra or the old CLI."""
parser = argparse.ArgumentParser(add_help=True)
parser.add_argument("--hydra", action="store_true")
args, remaining = parser.parse_known_args()
return args.hydra, remaining


def main() -> None:
seed = 42
torch.manual_seed(seed)
parser = build_args_parser()
args = parser.parse_args()
export_llama(args)

use_hydra, remaining_args = parse_hydra_arg()
if use_hydra:
# The import runs the main function of export_llama_hydra with the remaining args
# under the Hydra framework.
sys.argv = [arg for arg in sys.argv if arg != "--hydra"]
print(f"running with {sys.argv}")
runpy.run_module(
"executorch.examples.models.llama.export_llama_hydra", run_name="__main__"
)
else:
# Use the legacy version of the export_llama script which uses argsparse.
from executorch.examples.models.llama.export_llama_args import (
main as export_llama_args_main,
)

export_llama_args_main(remaining_args)


if __name__ == "__main__":
Expand Down
21 changes: 21 additions & 0 deletions examples/models/llama/export_llama_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
Run export_llama with the legacy argparse setup.
"""

from .export_llama_lib import build_args_parser, export_llama


def main(args) -> None:
parser = build_args_parser()
args = parser.parse_args(args)
export_llama(args)


if __name__ == "__main__":
main()
27 changes: 27 additions & 0 deletions examples/models/llama/export_llama_hydra.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
Run export_llama using the new Hydra CLI.
"""

import hydra

from executorch.examples.models.llama.config.llm_config import LlmConfig
from executorch.examples.models.llama.export_llama_lib import export_llama
from hydra.core.config_store import ConfigStore

cs = ConfigStore.instance()
cs.store(name="llm_config", node=LlmConfig)


@hydra.main(version_base=None, config_name="llm_config")
def main(llm_config: LlmConfig) -> None:
export_llama(llm_config)


if __name__ == "__main__":
main()
23 changes: 22 additions & 1 deletion examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
from executorch.devtools.backend_debug import print_delegation_info

from executorch.devtools.etrecord import generate_etrecord as generate_etrecord_func

from executorch.examples.models.llama.config.llm_config_utils import (
convert_args_to_llm_config,
)
from executorch.examples.models.llama.hf_download import (
download_and_convert_hf_checkpoint,
)
Expand All @@ -51,6 +55,7 @@
get_vulkan_quantizer,
)
from executorch.util.activation_memory_profiler import generate_memory_trace
from omegaconf.dictconfig import DictConfig

from ..model_factory import EagerModelFactory
from .source_transformation.apply_spin_quant_r1_r2 import (
Expand Down Expand Up @@ -568,7 +573,23 @@ def canonical_path(path: Union[str, Path], *, dir: bool = False) -> str:
return return_val


def export_llama(args) -> str:
def export_llama(
export_options: Union[argparse.Namespace, DictConfig],
) -> str:
if isinstance(export_options, argparse.Namespace):
# Legacy CLI.
args = export_options
llm_config = convert_args_to_llm_config(export_options) # noqa: F841
elif isinstance(export_options, DictConfig):
# Hydra CLI.
llm_config = export_options # noqa: F841
else:
raise ValueError(
"Input to export_llama must be either of type argparse.Namespace or LlmConfig"
)

# TODO: refactor rest of export_llama to use llm_config instead of args.

# If a checkpoint isn't provided for an HF OSS model, download and convert the
# weights first.
if not args.checkpoint and args.model in HUGGING_FACE_REPO_IDS:
Expand Down
2 changes: 1 addition & 1 deletion examples/models/llama/install_requirements.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# Install tokenizers for hf .json tokenizer.
# Install snakeviz for cProfile flamegraph
# Install lm-eval for Model Evaluation with lm-evalution-harness.
pip install huggingface_hub tiktoken torchtune sentencepiece tokenizers snakeviz lm_eval==0.4.5 blobfile
pip install hydra-core huggingface_hub tiktoken torchtune sentencepiece tokenizers snakeviz lm_eval==0.4.5 blobfile

# Call the install helper for further setup
python examples/models/llama/install_requirement_helper.py
Loading