5
5
# LICENSE file in the root directory of this source tree.
6
6
7
7
import os
8
- from typing import Optional
8
+ from typing import Dict , Optional
9
9
10
10
import torch
11
+ import torch ._inductor
11
12
import torch .nn as nn
12
13
13
14
from torch .export import Dim
14
- import torch ._inductor
15
15
16
16
from torchchat .cli .builder import (
17
17
_initialize_model ,
@@ -39,6 +39,7 @@ def export_for_server(
39
39
output_path : str = "model.pt2" ,
40
40
dynamic_shapes : bool = False ,
41
41
package : bool = True ,
42
+ metadata : Optional [Dict [str , str ]] = None ,
42
43
) -> str :
43
44
"""
44
45
Export the model using AOT Compile to get a .dso for server use cases.
@@ -67,8 +68,10 @@ def export_for_server(
67
68
dynamic_shapes = None
68
69
69
70
with torch .nn .attention .sdpa_kernel ([torch .nn .attention .SDPBackend .MATH ]):
70
- metadata = {} # TODO: put more metadata here
71
- options = {"aot_inductor.package" : package , "aot_inductor.metadata" : metadata }
71
+ options = {
72
+ "aot_inductor.package" : package ,
73
+ "aot_inductor.metadata" : metadata or {},
74
+ }
72
75
if not package :
73
76
options = {"aot_inductor.output_path" : output_path }
74
77
@@ -81,6 +84,7 @@ def export_for_server(
81
84
82
85
if package :
83
86
from torch ._inductor .package import package_aoti
87
+
84
88
path = package_aoti (output_path , path )
85
89
86
90
print (f"The generated packaged model can be found at: { path } " )
@@ -102,13 +106,13 @@ def export_for_server(
102
106
from typing import Any , Dict , Tuple , Union
103
107
104
108
import executorch .exir as exir
109
+ from executorch .backends .xnnpack ._passes .convert_to_linear import (
110
+ ConvertToLinearPass ,
111
+ )
105
112
106
113
from executorch .backends .xnnpack .partition .xnnpack_partitioner import (
107
114
XnnpackDynamicallyQuantizedPartitioner ,
108
115
)
109
- from executorch .backends .xnnpack ._passes .convert_to_linear import (
110
- ConvertToLinearPass ,
111
- )
112
116
from executorch .exir import EdgeProgramManager , to_edge
113
117
114
118
from executorch .exir .capture ._config import (
@@ -166,18 +170,22 @@ def __init__(self, attention: Attention):
166
170
167
171
self .wo = attention .wo
168
172
169
- max_batch_size , n_heads , max_seq_length , head_dim = (
170
- attention . kv_cache [ 0 ]. k_cache . shape
171
- )
173
+ max_batch_size , n_heads , max_seq_length , head_dim = attention . kv_cache [
174
+ 0
175
+ ]. k_cache . shape
172
176
cache_dtype = attention .kv_cache [0 ].k_cache .dtype
173
177
# The `Attention` module being replaced can have multiple KV caches
174
178
# (denoted by `cache_lanes`). Thus we follow the same setup format
175
179
# as in `Attention.setup_cache`.
176
180
cache_lanes = len (attention .kv_cache )
177
- self .kv_cache = nn .ModuleList ([
178
- CustomKVCache (max_batch_size , max_seq_length , n_heads , head_dim , cache_dtype )
179
- for _ in range (cache_lanes )
180
- ])
181
+ self .kv_cache = nn .ModuleList (
182
+ [
183
+ CustomKVCache (
184
+ max_batch_size , max_seq_length , n_heads , head_dim , cache_dtype
185
+ )
186
+ for _ in range (cache_lanes )
187
+ ]
188
+ )
181
189
182
190
self .n_heads = attention .n_heads
183
191
self .head_dim = attention .head_dim
@@ -215,9 +223,7 @@ def forward(self, x, freqs_cis, mask, input_pos=None, cache_lane: int = 0):
215
223
return self .wo (output )
216
224
217
225
def replace_attention_with_custom_sdpa_attention (module : nn .Module ):
218
- from executorch .extension .llm .custom_ops import ( # noqa
219
- sdpa_with_kv_cache ,
220
- )
226
+ from executorch .extension .llm .custom_ops import sdpa_with_kv_cache # noqa
221
227
222
228
for name , child in module .named_children ():
223
229
if isinstance (child , Attention ):
@@ -238,7 +244,9 @@ def _to_core_aten(
238
244
raise ValueError (
239
245
f"Expected passed in model to be an instance of fx.GraphModule, got { type (model )} "
240
246
)
241
- core_aten_ep = export_for_training (model , example_inputs , dynamic_shapes = dynamic_shapes )
247
+ core_aten_ep = export_for_training (
248
+ model , example_inputs , dynamic_shapes = dynamic_shapes
249
+ )
242
250
if verbose :
243
251
logging .info (f"Core ATen graph:\n { core_aten_ep .graph } " )
244
252
return core_aten_ep
@@ -350,7 +358,11 @@ def main(args):
350
358
351
359
print (f"Using device={ builder_args .device } " )
352
360
set_precision (builder_args .precision )
353
- set_backend (dso = args .output_dso_path , pte = args .output_pte_path , aoti_package = args .output_aoti_package_path )
361
+ set_backend (
362
+ dso = args .output_dso_path ,
363
+ pte = args .output_pte_path ,
364
+ aoti_package = args .output_aoti_package_path ,
365
+ )
354
366
355
367
builder_args .dso_path = None
356
368
builder_args .pte_path = None
@@ -372,6 +384,7 @@ def main(args):
372
384
373
385
# TODO: clean this up
374
386
# This mess is because ET does not support _weight_int4pack_mm right now
387
+ tokenizer_args = None
375
388
if not builder_args .gguf_path :
376
389
# tokenizer needed for quantization so get that here,
377
390
try :
@@ -382,9 +395,8 @@ def main(args):
382
395
383
396
if builder_args .max_seq_length is None :
384
397
if (
385
- (output_dso_path is not None or output_aoti_package_path is not None )
386
- and not builder_args .dynamic_shapes
387
- ):
398
+ output_dso_path is not None or output_aoti_package_path is not None
399
+ ) and not builder_args .dynamic_shapes :
388
400
print ("Setting max_seq_length to 300 for DSO export." )
389
401
builder_args .max_seq_length = 300
390
402
elif output_pte_path is not None :
@@ -397,7 +409,8 @@ def main(args):
397
409
quantize ,
398
410
tokenizer ,
399
411
max_seq_length = builder_args .max_seq_length ,
400
- support_tensor_subclass = output_dso_path is None and output_aoti_package_path is None ,
412
+ support_tensor_subclass = output_dso_path is None
413
+ and output_aoti_package_path is None ,
401
414
)
402
415
model_to_pte = model
403
416
model_to_dso = model
@@ -435,7 +448,9 @@ def main(args):
435
448
if output_dso_path :
436
449
output_dso_path = str (os .path .abspath (output_dso_path ))
437
450
print (f"Exporting model using AOT Inductor to { output_dso_path } " )
438
- print ("WARNING!! The path of compiling a dso is deprecated. Please use --output-aoti-package-path to create a .pt2 artifact instead." )
451
+ print (
452
+ "WARNING!! The path of compiling a dso is deprecated. Please use --output-aoti-package-path to create a .pt2 artifact instead."
453
+ )
439
454
export_for_server (
440
455
model_to_dso ,
441
456
builder_args .device ,
@@ -446,11 +461,23 @@ def main(args):
446
461
447
462
if output_aoti_package_path :
448
463
output_aoti_package_path = str (os .path .abspath (output_aoti_package_path ))
449
- print (f"Exporting model using AOT Inductor to { output_aoti_package_path } " )
464
+
465
+ if tokenizer_args is None :
466
+ tokenizer_type = "0"
467
+ elif tokenizer_args .is_sentencepiece :
468
+ tokenizer_type = "2" # Corresponding to llama2
469
+ else :
470
+ tokenizer_type = "3" # Corresponding to llama3
471
+
472
+ metadata = {"tokenizer_type" : tokenizer_type }
473
+ print (
474
+ "Exporting model using AOT Inductor to " f"{ output_aoti_package_path } ."
475
+ )
450
476
export_for_server (
451
477
model_to_aoti_package ,
452
478
builder_args .device ,
453
479
output_aoti_package_path ,
454
480
builder_args .dynamic_shapes ,
455
481
package = True ,
482
+ metadata = metadata ,
456
483
)
0 commit comments