1+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2+ # All rights reserved.
3+ #
4+ # This source code is licensed under the BSD-style license found in the
5+ # LICENSE file in the root directory of this source tree.
6+
7+ # pyre-unsafe
8+
9+ import copy
10+ import os
11+ import shutil
12+ import tempfile
13+ import unittest
14+
15+ import torch
16+ from executorch .devtools import BundledProgram
17+ from executorch .devtools .etrecord import generate_etrecord , parse_etrecord
18+ from executorch .devtools .inspector import Inspector
19+ from executorch .exir import EdgeCompileConfig , EdgeProgramManager , to_edge
20+ from executorch .exir .capture ._config import CaptureConfig
21+ from executorch .exir .program import ExecutorchProgram
22+ from torch .export import export , ExportedProgram
23+ from executorch .devtools .bundled_program .config import MethodTestCase , MethodTestSuite
24+ from executorch .devtools .bundled_program .serialize import (
25+ serialize_from_bundled_program_to_flatbuffer ,
26+ )
27+
28+ # 定义一个简单的模型用于测试
29+ class SimpleAddModel (torch .nn .Module ):
30+ def __init__ (self ):
31+ super ().__init__ ()
32+
33+ def forward (self , x , y ):
34+ return x + y
35+
36+
37+ class TestDevtoolsEndToEnd (unittest .TestCase ):
38+ def setUp (self ):
39+ self .tmp_dir = tempfile .mkdtemp ()
40+ self .etrecord_path = os .path .join (self .tmp_dir , "etrecord.bin" )
41+ self .etdump_path = os .path .join (self .tmp_dir , "etdump.bin" )
42+
43+ self .model = SimpleAddModel ()
44+
45+ def tearDown (self ):
46+ shutil .rmtree (self .tmp_dir )
47+
48+ def generate_etrecord (self ):
49+ aten_model : ExportedProgram = export (
50+ self .model ,
51+ (torch .randn (1 , 1 , 32 , 32 ),),
52+ )
53+ edge_program_manager = to_edge (
54+ aten_model ,
55+ compile_config = EdgeCompileConfig (
56+ _use_edge_ops = False ,
57+ _check_ir_validity = False ,
58+ ),
59+ )
60+ edge_program_manager_copy = copy .deepcopy (edge_program_manager )
61+ et_program_manager = edge_program_manager .to_executorch ()
62+
63+ generate_etrecord (self .etrecord_path , edge_program_manager_copy , et_program_manager )
64+
65+ def generate_bundled_program (self ):
66+ method_name = "forward"
67+ method_graphs = {method_name : export (self .model , (torch .randn (1 , 1 , 32 , 32 ),))}
68+
69+ inputs = [torch .randn (1 , 1 , 32 , 32 )]
70+ method_test_suites = [
71+ MethodTestSuite (
72+ method_name = method_name ,
73+ test_cases = [MethodTestCase (inputs = inp , expected_outputs = self .model (inp )) for inp in inputs ],
74+ )
75+ ]
76+
77+ executorch_program = to_edge (method_graphs ).to_executorch ()
78+ bundled_program = BundledProgram (
79+ executorch_program = executorch_program ,
80+ method_test_suites = method_test_suites ,
81+ )
82+
83+ return bundled_program
0 commit comments