Skip to content

[pybinding] Add mapping from C++ program::verification to Python #6140

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions extension/pybindings/portable_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
_reset_profile_results, # noqa: F401
BundledModule, # noqa: F401
ExecuTorchModule, # noqa: F401
Verification, # noqa: F401
)

# Clean up so that `dir(portable_lib)` is the same as `dir(_portable_lib)`
Expand Down
69 changes: 51 additions & 18 deletions extension/pybindings/pybindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,15 @@ class Module final {
explicit Module(
std::unique_ptr<DataLoader> loader,
std::unique_ptr<ETDumpGen> tracer = nullptr,
size_t debug_buffer_size = 0)
size_t debug_buffer_size = 0,
Program::Verification program_verification =
Program::Verification::InternalConsistency)
: loader_(std::move(loader)),
event_tracer_(std::move(tracer)),
debug_buffer_size_(debug_buffer_size) {
::executorch::runtime::runtime_init();
Result<Program> program = Program::load(
loader_.get(), Program::Verification::InternalConsistency);
Result<Program> program =
Program::load(loader_.get(), program_verification);
THROW_IF_ERROR(
program.error(),
"loading program failed with error: 0x%" PRIx32,
Expand Down Expand Up @@ -375,19 +377,22 @@ inline std::unique_ptr<Module> load_module_from_buffer(
const void* ptr,
size_t ptr_len,
bool enable_etdump,
size_t debug_buffer_size) {
size_t debug_buffer_size,
Program::Verification program_verification) {
EXECUTORCH_SCOPE_PROF("load_module_from_buffer");
auto loader = std::make_unique<BufferDataLoader>(ptr, ptr_len);
return std::make_unique<Module>(
std::move(loader),
enable_etdump ? std::make_unique<torch::executor::ETDumpGen>() : nullptr,
debug_buffer_size);
debug_buffer_size,
program_verification);
}

inline std::unique_ptr<Module> load_module_from_file(
const std::string& path,
bool enable_etdump,
size_t debug_buffer_size) {
size_t debug_buffer_size,
Program::Verification program_verification) {
EXECUTORCH_SCOPE_PROF("load_module_from_file");

Result<MmapDataLoader> res = MmapDataLoader::from(
Expand All @@ -402,7 +407,8 @@ inline std::unique_ptr<Module> load_module_from_file(
return std::make_unique<Module>(
std::move(loader),
enable_etdump ? std::make_unique<torch::executor::ETDumpGen>() : nullptr,
debug_buffer_size);
debug_buffer_size,
program_verification);
}

static constexpr size_t kDEFAULT_BUNDLED_INPUT_POOL_SIZE = 16 * 1024U;
Expand Down Expand Up @@ -452,30 +458,41 @@ struct PyModule final {
explicit PyModule(
const py::bytes& buffer,
bool enable_etdump,
size_t debug_buffer_size = 0)
size_t debug_buffer_size = 0,
Program::Verification program_verification =
Program::Verification::InternalConsistency)
: module_(load_module_from_buffer(
buffer.cast<std::string_view>().data(),
py::len(buffer),
enable_etdump,
debug_buffer_size)) {}
debug_buffer_size,
program_verification)) {}

explicit PyModule(
const void* ptr,
size_t ptr_len,
bool enable_etdump,
size_t debug_buffer_size = 0)
size_t debug_buffer_size = 0,
Program::Verification program_verification =
Program::Verification::InternalConsistency)
: module_(load_module_from_buffer(
ptr,
ptr_len,
enable_etdump,
debug_buffer_size)) {}
debug_buffer_size,
program_verification)) {}

explicit PyModule(
const std::string& path,
bool enable_etdump,
size_t debug_buffer_size = 0)
: module_(load_module_from_file(path, enable_etdump, debug_buffer_size)) {
}
size_t debug_buffer_size = 0,
Program::Verification program_verification =
Program::Verification::InternalConsistency)
: module_(load_module_from_file(
path,
enable_etdump,
debug_buffer_size,
program_verification)) {}

PyModule(const PyModule&) = delete;
PyModule& operator=(const PyModule&) = delete;
Expand All @@ -486,14 +503,20 @@ struct PyModule final {
static std::unique_ptr<PyModule> load_from_buffer(
const py::bytes& buffer,
bool enable_etdump,
size_t debug_buffer_size = 0) {
return std::make_unique<PyModule>(buffer, enable_etdump, debug_buffer_size);
size_t debug_buffer_size = 0,
Program::Verification program_verification =
Program::Verification::InternalConsistency) {
return std::make_unique<PyModule>(
buffer, enable_etdump, debug_buffer_size, program_verification);
}
static std::unique_ptr<PyModule> load_from_file(
const std::string& path,
bool enable_etdump,
size_t debug_buffer_size = 0) {
return std::make_unique<PyModule>(path, enable_etdump, debug_buffer_size);
size_t debug_buffer_size = 0,
Program::Verification program_verification =
Program::Verification::InternalConsistency) {
return std::make_unique<PyModule>(
path, enable_etdump, debug_buffer_size, program_verification);
}

static std::unique_ptr<PyModule> load_from_bundled_program(
Expand Down Expand Up @@ -805,19 +828,29 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
// Redirects cout and cerr for function calls this guards to the python env.
auto call_guard = py::
call_guard<py::scoped_ostream_redirect, py::scoped_estream_redirect>();

// Bind the verification enum to python.
py::enum_<Program::Verification>(m, "Verification")
.value("Minimal", Program::Verification::Minimal)
.value("InternalConsistency", Program::Verification::InternalConsistency);

m.def(
"_load_for_executorch",
PyModule::load_from_file,
py::arg("path"),
py::arg("enable_etdump") = false,
py::arg("debug_buffer_size") = 0,
py::arg("program_verification") =
Program::Verification::InternalConsistency,
call_guard);
m.def(
"_load_for_executorch_from_buffer",
&PyModule::load_from_buffer,
py::arg("buffer"),
py::arg("enable_etdump") = false,
py::arg("debug_buffer_size") = 0,
py::arg("program_verification") =
Program::Verification::InternalConsistency,
call_guard);
m.def(
"_load_for_executorch_from_bundled_program",
Expand Down
26 changes: 23 additions & 3 deletions extension/pybindings/pybindings.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,24 @@
# LICENSE file in the root directory of this source tree.

# pyre-strict
from typing import Any, Dict, List, Optional, Sequence, Tuple
from __future__ import annotations

from typing import Any, Dict, Enum, List, Optional, Sequence, Tuple

from executorch.exir._warnings import experimental

@experimental("This API is experimental and subject to change without notice.")
class Verification(Enum):
"""Verification maps C++ Program::Verification to Python.

.. warning::

This API is experimental and subject to change without notice.
"""

Minimal: ...
InternalConsistency: ...

@experimental("This API is experimental and subject to change without notice.")
class ExecuTorchModule:
"""ExecuTorchModule is a Python wrapper around a C++ ExecuTorch program.
Expand Down Expand Up @@ -56,7 +70,10 @@ class BundledModule:

@experimental("This API is experimental and subject to change without notice.")
def _load_for_executorch(
path: str, enable_etdump: bool = False, debug_buffer_size: int = 0
path: str,
enable_etdump: bool = False,
debug_buffer_size: int = 0,
program_verification: Verification = Verification.InternalConsistency,
) -> ExecuTorchModule:
"""Load an ExecuTorch Program from a file.

Expand All @@ -79,7 +96,10 @@ def _load_for_executorch(

@experimental("This API is experimental and subject to change without notice.")
def _load_for_executorch_from_buffer(
buffer: bytes, enable_etdump: bool = False, debug_buffer_size: int = 0
buffer: bytes,
enable_etdump: bool = False,
debug_buffer_size: int = 0,
program_verification: Verification = Verification.InternalConsistency,
) -> ExecuTorchModule:
"""Same as _load_for_executorch, but takes a byte buffer instead of a file path.

Expand Down
117 changes: 115 additions & 2 deletions extension/pybindings/test/make_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
# pyre-unsafe

import unittest
from typing import Any, Callable, Tuple
from types import ModuleType
from typing import Any, Callable, Optional, Tuple

import torch
from executorch.exir import ExecutorchProgramManager, to_edge
Expand All @@ -16,7 +17,7 @@

def make_test( # noqa: C901
tester: unittest.TestCase,
load_fn: Callable,
runtime: ModuleType,
) -> Callable[[unittest.TestCase], None]:
"""
Returns a function that operates as a test case within a unittest.TestCase class.
Expand All @@ -25,6 +26,7 @@ def make_test( # noqa: C901
which will all have different load functions. In this case each individual test case is a
subfunction of wrapper.
"""
load_fn: Callable = runtime._load_for_executorch_from_buffer

def wrapper(tester: unittest.TestCase) -> None:
class ModuleAdd(torch.nn.Module):
Expand Down Expand Up @@ -251,12 +253,123 @@ def test_quantized_ops(tester):
expected = example_inputs[0] + example_inputs[1]
tester.assertEqual(str(expected), str(executorch_output))

def test_constant_output_not_memory_planned(tester):
# Create an ExecuTorch program from ModuleAdd.
exported_program, inputs = create_program(
ModuleAddConstReturn(),
et_config=ExecutorchBackendConfig(
memory_planning_pass=MemoryPlanningPass(alloc_graph_output=False)
),
)

exported_program.dump_executorch_program(verbose=True)

# Use pybindings to load and execute the program.
executorch_module = load_fn(exported_program.buffer)
# Invoke the callable on executorch_module instead of calling module.forward.
# Use only one input to test this case.
executorch_output = executorch_module((torch.ones(2, 2),))
print(executorch_output)

# The test module adds the input to torch.ones(2,2), so its output should be the same
# as adding them directly.
expected = torch.ones(2, 2) + torch.ones(2, 2)
tester.assertEqual(str(expected), str(executorch_output[0]))

# The test module returns the state. Check that its value is correct.
tester.assertEqual(str(torch.ones(2, 2)), str(executorch_output[1]))

def test_method_meta(tester) -> None:
# pyre-fixme[16]: Callable `make_test` has no attribute `wrapper`.
exported_program, inputs = create_program(ModuleAdd())

# Use pybindings to load the program and query its metadata.
executorch_module = load_fn(exported_program.buffer)
meta = executorch_module.method_meta("forward")

# Ensure that all these APIs work even if the module object is destroyed.
del executorch_module
tester.assertEqual(meta.name(), "forward")
tester.assertEqual(meta.num_inputs(), 2)
tester.assertEqual(meta.num_outputs(), 1)
# Common string for all these tensors.
tensor_info = "TensorInfo(sizes=[2, 2], dtype=Float, is_memory_planned=True, nbytes=16)"
float_dtype = 6
tester.assertEqual(
str(meta),
"MethodMeta(name='forward', num_inputs=2, "
f"input_tensor_meta=['{tensor_info}', '{tensor_info}'], "
f"num_outputs=1, output_tensor_meta=['{tensor_info}'])",
)

input_tensors = [meta.input_tensor_meta(i) for i in range(2)]
output_tensor = meta.output_tensor_meta(0)
# Check that accessing out of bounds raises IndexError.
with tester.assertRaises(IndexError):
meta.input_tensor_meta(2)
# Test that tensor metadata can outlive method metadata.
del meta
tester.assertEqual([t.sizes() for t in input_tensors], [(2, 2), (2, 2)])
tester.assertEqual(
[t.dtype() for t in input_tensors], [float_dtype, float_dtype]
)
tester.assertEqual(
[t.is_memory_planned() for t in input_tensors], [True, True]
)
tester.assertEqual([t.nbytes() for t in input_tensors], [16, 16])
tester.assertEqual(str(input_tensors), f"[{tensor_info}, {tensor_info}]")

tester.assertEqual(output_tensor.sizes(), (2, 2))
tester.assertEqual(output_tensor.dtype(), float_dtype)
tester.assertEqual(output_tensor.is_memory_planned(), True)
tester.assertEqual(output_tensor.nbytes(), 16)
tester.assertEqual(str(output_tensor), tensor_info)

def test_bad_name(tester) -> None:
# Create an ExecuTorch program from ModuleAdd.
# pyre-fixme[16]: Callable `make_test` has no attribute `wrapper`.
exported_program, inputs = create_program(ModuleAdd())

# Use pybindings to load and execute the program.
executorch_module = load_fn(exported_program.buffer)
# Invoke the callable on executorch_module instead of calling module.forward.
with tester.assertRaises(RuntimeError):
executorch_module.run_method("not_a_real_method", inputs)

def test_verification_config(tester) -> None:
# Create an ExecuTorch program from ModuleAdd.
# pyre-fixme[16]: Callable `make_test` has no attribute `wrapper`.
exported_program, inputs = create_program(ModuleAdd())
Verification = runtime.Verification

# Use pybindings to load and execute the program.
for config in [Verification.Minimal, Verification.InternalConsistency]:
executorch_module = load_fn(
exported_program.buffer,
enable_etdump=False,
debug_buffer_size=0,
program_verification=config,
)

executorch_output = executorch_module.forward(inputs)[0]

# The test module adds the two inputs, so its output should be the same
# as adding them directly.
expected = inputs[0] + inputs[1]

tester.assertEqual(str(expected), str(executorch_output))

######### RUN TEST CASES #########
test_e2e(tester)
test_multiple_entry(tester)
test_output_lifespan(tester)
test_module_callable(tester)
test_module_single_input(tester)
test_stderr_redirect(tester)
test_quantized_ops(tester)
test_constant_output_not_memory_planned(tester)
test_method_meta(tester)
test_bad_name(tester)
test_verification_config(tester)

return wrapper
Loading
Loading