Skip to content

Commit 64f0321

Browse files
committed
up
1 parent efc5382 commit 64f0321

File tree

3 files changed

+35
-26
lines changed

3 files changed

+35
-26
lines changed

examples/apple/coreml/llama/export.py

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -27,27 +27,25 @@
2727

2828

2929
class SplitLinearModule(torch.nn.Module):
30-
def __init__(self, in_features, out_features, target_size, max_splits):
30+
def __init__(self, in_features, out_features, target_split_size, max_splits):
3131
super(SplitLinearModule, self).__init__()
32-
self.num_splits = max(out_features // target_size, 1)
33-
if self.num_splits > max_splits:
34-
self.num_splits = max_splits
35-
self.common_size = out_features // self.num_splits
36-
self.remainder = out_features % self.num_splits
32+
num_splits = max(out_features // target_split_size, 1)
33+
if num_splits > max_splits:
34+
num_splits = max_splits
35+
36+
self.split_size = out_features // num_splits
37+
self.split_remainder = out_features % num_splits
3738
self.splits = torch.nn.ModuleList(
38-
[
39-
torch.nn.Linear(in_features, self.common_size)
40-
for _ in range(self.num_splits)
41-
]
39+
[torch.nn.Linear(in_features, self.split_size) for _ in range(num_splits)]
4240
)
4341
print(
44-
f"Splitting out_features={out_features} into {self.num_splits} of size {self.common_size}"
42+
f"Splitting out_features={out_features} into {num_splits} of size {self.split_size}"
4543
)
46-
if self.remainder > 0:
44+
if self.split_remainder > 0:
4745
print(
48-
f"Warning: remainder {self.remainder} after splitting out_features={out_features} into {self.num_splits} of size {self.common_size}"
46+
f"Warning: remainder {self.split_remainder} after splitting out_features={out_features} into {num_splits} of size {self.split_size}"
4947
)
50-
self.splits.append(torch.nn.Linear(in_features, self.remainder))
48+
self.splits.append(torch.nn.Linear(in_features, self.split_remainder))
5149

5250
def split_sizes(self):
5351
return [split.out_features for split in self.splits]
@@ -56,11 +54,11 @@ def forward(self, x):
5654
return torch.cat([split(x) for split in self.splits], dim=-1)
5755

5856

59-
def replace_linear_with_split_linear(model, target_size, max_splits):
57+
def replace_linear_with_split_linear(model, target_split_size, max_splits):
6058
for name, module in model.named_children():
6159
if isinstance(module, torch.nn.Linear):
6260
new_module = SplitLinearModule(
63-
module.in_features, module.out_features, target_size, max_splits
61+
module.in_features, module.out_features, target_split_size, max_splits
6462
)
6563
split_sizes = new_module.split_sizes()
6664
if module.bias is not None:
@@ -74,7 +72,7 @@ def replace_linear_with_split_linear(model, target_size, max_splits):
7472
split.bias = None
7573
setattr(model, name, new_module)
7674
else:
77-
replace_linear_with_split_linear(module, target_size, max_splits)
75+
replace_linear_with_split_linear(module, target_split_size, max_splits)
7876

7977

8078
def main() -> None:
@@ -98,7 +96,7 @@ def main() -> None:
9896
parser.add_argument(
9997
"--seq_length",
10098
type=int,
101-
default=1, # set to 1 for decode
99+
default=1,
102100
help="length sequence to evaluate",
103101
)
104102
parser.add_argument(
@@ -132,10 +130,10 @@ def main() -> None:
132130
help="Use cache list to speed up model computation (does not work in pybindings)",
133131
)
134132
parser.add_argument(
135-
"--target_size",
133+
"--target_split_size",
136134
type=int,
137135
default=None,
138-
help="Split linear layers into smaller chunks of target_size",
136+
help="Split linear layers into smaller chunks of target_split_size.",
139137
)
140138
parser.add_argument(
141139
"--max_splits",
@@ -192,9 +190,9 @@ def main() -> None:
192190
packed=(bitwidth in [2, 4]),
193191
).quantized_model()
194192

195-
if export_args.target_size is not None:
193+
if export_args.target_split_size is not None:
196194
replace_linear_with_split_linear(
197-
model, export_args.target_size, export_args.max_splits
195+
model, export_args.target_split_size, export_args.max_splits
198196
)
199197

200198
model = model.to(float_dtype)

examples/apple/coreml/llama/llama_transformer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def __post_init__(self):
120120
if self.head_dim is None:
121121
self.head_dim = self.dim // self.n_heads
122122

123+
123124
class Rope(torch.nn.Module):
124125
def __init__(self, params: ModelArgs):
125126
super().__init__()

examples/apple/coreml/llama/readme.md

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ Export model with:
77
python export.py -n /path/to/output/model.pte -p /path/to/params.json -c /path/to/model.pth --seq_length 64 --max_seq_length 1024 --coreml-quantize c4w
88
```
99

10+
(Note the script should be run from the executorch/examples/apple/coreml/llama directory.)
11+
1012
The runner is written in python and is only intended to serve as an example for how the model inputs should be processed; it is not performant.
1113

1214

@@ -15,18 +17,26 @@ Run model with:
1517
python run.py -m /path/to/model.pte -p /path/to/params.json -t /path/to/tokenizer.model --seq_length 64 --max_seq_length 1024 --prompt "Once upon a time," --n_steps 512
1618
```
1719

18-
The model here is based on a "sliding" cache, where old tokens are evicted from the cache. There is no actual sliding in the implementation, though.tion.
20+
21+
(Note the script should be run from the executorch/examples/apple/coreml/llama directory.)
22+
23+
The model here is based on a "sliding" cache, where old tokens are evicted from the cache. There is no actual sliding in the implementation, though.
1924

2025

2126
## Export args
2227
* seq_length: the number of tokens processed by the model. Sequences shorter than seq_length must be padded, and sequences longer than it must be chunked.
2328
* max_seq_length: the maximum context tokens that can be processed.
2429
* cache_size: the size of the KV cache sequences. This parameter is optional, and defaults to max_seq_length - seq_length. If a smaller cache_size is used, older tokens are evicted from the cache and no longer play a role in attention. For example, if max_seq_length=1024, but cache_size is 512, the model can generate up to 1024 tokens, but only the current tokens and the previous 512 will participate in attention. In terms of computation, cache_size plays a similar role to max_seq_length in models without cache eviction.
2530
* use_cache_list: boolean option that controls whether KV caches are passed as a list of 4D tensors, one per layer, or if they are passed as one 5D tensor. (Note that use_cache_list does not work with ExecuTorch pybindings.)
26-
* target_size: this option splits linear layers into chunks of target size. For example, if target_size is 1024, a linear layer with (in_features=512, out_features=8096) will be split into 8 linear layers with (in_features=512, out_features=1024) and the results concatted. If not specified, the default is no splitting.
31+
* target_split_size: this option splits linear layers into chunks of target size. For example, if target_split_size is 1024, a linear layer with (in_features=512, out_features=8096) will be split into 8 linear layers with (in_features=512, out_features=1024) and the results concatted. If not specified, the default is no splitting.
2732
* max_splits: this controls the maximum number of splits for linear layers. It is only relevant if target_size is passed and defaults to 8.
2833

2934
## Llama1B on iPhone 15
3035

31-
We are actively experimenting with different settings, but here are ones we've found that work well on iPhone 15 Pro for Llama1B:
32-
* max_seq_length=1024, seq_length=64, use_cache_list, target_size=1024, max_splits=8
36+
We are actively experimenting with different settings. But here are ones that we've found work well for Llama1B on iPhone 15 Pro:
37+
38+
* Set use_cache_list
39+
* Split linear layers with target_split_size=1024, max_splits=8
40+
* Use seq_length=32 or seq_length=64, both of which offer reasonable tradeoffs for prefill and decode performance. seq_length=32 is better at decode and seq_length=64 is better at prefill.
41+
42+
In our tests, we set max_seq_length=1024, but if your application allows for it, performance can improve with max_seq_length=512 or by keeping max_seq_length=1024 and setting cache_size=512-seq_length.

0 commit comments

Comments
 (0)