Skip to content

Commit aac7fd6

Browse files
committed
Export a lora model
^ Program+data combined currently, using the lora linear definition. Differential Revision: [D75153377](https://our.internmc.facebook.com/intern/diff/D75153377/) [ghstack-poisoned]
1 parent 287e1dd commit aac7fd6

File tree

3 files changed

+53
-4
lines changed

3 files changed

+53
-4
lines changed

backends/xnnpack/operators/node_visitor.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -595,9 +595,10 @@ def get_serialized_buffer_index(
595595
)
596596

597597
external_tag = tensor.meta.get("delegate_constant_tag", None)
598-
logging.info(
599-
f"Adding constant data with name {tensor.name}, key {named_key} and external_tag {external_tag} to named_data_store"
600-
)
598+
if external_tag is not None:
599+
logging.info(
600+
f"Adding constant data with name {tensor.name}, key {named_key} and external_tag {external_tag} to named_data_store"
601+
)
601602
self._named_data_store.add_named_data(
602603
named_key,
603604
bytes(array),

examples/models/llama/export_llama_lib.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,18 @@ def build_args_parser() -> argparse.ArgumentParser:
235235
help="checkpoint directory. Use with a sharded checkpoint, not for the standard llama2 model. Note, checkpoint_dir takes precedence over checkpoint if both are set.",
236236
)
237237

238+
parser.add_argument(
239+
"--adapter_checkpoint",
240+
required=False,
241+
help="Path to the adapter.pt file from torchtune. Used if the model has trained LoRA adapters. Must provide adapter_config.json",
242+
)
243+
244+
parser.add_argument(
245+
"--adapter_config",
246+
required=False,
247+
help="Path to the adapter_config.json file. Used if the model has trained LoRA adapters. Must provide adapter_checkpoint.",
248+
)
249+
238250
parser.add_argument(
239251
"--use_qnn_sha",
240252
action="store_true",
@@ -631,6 +643,17 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
631643
canonical_path(args.checkpoint_dir) if args.checkpoint_dir else None
632644
)
633645
params_path = canonical_path(args.params) if args.params else None
646+
647+
assert (args.adapter_checkpoint is None and args.adapter_config is None) or (
648+
args.adapter_checkpoint is not None and args.adapter_config is not None
649+
), "Must provide both adapter_checkpoint and adapter_config, or neither"
650+
adapter_checkpoint_path = (
651+
canonical_path(args.adapter_checkpoint) if args.adapter_checkpoint else None
652+
)
653+
adapter_config_path = (
654+
canonical_path(args.adapter_config) if args.adapter_config else None
655+
)
656+
634657
output_dir_path = canonical_path(args.output_dir, dir=True)
635658
weight_type = WeightType.FAIRSEQ2 if args.fairseq2 else WeightType.LLAMA
636659

@@ -642,6 +665,8 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
642665
checkpoint=checkpoint_path,
643666
checkpoint_dir=checkpoint_dir,
644667
params_path=params_path,
668+
adapter_checkpoint=adapter_checkpoint_path,
669+
adapter_config=adapter_config_path,
645670
use_kv_cache=args.use_kv_cache,
646671
use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache,
647672
generate_full_logits=args.generate_full_logits,
@@ -1141,6 +1166,8 @@ def _load_llama_model(
11411166
checkpoint: Optional[str] = None,
11421167
checkpoint_dir: Optional[str] = None,
11431168
params_path: Optional[str] = None,
1169+
adapter_checkpoint: Optional[str] = None,
1170+
adapter_config: Optional[str] = None,
11441171
use_kv_cache: bool = False,
11451172
use_sdpa_with_kv_cache: bool = False,
11461173
generate_full_logits: bool = False,
@@ -1188,6 +1215,8 @@ def _load_llama_model(
11881215
checkpoint=checkpoint,
11891216
checkpoint_dir=checkpoint_dir,
11901217
params=params_path,
1218+
adapter_checkpoint=adapter_checkpoint,
1219+
adapter_config=adapter_config,
11911220
use_kv_cache=use_kv_cache,
11921221
use_sdpa_with_kv_cache=use_sdpa_with_kv_cache,
11931222
generate_full_logits=generate_full_logits,

examples/models/llama/model.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
import json
1010
import os
11-
from typing import Dict, Tuple
1211

1312
import torch
1413
from executorch.examples.models.checkpoint import (
@@ -47,6 +46,10 @@ def __init__(self, **kwargs):
4746
# Params file.
4847
params_path = kwargs.get("params", None)
4948

49+
# Adapter
50+
adapter_checkpoint = kwargs.get("adapter_checkpoint", None)
51+
adapter_config = kwargs.get("adapter_config", None)
52+
5053
self.use_kv_cache = kwargs.get("use_kv_cache", False)
5154
self.use_sdpa_with_kv_cache_op = kwargs.get("use_sdpa_with_kv_cache", False)
5255
self.generate_full_logits = kwargs.get("generate_full_logits", False)
@@ -130,6 +133,21 @@ def __init__(self, **kwargs):
130133
with open(params_path, "r") as f:
131134
params = json.loads(f.read())
132135

136+
# Get adapter checkpoint and config.
137+
adapter_checkpoint = {}
138+
adapter_config = {}
139+
adapter_checkpoint_path = kwargs.get("adapter_checkpoint", None)
140+
if adapter_checkpoint_path:
141+
adapter_checkpoint = torch.load(
142+
adapter_checkpoint_path, map_location=device, mmap=True
143+
)
144+
from torchtune.models import convert_weights
145+
adapter_checkpoint = convert_weights.tune_to_meta(adapter_checkpoint)
146+
adapter_config = kwargs.get("adapter_config", None)
147+
with open(adapter_config, "r") as f:
148+
adapter_config = json.loads(f.read())
149+
checkpoint.update(adapter_checkpoint)
150+
133151
output_prune_map = None
134152
if self.output_prune_map_path is not None:
135153
with open(self.output_prune_map_path, "r") as f:
@@ -154,6 +172,7 @@ def __init__(self, **kwargs):
154172
output_prune_map=output_prune_map,
155173
enable_dynamic_shape=self.enable_dynamic_shape,
156174
**params,
175+
**adapter_config,
157176
)
158177

159178
if model_args.use_scaled_rope:

0 commit comments

Comments
 (0)