Skip to content

Commit 8dcba74

Browse files
committed
use _load_for_executorch_from_buffer instead
1 parent 181c8ab commit 8dcba74

File tree

1 file changed

+9
-40
lines changed

1 file changed

+9
-40
lines changed

devtools/test/test_end2end.py

Lines changed: 9 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,9 @@
2525
serialize_from_bundled_program_to_flatbuffer,
2626
)
2727
from executorch.extension.pybindings.portable_lib import (
28-
_load_for_executorch_from_bundled_program,
29-
_load_bundled_program_from_buffer
28+
_load_for_executorch_from_buffer,
3029
)
3130

32-
# 定义一个简单的模型用于测试
3331
class SimpleAddModel(torch.nn.Module):
3432
def __init__(self):
3533
super().__init__()
@@ -43,6 +41,7 @@ def __init__(self):
4341
self.tmp_dir = "./"
4442
self.etrecord_path = os.path.join(self.tmp_dir, "etrecord.bin")
4543
self.etdump_path = os.path.join(self.tmp_dir, "etdump.bin")
44+
self.et_program_manager = None
4645

4746
self.model = SimpleAddModel()
4847

@@ -64,45 +63,16 @@ def generate_etrecord_(self):
6463
edge_program_manager_copy = copy.deepcopy(edge_program_manager)
6564
et_program_manager = edge_program_manager.to_executorch()
6665

67-
generate_etrecord(self.etrecord_path, edge_program_manager_copy, et_program_manager)
68-
69-
def generate_bundled_program(self):
70-
method_name = "forward"
71-
method_graphs = {method_name: export(self.model, (torch.randn(1, 1, 32, 32), torch.randn(1, 1, 32, 32)))}
72-
73-
inputs = [(torch.randn(1, 1, 32, 32), torch.randn(1, 1, 32, 32))]
74-
method_test_suites = [
75-
MethodTestSuite(
76-
method_name=method_name,
77-
test_cases=[MethodTestCase(inputs=inp, expected_outputs=self.model(*inp)) for inp in inputs],
78-
)
79-
]
80-
81-
executorch_program = to_edge(method_graphs).to_executorch()
82-
bundled_program = BundledProgram(
83-
executorch_program=executorch_program,
84-
method_test_suites=method_test_suites,
85-
)
66+
self.et_program_manager = et_program_manager
8667

87-
return bundled_program
68+
generate_etrecord(self.etrecord_path, edge_program_manager_copy, et_program_manager)
8869

8970
def generate_etdump(self):
90-
bundled_program_py = self.generate_bundled_program()
91-
92-
bundled_program_bytes = serialize_from_bundled_program_to_flatbuffer(
93-
bundled_program_py
94-
)
95-
96-
bundled_program_cpp = _load_bundled_program_from_buffer(bundled_program_bytes)
97-
98-
program = _load_for_executorch_from_bundled_program(
99-
bundled_program_cpp,
100-
enable_etdump=True
101-
)
102-
103-
example_inputs = (torch.randn(1, 1, 32, 32), torch.randn(1, 1, 32, 32))
104-
program.forward(example_inputs)
105-
71+
# load executorch program from buffer, and set enable_etdump to True
72+
program = _load_for_executorch_from_buffer(self.et_program_manager.buffer, enable_etdump=True)
73+
# run program with example inputs to generate etdump
74+
program.forward((torch.randn(1, 1, 32, 32), torch.randn(1, 1, 32, 32)))
75+
# write etdump to file
10676
program.write_etdump_result_to_file(self.etdump_path)
10777

10878
def test_profile(self):
@@ -112,5 +82,4 @@ def test_profile(self):
11282
if __name__ == "__main__":
11383
tester = TestDevtoolsEndToEnd()
11484
tester.generate_etrecord_()
115-
tester.generate_bundled_program()
11685
tester.generate_etdump()

0 commit comments

Comments
 (0)