10
10
11
11
import torch
12
12
from executorch import exir
13
+ from executorch .exir import to_edge
13
14
from executorch .exir .backend .backend_api import to_backend
14
15
from executorch .exir .backend .compile_spec_schema import CompileSpec
15
16
from executorch .exir .backend .test .backend_with_compiler_demo import (
22
23
_load_for_executorch_from_buffer ,
23
24
)
24
25
from hypothesis import given , settings , strategies as st
26
+ from torch .export import export
25
27
26
28
27
29
class TestBackendAPI (unittest .TestCase ):
@@ -44,7 +46,7 @@ def validate_lowered_module_program(self, program: Program) -> None:
44
46
)
45
47
46
48
def get_program_from_wrapped_module (
47
- self , lowered_module , example_inputs , capture_config , edge_compile_config
49
+ self , lowered_module , example_inputs , edge_compile_config
48
50
):
49
51
class WrappedModule (torch .nn .Module ):
50
52
def __init__ (self ):
@@ -55,17 +57,16 @@ def forward(self, *args):
55
57
return self .one_module (* args )
56
58
57
59
return (
58
- exir .capture (WrappedModule (), example_inputs , capture_config )
59
- .to_edge (edge_compile_config )
60
+ to_edge (
61
+ export (WrappedModule (), example_inputs ),
62
+ compile_config = edge_compile_config ,
63
+ )
60
64
.to_executorch ()
61
- .program
65
+ .executorch_program
62
66
)
63
67
64
- @given (
65
- unlift = st .booleans (), # verify both lifted and unlifted graph
66
- )
67
68
@settings (deadline = 500000 )
68
- def test_emit_lowered_backend_module_end_to_end (self , unlift ):
69
+ def test_emit_lowered_backend_module_end_to_end (self ):
69
70
class SinModule (torch .nn .Module ):
70
71
def __init__ (self ):
71
72
super ().__init__ ()
@@ -76,15 +77,19 @@ def forward(self, x):
76
77
sin_module = SinModule ()
77
78
model_inputs = (torch .ones (1 ),)
78
79
expected_res = sin_module (* model_inputs )
79
- edgeir_m = exir .capture (
80
- sin_module ,
81
- model_inputs ,
82
- exir .CaptureConfig (pt2_mode = True , enable_aot = True , _unlift = unlift ),
83
- ).to_edge (exir .EdgeCompileConfig (_check_ir_validity = False , _use_edge_ops = True ))
80
+ edgeir_m = to_edge (
81
+ export (
82
+ sin_module ,
83
+ model_inputs ,
84
+ ),
85
+ compile_config = exir .EdgeCompileConfig (
86
+ _check_ir_validity = False , _use_edge_ops = True
87
+ ),
88
+ )
84
89
max_value = model_inputs [0 ].shape [0 ]
85
90
compile_specs = [CompileSpec ("max_value" , bytes ([max_value ]))]
86
91
lowered_sin_module = to_backend (
87
- BackendWithCompilerDemo .__name__ , edgeir_m .exported_program , compile_specs
92
+ BackendWithCompilerDemo .__name__ , edgeir_m .exported_program () , compile_specs
88
93
)
89
94
90
95
new_res = lowered_sin_module (* model_inputs )
@@ -120,26 +125,22 @@ def test_emit_lowered_backend_module(self, unlift):
120
125
models .ModelWithUnusedArg (),
121
126
]
122
127
123
- capture_config = (
124
- exir .CaptureConfig (enable_aot = True ) if unlift else exir .CaptureConfig ()
125
- )
126
-
127
128
edge_compile_config = exir .EdgeCompileConfig (
128
129
_check_ir_validity = False , _use_edge_ops = True
129
130
)
130
131
131
132
for model in module_list :
132
133
model_inputs = model .get_random_inputs ()
133
134
134
- edgeir_m = exir . capture ( model , model_inputs , capture_config ). to_edge (
135
- edge_compile_config
135
+ edgeir_m = to_edge (
136
+ export ( model , model_inputs ), compile_config = edge_compile_config
136
137
)
137
138
lowered_model = to_backend (
138
- QnnBackend .__name__ , edgeir_m .exported_program , []
139
+ QnnBackend .__name__ , edgeir_m .exported_program () , []
139
140
)
140
141
program = lowered_model .program ()
141
142
reference_program = self .get_program_from_wrapped_module (
142
- lowered_model , model_inputs , capture_config , edge_compile_config
143
+ lowered_model , model_inputs , edge_compile_config
143
144
)
144
145
145
146
# Check program is fairly equal to the reference program
@@ -180,22 +181,18 @@ def test_emit_nested_lowered_backend_module(self, unlift):
180
181
models .ModelWithUnusedArg (),
181
182
]
182
183
183
- capture_config = (
184
- exir .CaptureConfig (enable_aot = True ) if unlift else exir .CaptureConfig ()
185
- )
186
-
187
184
edge_compile_config = exir .EdgeCompileConfig (
188
185
_check_ir_validity = False , _use_edge_ops = True
189
186
)
190
187
191
188
for model in module_list :
192
189
model_inputs = model .get_random_inputs ()
193
190
194
- edgeir_m = exir . capture ( model , model_inputs , capture_config ). to_edge (
195
- edge_compile_config
191
+ edgeir_m = to_edge (
192
+ export ( model , model_inputs ), compile_config = edge_compile_config
196
193
)
197
194
lowered_module = to_backend (
198
- QnnBackend .__name__ , edgeir_m .exported_program , []
195
+ QnnBackend .__name__ , edgeir_m .exported_program () , []
199
196
)
200
197
201
198
# This module will include one operator and two delegate call
@@ -208,12 +205,12 @@ def forward(self, *args):
208
205
return self .one_module (* args )
209
206
210
207
wrapped_module = WrappedModule (lowered_module )
211
- wrapped_module_edge = exir . capture (
212
- wrapped_module , model_inputs , capture_config
213
- ). to_edge ( edge_compile_config )
208
+ wrapped_module_edge = to_edge (
209
+ export ( wrapped_module , model_inputs ), compile_config = edge_compile_config
210
+ )
214
211
215
212
nested_lowered_model = to_backend (
216
- QnnBackend .__name__ , wrapped_module_edge .exported_program , []
213
+ QnnBackend .__name__ , wrapped_module_edge .exported_program () , []
217
214
)
218
215
219
216
program = nested_lowered_model .program ()
0 commit comments