Skip to content

Commit cffa508

Browse files
committed
up
1 parent 92528e9 commit cffa508

File tree

3 files changed

+58
-2
lines changed

3 files changed

+58
-2
lines changed

examples/apple/coreml/llama/export.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,50 @@
2626
from llama.llama_transformer import InputManager, ModelArgs, Transformer
2727

2828

29+
class SplitLinearModule(torch.nn.Module):
30+
def __init__(self, in_features, out_features, target_size):
31+
super(SplitLinearModule, self).__init__()
32+
self.num_splits = max(out_features // target_size, 1)
33+
self.common_size = out_features // self.num_splits
34+
self.remainder = out_features % self.num_splits
35+
self.splits = torch.nn.ModuleList(
36+
[
37+
torch.nn.Linear(in_features, self.common_size)
38+
for _ in range(self.num_splits)
39+
]
40+
)
41+
if self.remainder > 0:
42+
self.splits.append(torch.nn.Linear(in_features, self.remainder))
43+
44+
def split_sizes(self):
45+
return [split.out_features for split in self.splits]
46+
47+
def forward(self, x):
48+
return torch.cat([split(x) for split in self.splits], dim=-1)
49+
50+
51+
def replace_linear_with_split_linear(model, target_size):
52+
for name, module in model.named_children():
53+
if isinstance(module, torch.nn.Linear):
54+
new_module = SplitLinearModule(
55+
module.in_features, module.out_features, target_size
56+
)
57+
split_sizes = new_module.split_sizes()
58+
if module.bias is not None:
59+
split_bias = module.bias.split(split_sizes)
60+
split_weights = module.weight.split(split_sizes, dim=0)
61+
for i, split in enumerate(new_module.splits):
62+
split.weight = torch.nn.Parameter(split_weights[i])
63+
if module.bias is not None:
64+
split.bias = torch.nn.Parameter(split_bias[i])
65+
else:
66+
split.bias = None
67+
setattr(model, name, new_module)
68+
else:
69+
replace_linear_with_split_linear(module, target_size)
70+
71+
72+
2973
def main() -> None:
3074
parser = argparse.ArgumentParser()
3175
parser.add_argument(
@@ -80,6 +124,12 @@ def main() -> None:
80124
action="store_true",
81125
help="Use cache list to speed up model computation (does not work in pybindings)",
82126
)
127+
parser.add_argument(
128+
"--target_size",
129+
type=int,
130+
default=None,
131+
help="Split linear layers into smaller chunks of target_size",
132+
)
83133

84134
export_args = parser.parse_args()
85135
params_path = export_args.params
@@ -129,6 +179,9 @@ def main() -> None:
129179
packed=(bitwidth in [2, 4]),
130180
).quantized_model()
131181

182+
if export_args.target_size is not None:
183+
replace_linear_with_split_linear(model, export_args.target_size)
184+
132185
model = model.to(float_dtype)
133186

134187
op_linear_quantizer_config = None
@@ -184,6 +237,9 @@ def main() -> None:
184237
print("Edge program")
185238
print(edge_manager.exported_program())
186239

240+
for node in edge_manager.exported_program().graph_module.graph.nodes:
241+
print(node.name, node.target, node.args, node.kwargs)
242+
187243
edge_manager = edge_manager.to_backend(partitioner)
188244

189245
print("Delegated program")

examples/apple/coreml/llama/llama_transformer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,6 @@ def __post_init__(self):
120120
if self.head_dim is None:
121121
self.head_dim = self.dim // self.n_heads
122122

123-
124123
class Rope(torch.nn.Module):
125124
def __init__(self, params: ModelArgs):
126125
super().__init__()
@@ -401,7 +400,7 @@ def forward(
401400

402401
if not self.generate_full_logits:
403402
# Only the last logit is used for the new generated token
404-
h = h[:, input_length - 1, :]
403+
h = h[:, input_length - 1, :].squeeze(1)
405404

406405
h = self.norm(h)
407406

examples/apple/coreml/llama/readme.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ Export model with:
77
python export.py -n /path/to/output/model.pte -p /path/to/params.json -c /path/to/model.pth --seq_length 64 --max_seq_length 1024 --coreml-quantize c4w
88
```
99

10+
For better performance, use "--use_cache_list" export arg (does not work with pybindings). You can also set "--target_size", which splits linear layers into smaller sizes for the ANE (it defaults to no splitting). This can have substantial impact on performance. For example, on Llama1B by setting "--target_size" to 1024, I see 1.34x increase in inference speed on M1 Pro (but loading time is increased). We need further experiments to tune this.
1011

1112
The runner is written in python and is only intended to serve as an example for how the model inputs should be processed; it is not performant.
1213

0 commit comments

Comments
 (0)