Skip to content

Commit dfd98b8

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
remove exir.capture from test_lowered_backend_module
Summary: title Differential Revision: D56368215
1 parent db17853 commit dfd98b8

File tree

1 file changed

+30
-33
lines changed

1 file changed

+30
-33
lines changed

exir/backend/test/test_lowered_backend_module.py

Lines changed: 30 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import torch
1212
from executorch import exir
13+
from executorch.exir import to_edge
1314
from executorch.exir.backend.backend_api import to_backend
1415
from executorch.exir.backend.compile_spec_schema import CompileSpec
1516
from executorch.exir.backend.test.backend_with_compiler_demo import (
@@ -22,6 +23,7 @@
2223
_load_for_executorch_from_buffer,
2324
)
2425
from hypothesis import given, settings, strategies as st
26+
from torch.export import export
2527

2628

2729
class TestBackendAPI(unittest.TestCase):
@@ -44,7 +46,7 @@ def validate_lowered_module_program(self, program: Program) -> None:
4446
)
4547

4648
def get_program_from_wrapped_module(
47-
self, lowered_module, example_inputs, capture_config, edge_compile_config
49+
self, lowered_module, example_inputs, edge_compile_config
4850
):
4951
class WrappedModule(torch.nn.Module):
5052
def __init__(self):
@@ -55,17 +57,16 @@ def forward(self, *args):
5557
return self.one_module(*args)
5658

5759
return (
58-
exir.capture(WrappedModule(), example_inputs, capture_config)
59-
.to_edge(edge_compile_config)
60+
to_edge(
61+
export(WrappedModule(), example_inputs),
62+
compile_config=edge_compile_config,
63+
)
6064
.to_executorch()
61-
.program
65+
.executorch_program
6266
)
6367

64-
@given(
65-
unlift=st.booleans(), # verify both lifted and unlifted graph
66-
)
6768
@settings(deadline=500000)
68-
def test_emit_lowered_backend_module_end_to_end(self, unlift):
69+
def test_emit_lowered_backend_module_end_to_end(self):
6970
class SinModule(torch.nn.Module):
7071
def __init__(self):
7172
super().__init__()
@@ -76,15 +77,19 @@ def forward(self, x):
7677
sin_module = SinModule()
7778
model_inputs = (torch.ones(1),)
7879
expected_res = sin_module(*model_inputs)
79-
edgeir_m = exir.capture(
80-
sin_module,
81-
model_inputs,
82-
exir.CaptureConfig(pt2_mode=True, enable_aot=True, _unlift=unlift),
83-
).to_edge(exir.EdgeCompileConfig(_check_ir_validity=False, _use_edge_ops=True))
80+
edgeir_m = to_edge(
81+
export(
82+
sin_module,
83+
model_inputs,
84+
),
85+
compile_config=exir.EdgeCompileConfig(
86+
_check_ir_validity=False, _use_edge_ops=True
87+
),
88+
)
8489
max_value = model_inputs[0].shape[0]
8590
compile_specs = [CompileSpec("max_value", bytes([max_value]))]
8691
lowered_sin_module = to_backend(
87-
BackendWithCompilerDemo.__name__, edgeir_m.exported_program, compile_specs
92+
BackendWithCompilerDemo.__name__, edgeir_m.exported_program(), compile_specs
8893
)
8994

9095
new_res = lowered_sin_module(*model_inputs)
@@ -120,26 +125,22 @@ def test_emit_lowered_backend_module(self, unlift):
120125
models.ModelWithUnusedArg(),
121126
]
122127

123-
capture_config = (
124-
exir.CaptureConfig(enable_aot=True) if unlift else exir.CaptureConfig()
125-
)
126-
127128
edge_compile_config = exir.EdgeCompileConfig(
128129
_check_ir_validity=False, _use_edge_ops=True
129130
)
130131

131132
for model in module_list:
132133
model_inputs = model.get_random_inputs()
133134

134-
edgeir_m = exir.capture(model, model_inputs, capture_config).to_edge(
135-
edge_compile_config
135+
edgeir_m = to_edge(
136+
export(model, model_inputs), compile_config=edge_compile_config
136137
)
137138
lowered_model = to_backend(
138-
QnnBackend.__name__, edgeir_m.exported_program, []
139+
QnnBackend.__name__, edgeir_m.exported_program(), []
139140
)
140141
program = lowered_model.program()
141142
reference_program = self.get_program_from_wrapped_module(
142-
lowered_model, model_inputs, capture_config, edge_compile_config
143+
lowered_model, model_inputs, edge_compile_config
143144
)
144145

145146
# Check program is fairly equal to the reference program
@@ -180,22 +181,18 @@ def test_emit_nested_lowered_backend_module(self, unlift):
180181
models.ModelWithUnusedArg(),
181182
]
182183

183-
capture_config = (
184-
exir.CaptureConfig(enable_aot=True) if unlift else exir.CaptureConfig()
185-
)
186-
187184
edge_compile_config = exir.EdgeCompileConfig(
188185
_check_ir_validity=False, _use_edge_ops=True
189186
)
190187

191188
for model in module_list:
192189
model_inputs = model.get_random_inputs()
193190

194-
edgeir_m = exir.capture(model, model_inputs, capture_config).to_edge(
195-
edge_compile_config
191+
edgeir_m = to_edge(
192+
export(model, model_inputs), compile_config=edge_compile_config
196193
)
197194
lowered_module = to_backend(
198-
QnnBackend.__name__, edgeir_m.exported_program, []
195+
QnnBackend.__name__, edgeir_m.exported_program(), []
199196
)
200197

201198
# This module will include one operator and two delegate call
@@ -208,12 +205,12 @@ def forward(self, *args):
208205
return self.one_module(*args)
209206

210207
wrapped_module = WrappedModule(lowered_module)
211-
wrapped_module_edge = exir.capture(
212-
wrapped_module, model_inputs, capture_config
213-
).to_edge(edge_compile_config)
208+
wrapped_module_edge = to_edge(
209+
export(wrapped_module, model_inputs), compile_config=edge_compile_config
210+
)
214211

215212
nested_lowered_model = to_backend(
216-
QnnBackend.__name__, wrapped_module_edge.exported_program, []
213+
QnnBackend.__name__, wrapped_module_edge.exported_program(), []
217214
)
218215

219216
program = nested_lowered_model.program()

0 commit comments

Comments
 (0)