1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+
9
+ import copy
10
+ import os
11
+ import shutil
12
+ import tempfile
13
+ import unittest
14
+
15
+ import torch
16
+ from executorch .devtools import BundledProgram
17
+ from executorch .devtools .etrecord import generate_etrecord , parse_etrecord
18
+ from executorch .devtools .inspector import Inspector
19
+ from executorch .exir import EdgeCompileConfig , EdgeProgramManager , to_edge
20
+ from executorch .exir .capture ._config import CaptureConfig
21
+ from executorch .exir .program import ExecutorchProgram
22
+ from torch .export import export , ExportedProgram
23
+ from executorch .devtools .bundled_program .config import MethodTestCase , MethodTestSuite
24
+ from executorch .devtools .bundled_program .serialize import (
25
+ serialize_from_bundled_program_to_flatbuffer ,
26
+ )
27
+
28
+ # 定义一个简单的模型用于测试
29
+ class SimpleAddModel (torch .nn .Module ):
30
+ def __init__ (self ):
31
+ super ().__init__ ()
32
+
33
+ def forward (self , x , y ):
34
+ return x + y
35
+
36
+
37
+ class TestDevtoolsEndToEnd (unittest .TestCase ):
38
+ def setUp (self ):
39
+ self .tmp_dir = tempfile .mkdtemp ()
40
+ self .etrecord_path = os .path .join (self .tmp_dir , "etrecord.bin" )
41
+ self .etdump_path = os .path .join (self .tmp_dir , "etdump.bin" )
42
+
43
+ self .model = SimpleAddModel ()
44
+
45
+ def tearDown (self ):
46
+ shutil .rmtree (self .tmp_dir )
47
+
48
+ def generate_etrecord (self ):
49
+ aten_model : ExportedProgram = export (
50
+ self .model ,
51
+ (torch .randn (1 , 1 , 32 , 32 ),),
52
+ )
53
+ edge_program_manager = to_edge (
54
+ aten_model ,
55
+ compile_config = EdgeCompileConfig (
56
+ _use_edge_ops = False ,
57
+ _check_ir_validity = False ,
58
+ ),
59
+ )
60
+ edge_program_manager_copy = copy .deepcopy (edge_program_manager )
61
+ et_program_manager = edge_program_manager .to_executorch ()
62
+
63
+ generate_etrecord (self .etrecord_path , edge_program_manager_copy , et_program_manager )
64
+
65
+ def generate_bundled_program (self ):
66
+ method_name = "forward"
67
+ method_graphs = {method_name : export (self .model , (torch .randn (1 , 1 , 32 , 32 ),))}
68
+
69
+ inputs = [torch .randn (1 , 1 , 32 , 32 )]
70
+ method_test_suites = [
71
+ MethodTestSuite (
72
+ method_name = method_name ,
73
+ test_cases = [MethodTestCase (inputs = inp , expected_outputs = self .model (inp )) for inp in inputs ],
74
+ )
75
+ ]
76
+
77
+ executorch_program = to_edge (method_graphs ).to_executorch ()
78
+ bundled_program = BundledProgram (
79
+ executorch_program = executorch_program ,
80
+ method_test_suites = method_test_suites ,
81
+ )
82
+
83
+ return bundled_program
0 commit comments