Skip to content

Commit efc5382

Browse files
committed
up
1 parent cffa508 commit efc5382

File tree

3 files changed

+36
-86
lines changed

3 files changed

+36
-86
lines changed

examples/apple/coreml/llama/export.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,11 @@
2727

2828

2929
class SplitLinearModule(torch.nn.Module):
30-
def __init__(self, in_features, out_features, target_size):
30+
def __init__(self, in_features, out_features, target_size, max_splits):
3131
super(SplitLinearModule, self).__init__()
3232
self.num_splits = max(out_features // target_size, 1)
33+
if self.num_splits > max_splits:
34+
self.num_splits = max_splits
3335
self.common_size = out_features // self.num_splits
3436
self.remainder = out_features % self.num_splits
3537
self.splits = torch.nn.ModuleList(
@@ -38,7 +40,13 @@ def __init__(self, in_features, out_features, target_size):
3840
for _ in range(self.num_splits)
3941
]
4042
)
43+
print(
44+
f"Splitting out_features={out_features} into {self.num_splits} of size {self.common_size}"
45+
)
4146
if self.remainder > 0:
47+
print(
48+
f"Warning: remainder {self.remainder} after splitting out_features={out_features} into {self.num_splits} of size {self.common_size}"
49+
)
4250
self.splits.append(torch.nn.Linear(in_features, self.remainder))
4351

4452
def split_sizes(self):
@@ -48,11 +56,11 @@ def forward(self, x):
4856
return torch.cat([split(x) for split in self.splits], dim=-1)
4957

5058

51-
def replace_linear_with_split_linear(model, target_size):
59+
def replace_linear_with_split_linear(model, target_size, max_splits):
5260
for name, module in model.named_children():
5361
if isinstance(module, torch.nn.Linear):
5462
new_module = SplitLinearModule(
55-
module.in_features, module.out_features, target_size
63+
module.in_features, module.out_features, target_size, max_splits
5664
)
5765
split_sizes = new_module.split_sizes()
5866
if module.bias is not None:
@@ -66,8 +74,7 @@ def replace_linear_with_split_linear(model, target_size):
6674
split.bias = None
6775
setattr(model, name, new_module)
6876
else:
69-
replace_linear_with_split_linear(module, target_size)
70-
77+
replace_linear_with_split_linear(module, target_size, max_splits)
7178

7279

7380
def main() -> None:
@@ -130,6 +137,12 @@ def main() -> None:
130137
default=None,
131138
help="Split linear layers into smaller chunks of target_size",
132139
)
140+
parser.add_argument(
141+
"--max_splits",
142+
type=int,
143+
default=8,
144+
help="Maximum number of splits to divide linear layers",
145+
)
133146

134147
export_args = parser.parse_args()
135148
params_path = export_args.params
@@ -180,7 +193,9 @@ def main() -> None:
180193
).quantized_model()
181194

182195
if export_args.target_size is not None:
183-
replace_linear_with_split_linear(model, export_args.target_size)
196+
replace_linear_with_split_linear(
197+
model, export_args.target_size, export_args.max_splits
198+
)
184199

185200
model = model.to(float_dtype)
186201

examples/apple/coreml/llama/extract_and_combine.py

Lines changed: 0 additions & 77 deletions
This file was deleted.

examples/apple/coreml/llama/readme.md

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@ Export model with:
77
python export.py -n /path/to/output/model.pte -p /path/to/params.json -c /path/to/model.pth --seq_length 64 --max_seq_length 1024 --coreml-quantize c4w
88
```
99

10-
For better performance, use "--use_cache_list" export arg (does not work with pybindings). You can also set "--target_size", which splits linear layers into smaller sizes for the ANE (it defaults to no splitting). This can have substantial impact on performance. For example, on Llama1B by setting "--target_size" to 1024, I see 1.34x increase in inference speed on M1 Pro (but loading time is increased). We need further experiments to tune this.
11-
1210
The runner is written in python and is only intended to serve as an example for how the model inputs should be processed; it is not performant.
1311

1412

@@ -17,4 +15,18 @@ Run model with:
1715
python run.py -m /path/to/model.pte -p /path/to/params.json -t /path/to/tokenizer.model --seq_length 64 --max_seq_length 1024 --prompt "Once upon a time," --n_steps 512
1816
```
1917

20-
The model here is based on a "sliding" cache, where old tokens are evicted from the cache. By default, the cache size is max_seq_length - seq_length, but you can explicitly pass in a smaller cache size (e.g., --cache_size 512). This can speed up computation and reduce memory. Keep in mind that once cache_size is reached, older tokens get evicted from the cache and do not participate in attention.
18+
The model here is based on a "sliding" cache, where old tokens are evicted from the cache. There is no actual sliding in the implementation, though.tion.
19+
20+
21+
## Export args
22+
* seq_length: the number of tokens processed by the model. Sequences shorter than seq_length must be padded, and sequences longer than it must be chunked.
23+
* max_seq_length: the maximum context tokens that can be processed.
24+
* cache_size: the size of the KV cache sequences. This parameter is optional, and defaults to max_seq_length - seq_length. If a smaller cache_size is used, older tokens are evicted from the cache and no longer play a role in attention. For example, if max_seq_length=1024, but cache_size is 512, the model can generate up to 1024 tokens, but only the current tokens and the previous 512 will participate in attention. In terms of computation, cache_size plays a similar role to max_seq_length in models without cache eviction.
25+
* use_cache_list: boolean option that controls whether KV caches are passed as a list of 4D tensors, one per layer, or if they are passed as one 5D tensor. (Note that use_cache_list does not work with ExecuTorch pybindings.)
26+
* target_size: this option splits linear layers into chunks of target size. For example, if target_size is 1024, a linear layer with (in_features=512, out_features=8096) will be split into 8 linear layers with (in_features=512, out_features=1024) and the results concatted. If not specified, the default is no splitting.
27+
* max_splits: this controls the maximum number of splits for linear layers. It is only relevant if target_size is passed and defaults to 8.
28+
29+
## Llama1B on iPhone 15
30+
31+
We are actively experimenting with different settings, but here are ones we've found that work well on iPhone 15 Pro for Llama1B:
32+
* max_seq_length=1024, seq_length=64, use_cache_list, target_size=1024, max_splits=8

0 commit comments

Comments
 (0)