Skip to content

Commit 79bb43f

Browse files
committed
generate etrecord and bundle program
1 parent 4717459 commit 79bb43f

File tree

1 file changed

+83
-0
lines changed

1 file changed

+83
-0
lines changed

devtools/test/test_end2end.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
import copy
10+
import os
11+
import shutil
12+
import tempfile
13+
import unittest
14+
15+
import torch
16+
from executorch.devtools import BundledProgram
17+
from executorch.devtools.etrecord import generate_etrecord, parse_etrecord
18+
from executorch.devtools.inspector import Inspector
19+
from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge
20+
from executorch.exir.capture._config import CaptureConfig
21+
from executorch.exir.program import ExecutorchProgram
22+
from torch.export import export, ExportedProgram
23+
from executorch.devtools.bundled_program.config import MethodTestCase, MethodTestSuite
24+
from executorch.devtools.bundled_program.serialize import (
25+
serialize_from_bundled_program_to_flatbuffer,
26+
)
27+
28+
# 定义一个简单的模型用于测试
29+
class SimpleAddModel(torch.nn.Module):
30+
def __init__(self):
31+
super().__init__()
32+
33+
def forward(self, x, y):
34+
return x + y
35+
36+
37+
class TestDevtoolsEndToEnd(unittest.TestCase):
38+
def setUp(self):
39+
self.tmp_dir = tempfile.mkdtemp()
40+
self.etrecord_path = os.path.join(self.tmp_dir, "etrecord.bin")
41+
self.etdump_path = os.path.join(self.tmp_dir, "etdump.bin")
42+
43+
self.model = SimpleAddModel()
44+
45+
def tearDown(self):
46+
shutil.rmtree(self.tmp_dir)
47+
48+
def generate_etrecord(self):
49+
aten_model: ExportedProgram = export(
50+
self.model,
51+
(torch.randn(1, 1, 32, 32),),
52+
)
53+
edge_program_manager = to_edge(
54+
aten_model,
55+
compile_config=EdgeCompileConfig(
56+
_use_edge_ops=False,
57+
_check_ir_validity=False,
58+
),
59+
)
60+
edge_program_manager_copy = copy.deepcopy(edge_program_manager)
61+
et_program_manager = edge_program_manager.to_executorch()
62+
63+
generate_etrecord(self.etrecord_path, edge_program_manager_copy, et_program_manager)
64+
65+
def generate_bundled_program(self):
66+
method_name = "forward"
67+
method_graphs = {method_name: export(self.model, (torch.randn(1, 1, 32, 32),))}
68+
69+
inputs = [torch.randn(1, 1, 32, 32)]
70+
method_test_suites = [
71+
MethodTestSuite(
72+
method_name=method_name,
73+
test_cases=[MethodTestCase(inputs=inp, expected_outputs=self.model(inp)) for inp in inputs],
74+
)
75+
]
76+
77+
executorch_program = to_edge(method_graphs).to_executorch()
78+
bundled_program = BundledProgram(
79+
executorch_program=executorch_program,
80+
method_test_suites=method_test_suites,
81+
)
82+
83+
return bundled_program

0 commit comments

Comments
 (0)