Skip to content

Commit 54dccc9

Browse files
authored
ANE-friendly static llama (#8436)
* init * up * up * up * up * up * lint * up * up * up * up * lint
1 parent cae89c5 commit 54dccc9

File tree

4 files changed

+1028
-0
lines changed

4 files changed

+1028
-0
lines changed

examples/apple/coreml/llama/export.py

Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,285 @@
1+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
# pyre-strict
4+
5+
import argparse
6+
import json
7+
8+
import sys
9+
10+
import coremltools as ct
11+
import torch
12+
from executorch.backends.apple.coreml.compiler import CoreMLBackend # pyre-ignore
13+
from executorch.backends.apple.coreml.partition import CoreMLPartitioner # pyre-ignore
14+
from executorch.examples.models.llama.source_transformation.quantize import (
15+
EmbeddingQuantHandler,
16+
)
17+
18+
from executorch.exir.backend.utils import format_delegated_graph
19+
from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig
20+
from executorch.exir.passes import MemoryPlanningPass
21+
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
22+
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
23+
from executorch.extension.export_util.utils import export_to_edge, save_pte_program
24+
25+
sys.path.insert(0, ".")
26+
from llama_transformer import InputManager, ModelArgs, Transformer
27+
28+
29+
class SplitLinearModule(torch.nn.Module):
30+
def __init__(self, in_features, out_features, target_split_size, max_splits):
31+
super(SplitLinearModule, self).__init__()
32+
num_splits = max(out_features // target_split_size, 1)
33+
if num_splits > max_splits:
34+
num_splits = max_splits
35+
36+
self.split_size = out_features // num_splits
37+
self.split_remainder = out_features % num_splits
38+
self.splits = torch.nn.ModuleList(
39+
[torch.nn.Linear(in_features, self.split_size) for _ in range(num_splits)]
40+
)
41+
print(
42+
f"Splitting out_features={out_features} into {num_splits} of size {self.split_size}"
43+
)
44+
if self.split_remainder > 0:
45+
print(
46+
f"Warning: remainder {self.split_remainder} after splitting out_features={out_features} into {num_splits} of size {self.split_size}"
47+
)
48+
self.splits.append(torch.nn.Linear(in_features, self.split_remainder))
49+
50+
def split_sizes(self):
51+
return [split.out_features for split in self.splits]
52+
53+
def forward(self, x):
54+
return torch.cat([split(x) for split in self.splits], dim=-1)
55+
56+
57+
def replace_linear_with_split_linear(model, target_split_size, max_splits):
58+
for name, module in model.named_children():
59+
if isinstance(module, torch.nn.Linear):
60+
new_module = SplitLinearModule(
61+
module.in_features, module.out_features, target_split_size, max_splits
62+
)
63+
split_sizes = new_module.split_sizes()
64+
if module.bias is not None:
65+
split_bias = module.bias.split(split_sizes)
66+
split_weights = module.weight.split(split_sizes, dim=0)
67+
for i, split in enumerate(new_module.splits):
68+
split.weight = torch.nn.Parameter(split_weights[i])
69+
if module.bias is not None:
70+
split.bias = torch.nn.Parameter(split_bias[i])
71+
else:
72+
split.bias = None
73+
setattr(model, name, new_module)
74+
else:
75+
replace_linear_with_split_linear(module, target_split_size, max_splits)
76+
77+
78+
def main() -> None:
79+
parser = argparse.ArgumentParser()
80+
parser.add_argument(
81+
"-n",
82+
"--output_name",
83+
default="model.pte",
84+
help="Override the output filename of the saved pte model file.",
85+
)
86+
parser.add_argument(
87+
"-p",
88+
"--params",
89+
help="config.json",
90+
)
91+
parser.add_argument(
92+
"-c",
93+
"--checkpoint",
94+
help="checkpoint path",
95+
)
96+
parser.add_argument(
97+
"--seq_length",
98+
type=int,
99+
default=1,
100+
help="length sequence to evaluate",
101+
)
102+
parser.add_argument(
103+
"--max_seq_length",
104+
type=int,
105+
default=128,
106+
help="maximum length sequence to evaluate",
107+
)
108+
parser.add_argument(
109+
"--cache_size",
110+
type=int,
111+
default=None,
112+
help="Cache size. Old items are evicted from cache",
113+
)
114+
parser.add_argument(
115+
"-E",
116+
"--embedding-quantize",
117+
default=None,
118+
type=str,
119+
help="type of embedding quantization, '<bitwidth>,<groupsize>', e.g., '8,1024'.",
120+
)
121+
parser.add_argument(
122+
"--coreml-quantize",
123+
default=None,
124+
choices=["b4w", "c4w"],
125+
help="This option is only for coreml: Use coreml quantization, e.g. b4w (for blockwise 4 bit weight), c4w (for channelwise 4 bit weight)",
126+
)
127+
parser.add_argument(
128+
"--use_cache_list",
129+
action="store_true",
130+
help="Use cache list to speed up model computation (does not work in pybindings)",
131+
)
132+
parser.add_argument(
133+
"--target_split_size",
134+
type=int,
135+
default=None,
136+
help="Split linear layers into smaller chunks of target_split_size.",
137+
)
138+
parser.add_argument(
139+
"--max_splits",
140+
type=int,
141+
default=8,
142+
help="Maximum number of splits to divide linear layers",
143+
)
144+
145+
export_args = parser.parse_args()
146+
params_path = export_args.params
147+
checkpoint_path = export_args.checkpoint
148+
149+
# Load model args
150+
with open(params_path, "r") as f:
151+
params = json.loads(f.read())
152+
153+
args = ModelArgs(
154+
max_seq_len=export_args.max_seq_length,
155+
generate_full_logits=False,
156+
use_cache_list=export_args.use_cache_list,
157+
**params,
158+
)
159+
160+
with torch.device("meta"):
161+
model = Transformer(args)
162+
163+
checkpoint = torch.load(
164+
checkpoint_path, map_location="cpu", mmap=True, weights_only=True
165+
)
166+
if "model" in checkpoint:
167+
checkpoint = checkpoint["model"]
168+
169+
missing, unexpected = model.load_state_dict(
170+
checkpoint,
171+
strict=False,
172+
assign=True,
173+
)
174+
print("Missing keys: ", missing)
175+
print("Unexpected keys: ", unexpected)
176+
177+
float_dtype = torch.float16 # dtype for model/inputs
178+
model.eval()
179+
model.to(float_dtype)
180+
181+
if export_args.embedding_quantize:
182+
bitwidth, group_size = export_args.embedding_quantize.split(",")
183+
if group_size == "none" or group_size == "None" or group_size == "0":
184+
group_size = None
185+
else:
186+
group_size = int(group_size)
187+
bitwidth = int(bitwidth)
188+
model = EmbeddingQuantHandler(
189+
model,
190+
bitwidth=bitwidth,
191+
group_size=group_size,
192+
packed=(bitwidth in [2, 4]),
193+
).quantized_model()
194+
195+
if export_args.target_split_size is not None:
196+
replace_linear_with_split_linear(
197+
model, export_args.target_split_size, export_args.max_splits
198+
)
199+
200+
model = model.to(float_dtype)
201+
202+
op_linear_quantizer_config = None
203+
if export_args.coreml_quantize == "b4w":
204+
op_linear_quantizer_config = {
205+
"mode": "linear_symmetric",
206+
"dtype": "int4",
207+
"granularity": "per_block",
208+
"block_size": 32,
209+
"weight_threshold": 512,
210+
}
211+
elif export_args.coreml_quantize == "c4w":
212+
op_linear_quantizer_config = {
213+
"mode": "linear_symmetric",
214+
"dtype": "int4",
215+
"granularity": "per_channel",
216+
}
217+
218+
compile_specs = CoreMLBackend.generate_compile_specs( # pyre-fixme[16]
219+
minimum_deployment_target=ct.target.iOS18,
220+
compute_precision=ct.precision(ct.precision.FLOAT16.value),
221+
compute_unit=ct.ComputeUnit.CPU_AND_NE,
222+
model_type=CoreMLBackend.MODEL_TYPE.MODEL, # pyre-fixme[16]
223+
op_linear_quantizer_config=op_linear_quantizer_config,
224+
)
225+
partitioner = CoreMLPartitioner( # pyre-fixme[16]
226+
compile_specs=compile_specs,
227+
take_over_mutable_buffer=False,
228+
skip_ops_for_coreml_delegation=[
229+
"quantized_decomposed.embedding_4bit.dtype",
230+
"aten.embedding.default",
231+
],
232+
)
233+
234+
input_manager = InputManager(
235+
n_layers=args.n_layers,
236+
max_batch_size=args.max_batch_size,
237+
n_kv_heads=args.n_kv_heads,
238+
max_seq_length=args.max_seq_len,
239+
head_dim=args.head_dim,
240+
use_cache_list=export_args.use_cache_list,
241+
seq_length=export_args.seq_length,
242+
dtype=float_dtype,
243+
minus_infinity=-30000,
244+
cache_size=export_args.cache_size,
245+
)
246+
example_inputs = input_manager.get_inputs(tokens=[0])
247+
248+
edge_manager = export_to_edge(
249+
model,
250+
example_inputs,
251+
edge_compile_config=EdgeCompileConfig(
252+
_check_ir_validity=False,
253+
_skip_type_promotion=(float_dtype == torch.float16),
254+
_skip_dim_order=True,
255+
),
256+
)
257+
print("Edge program")
258+
print(edge_manager.exported_program())
259+
260+
for node in edge_manager.exported_program().graph_module.graph.nodes:
261+
print(node.name, node.target, node.args, node.kwargs)
262+
263+
edge_manager = edge_manager.to_backend(partitioner)
264+
265+
print("Delegated program")
266+
267+
print(format_delegated_graph(edge_manager.exported_program().graph_module))
268+
269+
executorch_program = edge_manager.to_executorch(
270+
ExecutorchBackendConfig(
271+
extract_delegate_segments=True,
272+
passes=[
273+
QuantFusionPass(),
274+
],
275+
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
276+
sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
277+
)
278+
)
279+
280+
filename = save_pte_program(executorch_program, export_args.output_name)
281+
print(f"Saved Executorch program to local {filename}")
282+
283+
284+
if __name__ == "__main__":
285+
main() # pragma: no cover

0 commit comments

Comments
 (0)