Skip to content

Export a lora model #11045

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: gh/lucylq/83/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions backends/xnnpack/operators/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,9 +595,10 @@ def get_serialized_buffer_index(
)

external_tag = tensor.meta.get("delegate_constant_tag", None)
logging.info(
f"Adding constant data with name {tensor.name}, key {named_key} and external_tag {external_tag} to named_data_store"
)
if external_tag is not None:
logging.info(
f"Adding constant data with name {tensor.name}, key {named_key} and external_tag {external_tag} to named_data_store"
)
self._named_data_store.add_named_data(
named_key,
bytes(array),
Expand Down
29 changes: 29 additions & 0 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,18 @@ def build_args_parser() -> argparse.ArgumentParser:
help="checkpoint directory. Use with a sharded checkpoint, not for the standard llama2 model. Note, checkpoint_dir takes precedence over checkpoint if both are set.",
)

parser.add_argument(
"--adapter_checkpoint",
required=False,
help="Path to the adapter.pt file from torchtune. Used if the model has trained LoRA adapters. Must provide adapter_config.json",
)

parser.add_argument(
"--adapter_config",
required=False,
help="Path to the adapter_config.json file. Used if the model has trained LoRA adapters. Must provide adapter_checkpoint.",
)

parser.add_argument(
"--use_qnn_sha",
action="store_true",
Expand Down Expand Up @@ -631,6 +643,17 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
canonical_path(args.checkpoint_dir) if args.checkpoint_dir else None
)
params_path = canonical_path(args.params) if args.params else None

assert (args.adapter_checkpoint is None and args.adapter_config is None) or (
args.adapter_checkpoint is not None and args.adapter_config is not None
), "Must provide both adapter_checkpoint and adapter_config, or neither"
adapter_checkpoint_path = (
canonical_path(args.adapter_checkpoint) if args.adapter_checkpoint else None
)
adapter_config_path = (
canonical_path(args.adapter_config) if args.adapter_config else None
)

output_dir_path = canonical_path(args.output_dir, dir=True)
weight_type = WeightType.FAIRSEQ2 if args.fairseq2 else WeightType.LLAMA

Expand All @@ -642,6 +665,8 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
checkpoint=checkpoint_path,
checkpoint_dir=checkpoint_dir,
params_path=params_path,
adapter_checkpoint=adapter_checkpoint_path,
adapter_config=adapter_config_path,
use_kv_cache=args.use_kv_cache,
use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache,
generate_full_logits=args.generate_full_logits,
Expand Down Expand Up @@ -1141,6 +1166,8 @@ def _load_llama_model(
checkpoint: Optional[str] = None,
checkpoint_dir: Optional[str] = None,
params_path: Optional[str] = None,
adapter_checkpoint: Optional[str] = None,
adapter_config: Optional[str] = None,
use_kv_cache: bool = False,
use_sdpa_with_kv_cache: bool = False,
generate_full_logits: bool = False,
Expand Down Expand Up @@ -1188,6 +1215,8 @@ def _load_llama_model(
checkpoint=checkpoint,
checkpoint_dir=checkpoint_dir,
params=params_path,
adapter_checkpoint=adapter_checkpoint,
adapter_config=adapter_config,
use_kv_cache=use_kv_cache,
use_sdpa_with_kv_cache=use_sdpa_with_kv_cache,
generate_full_logits=generate_full_logits,
Expand Down
21 changes: 20 additions & 1 deletion examples/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import json
import os
from typing import Dict, Tuple

import torch
from executorch.examples.models.checkpoint import (
Expand Down Expand Up @@ -47,6 +46,10 @@ def __init__(self, **kwargs):
# Params file.
params_path = kwargs.get("params", None)

# Adapter
adapter_checkpoint = kwargs.get("adapter_checkpoint", None)
adapter_config = kwargs.get("adapter_config", None)

self.use_kv_cache = kwargs.get("use_kv_cache", False)
self.use_sdpa_with_kv_cache_op = kwargs.get("use_sdpa_with_kv_cache", False)
self.generate_full_logits = kwargs.get("generate_full_logits", False)
Expand Down Expand Up @@ -130,6 +133,21 @@ def __init__(self, **kwargs):
with open(params_path, "r") as f:
params = json.loads(f.read())

# Get adapter checkpoint and config.
adapter_checkpoint = {}
adapter_config = {}
adapter_checkpoint_path = kwargs.get("adapter_checkpoint", None)
if adapter_checkpoint_path:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use os.exists or similar from Path?

adapter_checkpoint = torch.load(
adapter_checkpoint_path, map_location=device, mmap=True
)
from torchtune.models import convert_weights
adapter_checkpoint = convert_weights.tune_to_meta(adapter_checkpoint)
adapter_config = kwargs.get("adapter_config", None)
with open(adapter_config, "r") as f:
adapter_config = json.loads(f.read())
checkpoint.update(adapter_checkpoint)

output_prune_map = None
if self.output_prune_map_path is not None:
with open(self.output_prune_map_path, "r") as f:
Expand All @@ -154,6 +172,7 @@ def __init__(self, **kwargs):
output_prune_map=output_prune_map,
enable_dynamic_shape=self.enable_dynamic_shape,
**params,
**adapter_config,
)

if model_args.use_scaled_rope:
Expand Down
Loading