Skip to content

Commit cbc72a4

Browse files
authored
[aoti] Remove need for -l in cmake (#1159)
1 parent 654bb03 commit cbc72a4

File tree

4 files changed

+79
-63
lines changed

4 files changed

+79
-63
lines changed

.github/workflows/runner-cuda-dtype.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ jobs:
5252
5353
python torchchat.py export --checkpoint-path ${MODEL_DIR}/stories15M.pt --output-aoti-package-path /tmp/model.pt2
5454
55-
./cmake-out/aoti_run /tmp/model.pt2 -d CUDA -z ${MODEL_DIR}/tokenizer.model -i "${PROMPT}"
55+
./cmake-out/aoti_run /tmp/model.pt2 -z ${MODEL_DIR}/tokenizer.model -i "${PROMPT}"
5656
5757
done
5858

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ torchchat/utils/scripts/build_native.sh aoti
341341

342342
Then run the compiled executable, with the pt2.
343343
```bash
344-
cmake-out/aoti_run exportedModels/llama3_1_artifacts.pt2 -z `python3 torchchat.py where llama3.1`/tokenizer.model -l 3 -i "Once upon a time"
344+
cmake-out/aoti_run exportedModels/llama3_1_artifacts.pt2 -z `python3 torchchat.py where llama3.1`/tokenizer.model -i "Once upon a time"
345345
```
346346

347347
## Mobile Execution

runner/run.cpp

Lines changed: 25 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ typedef struct {
102102
typedef struct {
103103
Config config; // the hyperparameters of the architecture (the blueprint)
104104
RunState state; // buffers for the "wave" of activations in the forward pass
105+
std::unordered_map<std::string, std::string> metadata;
105106

106107
#ifdef __AOTI_MODEL__
107108
torch::inductor::AOTIModelPackageLoader *runner;
@@ -141,20 +142,9 @@ void read_checkpoint(char *checkpoint, Config *config) {
141142
config->vocab_size = abs(config->vocab_size);
142143
}
143144

144-
void build_transformer(Transformer *t, char *model_path, int vocab_size,
145-
int seq_len) {
146-
// read in the Config and the Weights from the model
147-
// read_checkpoint(model_path, &t->config);
148-
// allocate the RunState buffers
149-
t->config.vocab_size = vocab_size;
150-
t->config.seq_len = seq_len;
151-
malloc_run_state(&t->state, &t->config);
152-
145+
void build_transformer(Transformer *t, char *model_path) {
153146
#ifdef __AOTI_MODEL__
154147
t->runner = new torch::inductor::AOTIModelPackageLoader(model_path);
155-
aoti_device = t->runner->get_metadata()["AOTI_DEVICE_KEY"] == "cpu"
156-
? torch::Device(torch::kCPU)
157-
: torch::Device(torch::kCUDA);
158148
#else //__ET_MODEL__
159149
t->runner = new Module(
160150
/* path to PTE model */ model_path,
@@ -776,9 +766,6 @@ void error_usage() {
776766
" -v <int> (optional) vocab size, default is model-specific.\n");
777767
fprintf(stderr,
778768
" -l <int> (optional) llama version (2 or 3), default 2.\n");
779-
fprintf(
780-
stderr,
781-
" -d <string> (optional) device(CUDA or CPU) model was exported for\n");
782769
exit(EXIT_FAILURE);
783770
}
784771

@@ -848,37 +835,35 @@ int main(int argc, char *argv[]) {
848835
system_prompt = argv[i + 1];
849836
} else if (argv[i][1] == 'l') {
850837
llama_ver = atoi(argv[i + 1]);
851-
#ifdef __AOTI_MODEL__
852-
} else if (argv[i][1] == 'd') {
853-
#ifdef USE_CUDA
854-
if (strcasecmp(argv[i + 1], "CUDA") == 0) {
855-
aoti_device = torch::Device(torch::kCUDA);
856-
} else
857-
#endif
858-
if (strcasecmp(argv[i + 1], "CPU") == 0) {
859-
aoti_device = torch::Device(torch::kCPU);
860-
} else {
861-
fprintf(stderr, "Unknown device %s", argv[i + 1]);
862-
exit(1);
863-
}
864-
#endif
865838
} else {
866839
error_usage();
867840
}
868841
}
869842

843+
if (model_path == NULL) {
844+
fprintf(stderr, "No model_path provided.");
845+
error_usage();
846+
}
847+
848+
Transformer transformer;
849+
build_transformer(&transformer, model_path);
850+
851+
#ifdef __AOTI_MODEL__
852+
auto aoti_metadata = transformer.runner->get_metadata();
853+
aoti_device = aoti_metadata["AOTI_DEVICE_KEY"] == "cpu"
854+
? torch::Device(torch::kCPU)
855+
: torch::Device(torch::kCUDA);
856+
ModelType model_type = get_model_type(std::stoi(aoti_metadata["tokenizer_type"]));
857+
#else // __ET_MODEL__
870858
ModelType model_type = get_model_type(llama_ver);
859+
#endif
860+
871861
if (model_type == UNKNOWN_MODEL) {
872862
fprintf(stderr, "Unknown model type passed by -l argument. Received l=%d.",
873863
llama_ver);
874864
error_usage();
875865
}
876866

877-
if (model_path == NULL) {
878-
fprintf(stderr, "No model_path provided.");
879-
error_usage();
880-
}
881-
882867
if (tokenizer_path == NULL) {
883868
fprintf(stderr, "No tokenizer_path provided.");
884869
error_usage();
@@ -901,8 +886,12 @@ int main(int argc, char *argv[]) {
901886
vocab_size = tokenizer->vocab_size();
902887
}
903888

904-
Transformer transformer;
905-
build_transformer(&transformer, model_path, vocab_size, steps);
889+
// read in the Config and the Weights from the model
890+
// read_checkpoint(model_path, &t->config);
891+
// allocate the RunState buffers
892+
transformer.config.vocab_size = vocab_size;
893+
transformer.config.seq_len = steps;
894+
malloc_run_state(&transformer.state, &transformer.config);
906895

907896
Sampler sampler;
908897
build_sampler(&sampler, vocab_size, temperature, topp, rng_seed);

torchchat/export.py

Lines changed: 52 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import os
8-
from typing import Optional
8+
from typing import Dict, Optional
99

1010
import torch
11+
import torch._inductor
1112
import torch.nn as nn
1213

1314
from torch.export import Dim
14-
import torch._inductor
1515

1616
from torchchat.cli.builder import (
1717
_initialize_model,
@@ -39,6 +39,7 @@ def export_for_server(
3939
output_path: str = "model.pt2",
4040
dynamic_shapes: bool = False,
4141
package: bool = True,
42+
metadata: Optional[Dict[str, str]] = None,
4243
) -> str:
4344
"""
4445
Export the model using AOT Compile to get a .dso for server use cases.
@@ -67,8 +68,10 @@ def export_for_server(
6768
dynamic_shapes = None
6869

6970
with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
70-
metadata = {} # TODO: put more metadata here
71-
options = {"aot_inductor.package": package, "aot_inductor.metadata": metadata}
71+
options = {
72+
"aot_inductor.package": package,
73+
"aot_inductor.metadata": metadata or {},
74+
}
7275
if not package:
7376
options = {"aot_inductor.output_path": output_path}
7477

@@ -81,6 +84,7 @@ def export_for_server(
8184

8285
if package:
8386
from torch._inductor.package import package_aoti
87+
8488
path = package_aoti(output_path, path)
8589

8690
print(f"The generated packaged model can be found at: {path}")
@@ -102,13 +106,13 @@ def export_for_server(
102106
from typing import Any, Dict, Tuple, Union
103107

104108
import executorch.exir as exir
109+
from executorch.backends.xnnpack._passes.convert_to_linear import (
110+
ConvertToLinearPass,
111+
)
105112

106113
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
107114
XnnpackDynamicallyQuantizedPartitioner,
108115
)
109-
from executorch.backends.xnnpack._passes.convert_to_linear import (
110-
ConvertToLinearPass,
111-
)
112116
from executorch.exir import EdgeProgramManager, to_edge
113117

114118
from executorch.exir.capture._config import (
@@ -166,18 +170,22 @@ def __init__(self, attention: Attention):
166170

167171
self.wo = attention.wo
168172

169-
max_batch_size, n_heads, max_seq_length, head_dim = (
170-
attention.kv_cache[0].k_cache.shape
171-
)
173+
max_batch_size, n_heads, max_seq_length, head_dim = attention.kv_cache[
174+
0
175+
].k_cache.shape
172176
cache_dtype = attention.kv_cache[0].k_cache.dtype
173177
# The `Attention` module being replaced can have multiple KV caches
174178
# (denoted by `cache_lanes`). Thus we follow the same setup format
175179
# as in `Attention.setup_cache`.
176180
cache_lanes = len(attention.kv_cache)
177-
self.kv_cache = nn.ModuleList([
178-
CustomKVCache(max_batch_size, max_seq_length, n_heads, head_dim, cache_dtype)
179-
for _ in range(cache_lanes)
180-
])
181+
self.kv_cache = nn.ModuleList(
182+
[
183+
CustomKVCache(
184+
max_batch_size, max_seq_length, n_heads, head_dim, cache_dtype
185+
)
186+
for _ in range(cache_lanes)
187+
]
188+
)
181189

182190
self.n_heads = attention.n_heads
183191
self.head_dim = attention.head_dim
@@ -215,9 +223,7 @@ def forward(self, x, freqs_cis, mask, input_pos=None, cache_lane: int = 0):
215223
return self.wo(output)
216224

217225
def replace_attention_with_custom_sdpa_attention(module: nn.Module):
218-
from executorch.extension.llm.custom_ops import ( # noqa
219-
sdpa_with_kv_cache,
220-
)
226+
from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa
221227

222228
for name, child in module.named_children():
223229
if isinstance(child, Attention):
@@ -238,7 +244,9 @@ def _to_core_aten(
238244
raise ValueError(
239245
f"Expected passed in model to be an instance of fx.GraphModule, got {type(model)}"
240246
)
241-
core_aten_ep = export_for_training(model, example_inputs, dynamic_shapes=dynamic_shapes)
247+
core_aten_ep = export_for_training(
248+
model, example_inputs, dynamic_shapes=dynamic_shapes
249+
)
242250
if verbose:
243251
logging.info(f"Core ATen graph:\n{core_aten_ep.graph}")
244252
return core_aten_ep
@@ -350,7 +358,11 @@ def main(args):
350358

351359
print(f"Using device={builder_args.device}")
352360
set_precision(builder_args.precision)
353-
set_backend(dso=args.output_dso_path, pte=args.output_pte_path, aoti_package=args.output_aoti_package_path)
361+
set_backend(
362+
dso=args.output_dso_path,
363+
pte=args.output_pte_path,
364+
aoti_package=args.output_aoti_package_path,
365+
)
354366

355367
builder_args.dso_path = None
356368
builder_args.pte_path = None
@@ -372,6 +384,7 @@ def main(args):
372384

373385
# TODO: clean this up
374386
# This mess is because ET does not support _weight_int4pack_mm right now
387+
tokenizer_args = None
375388
if not builder_args.gguf_path:
376389
# tokenizer needed for quantization so get that here,
377390
try:
@@ -382,9 +395,8 @@ def main(args):
382395

383396
if builder_args.max_seq_length is None:
384397
if (
385-
(output_dso_path is not None or output_aoti_package_path is not None)
386-
and not builder_args.dynamic_shapes
387-
):
398+
output_dso_path is not None or output_aoti_package_path is not None
399+
) and not builder_args.dynamic_shapes:
388400
print("Setting max_seq_length to 300 for DSO export.")
389401
builder_args.max_seq_length = 300
390402
elif output_pte_path is not None:
@@ -397,7 +409,8 @@ def main(args):
397409
quantize,
398410
tokenizer,
399411
max_seq_length=builder_args.max_seq_length,
400-
support_tensor_subclass=output_dso_path is None and output_aoti_package_path is None,
412+
support_tensor_subclass=output_dso_path is None
413+
and output_aoti_package_path is None,
401414
)
402415
model_to_pte = model
403416
model_to_dso = model
@@ -435,7 +448,9 @@ def main(args):
435448
if output_dso_path:
436449
output_dso_path = str(os.path.abspath(output_dso_path))
437450
print(f"Exporting model using AOT Inductor to {output_dso_path}")
438-
print("WARNING!! The path of compiling a dso is deprecated. Please use --output-aoti-package-path to create a .pt2 artifact instead.")
451+
print(
452+
"WARNING!! The path of compiling a dso is deprecated. Please use --output-aoti-package-path to create a .pt2 artifact instead."
453+
)
439454
export_for_server(
440455
model_to_dso,
441456
builder_args.device,
@@ -446,11 +461,23 @@ def main(args):
446461

447462
if output_aoti_package_path:
448463
output_aoti_package_path = str(os.path.abspath(output_aoti_package_path))
449-
print(f"Exporting model using AOT Inductor to {output_aoti_package_path}")
464+
465+
if tokenizer_args is None:
466+
tokenizer_type = "0"
467+
elif tokenizer_args.is_sentencepiece:
468+
tokenizer_type = "2" # Corresponding to llama2
469+
else:
470+
tokenizer_type = "3" # Corresponding to llama3
471+
472+
metadata = {"tokenizer_type": tokenizer_type}
473+
print(
474+
"Exporting model using AOT Inductor to " f"{output_aoti_package_path}."
475+
)
450476
export_for_server(
451477
model_to_aoti_package,
452478
builder_args.device,
453479
output_aoti_package_path,
454480
builder_args.dynamic_shapes,
455481
package=True,
482+
metadata=metadata,
456483
)

0 commit comments

Comments
 (0)