Skip to content

remove exir.capture from test_lowered_backend_module #3169

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 30 additions & 33 deletions exir/backend/test/test_lowered_backend_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import torch
from executorch import exir
from executorch.exir import to_edge
from executorch.exir.backend.backend_api import to_backend
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.backend.test.backend_with_compiler_demo import (
Expand All @@ -22,6 +23,7 @@
_load_for_executorch_from_buffer,
)
from hypothesis import given, settings, strategies as st
from torch.export import export


class TestBackendAPI(unittest.TestCase):
Expand All @@ -44,7 +46,7 @@ def validate_lowered_module_program(self, program: Program) -> None:
)

def get_program_from_wrapped_module(
self, lowered_module, example_inputs, capture_config, edge_compile_config
self, lowered_module, example_inputs, edge_compile_config
):
class WrappedModule(torch.nn.Module):
def __init__(self):
Expand All @@ -55,17 +57,16 @@ def forward(self, *args):
return self.one_module(*args)

return (
exir.capture(WrappedModule(), example_inputs, capture_config)
.to_edge(edge_compile_config)
to_edge(
export(WrappedModule(), example_inputs),
compile_config=edge_compile_config,
)
.to_executorch()
.program
.executorch_program
)

@given(
unlift=st.booleans(), # verify both lifted and unlifted graph
)
@settings(deadline=500000)
def test_emit_lowered_backend_module_end_to_end(self, unlift):
def test_emit_lowered_backend_module_end_to_end(self):
class SinModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -76,15 +77,19 @@ def forward(self, x):
sin_module = SinModule()
model_inputs = (torch.ones(1),)
expected_res = sin_module(*model_inputs)
edgeir_m = exir.capture(
sin_module,
model_inputs,
exir.CaptureConfig(pt2_mode=True, enable_aot=True, _unlift=unlift),
).to_edge(exir.EdgeCompileConfig(_check_ir_validity=False, _use_edge_ops=True))
edgeir_m = to_edge(
export(
sin_module,
model_inputs,
),
compile_config=exir.EdgeCompileConfig(
_check_ir_validity=False, _use_edge_ops=True
),
)
max_value = model_inputs[0].shape[0]
compile_specs = [CompileSpec("max_value", bytes([max_value]))]
lowered_sin_module = to_backend(
BackendWithCompilerDemo.__name__, edgeir_m.exported_program, compile_specs
BackendWithCompilerDemo.__name__, edgeir_m.exported_program(), compile_specs
)

new_res = lowered_sin_module(*model_inputs)
Expand Down Expand Up @@ -120,26 +125,22 @@ def test_emit_lowered_backend_module(self, unlift):
models.ModelWithUnusedArg(),
]

capture_config = (
exir.CaptureConfig(enable_aot=True) if unlift else exir.CaptureConfig()
)

edge_compile_config = exir.EdgeCompileConfig(
_check_ir_validity=False, _use_edge_ops=True
)

for model in module_list:
model_inputs = model.get_random_inputs()

edgeir_m = exir.capture(model, model_inputs, capture_config).to_edge(
edge_compile_config
edgeir_m = to_edge(
export(model, model_inputs), compile_config=edge_compile_config
)
lowered_model = to_backend(
QnnBackend.__name__, edgeir_m.exported_program, []
QnnBackend.__name__, edgeir_m.exported_program(), []
)
program = lowered_model.program()
reference_program = self.get_program_from_wrapped_module(
lowered_model, model_inputs, capture_config, edge_compile_config
lowered_model, model_inputs, edge_compile_config
)

# Check program is fairly equal to the reference program
Expand Down Expand Up @@ -180,22 +181,18 @@ def test_emit_nested_lowered_backend_module(self, unlift):
models.ModelWithUnusedArg(),
]

capture_config = (
exir.CaptureConfig(enable_aot=True) if unlift else exir.CaptureConfig()
)

edge_compile_config = exir.EdgeCompileConfig(
_check_ir_validity=False, _use_edge_ops=True
)

for model in module_list:
model_inputs = model.get_random_inputs()

edgeir_m = exir.capture(model, model_inputs, capture_config).to_edge(
edge_compile_config
edgeir_m = to_edge(
export(model, model_inputs), compile_config=edge_compile_config
)
lowered_module = to_backend(
QnnBackend.__name__, edgeir_m.exported_program, []
QnnBackend.__name__, edgeir_m.exported_program(), []
)

# This module will include one operator and two delegate call
Expand All @@ -208,12 +205,12 @@ def forward(self, *args):
return self.one_module(*args)

wrapped_module = WrappedModule(lowered_module)
wrapped_module_edge = exir.capture(
wrapped_module, model_inputs, capture_config
).to_edge(edge_compile_config)
wrapped_module_edge = to_edge(
export(wrapped_module, model_inputs), compile_config=edge_compile_config
)

nested_lowered_model = to_backend(
QnnBackend.__name__, wrapped_module_edge.exported_program, []
QnnBackend.__name__, wrapped_module_edge.exported_program(), []
)

program = nested_lowered_model.program()
Expand Down
Loading