Skip to content

Commit

Permalink
Add MethodMeta object for python visibility (#5571)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #5571

Some clients and consumers of the Executorch program files (.pte) were
requesting ways to access metadata like the sizes of tensors and the number
of bytes they needed.
When I told them how to access them in C++, they requested Python wrappers
since they had processing scripts written in Python.

Add some implementations of MethodMeta and TensorInfo methods.
Note that these become more expensive than in C++ because they need to
allocate python objects, but I doubt these are used in
performance-sensitive applications anyway. And dealing with
lifetimes of mixed C++/Python objects is complex, so I favored simple lifetimes.

Differential Revision: D63288433
  • Loading branch information
dulinriley authored and facebook-github-bot committed Sep 24, 2024
1 parent 99ee547 commit 11f2716
Show file tree
Hide file tree
Showing 3 changed files with 271 additions and 2 deletions.
157 changes: 156 additions & 1 deletion extension/pybindings/pybindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <executorch/extension/data_loader/mmap_data_loader.h>
#include <executorch/extension/memory_allocator/malloc_memory_allocator.h>
#include <executorch/runtime/core/data_loader.h>
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
#include <executorch/runtime/executor/method.h>
#include <executorch/runtime/executor/program.h>
#include <executorch/runtime/kernel/operator_registry.h>
Expand Down Expand Up @@ -55,6 +56,16 @@
} \
})

#define THROW_INDEX_IF_ERROR(error, message, ...) \
({ \
if ((error) != Error::Ok) { \
char msg_buf[128]; \
snprintf(msg_buf, sizeof(msg_buf), message, ##__VA_ARGS__); \
/* pybind will convert this to a python exception. */ \
throw std::out_of_range(msg_buf); \
} \
})

// Our logs work by writing to stderr. By default this is done through fprintf
// (as defined in posix.cpp) which then does not show up in python environments.
// Here we override the pal to use std::cerr which can be properly redirected by
Expand Down Expand Up @@ -448,6 +459,119 @@ struct PyBundledModule final {
size_t program_len_;
};

/// Expose a subset of TensorInfo information to python.
struct PyTensorInfo final {
explicit PyTensorInfo(
std::shared_ptr<Module> module,
torch::executor::TensorInfo info)
: module_(std::move(module)), info_(info) {}

py::tuple sizes() const {
const auto shape = info_.sizes();
py::tuple tup(shape.size());
for (size_t i = 0; i < shape.size(); ++i) {
tup[i] = py::cast(shape[i]);
}
return tup;
}

int8_t dtype() const {
return static_cast<std::underlying_type<exec_aten::ScalarType>::type>(
info_.scalar_type());
}

bool is_memory_planned() const {
return info_.is_memory_planned();
}

size_t nbytes() const {
return info_.nbytes();
}

std::string repr() const {
std::string size_str = "[";
for (const auto& d : info_.sizes()) {
size_str.append(std::to_string(d));
size_str.append(", ");
}
if (size_str.length() >= 2) {
// Pop the last two characters (command and space) and add close bracket.
size_str.pop_back();
size_str.pop_back();
}
size_str.append("]");
return "TensorInfo(sizes=" + size_str + ", dtype=" +
std::string(executorch::runtime::toString(info_.scalar_type())) +
", is_memory_planned=" +
(info_.is_memory_planned() ? "True" : "False") +
", nbytes=" + std::to_string(info_.nbytes()) + ")";
}

private:
// TensorInfo relies on module to be alive.
std::shared_ptr<Module> module_;
torch::executor::TensorInfo info_;
};

/// Expose a subset of MethodMeta information to python.
struct PyMethodMeta final {
explicit PyMethodMeta(
std::shared_ptr<Module> module,
torch::executor::MethodMeta meta)
: module_(std::move(module)), meta_(meta) {}

const char* name() const {
return meta_.name();
}

size_t num_inputs() const {
return meta_.num_inputs();
}

std::unique_ptr<PyTensorInfo> input_tensor_meta(size_t index) const {
const auto result = meta_.input_tensor_meta(index);
THROW_INDEX_IF_ERROR(
result.error(), "Cannot get input tensor meta at %zu", index);
return std::make_unique<PyTensorInfo>(module_, result.get());
}

size_t num_outputs() const {
return meta_.num_outputs();
}

std::unique_ptr<PyTensorInfo> output_tensor_meta(size_t index) const {
const auto result = meta_.output_tensor_meta(index);
THROW_INDEX_IF_ERROR(
result.error(), "Cannot get output tensor meta at %zu", index);
return std::make_unique<PyTensorInfo>(module_, result.get());
}

py::str repr() const {
py::list input_meta_strs;
for (size_t i = 0; i < meta_.num_inputs(); ++i) {
input_meta_strs.append(py::str(input_tensor_meta(i)->repr()));
}
py::list output_meta_strs;
for (size_t i = 0; i < meta_.num_outputs(); ++i) {
output_meta_strs.append(py::str(output_tensor_meta(i)->repr()));
}
// Add quotes to be more similar to Python's repr for strings.
py::str format =
"MethodMeta(name='{}', num_inputs={}, input_tensor_meta={}, num_outputs={}, output_tensor_meta={})";
return format.format(
std::string(meta_.name()),
std::to_string(meta_.num_inputs()),
input_meta_strs,
std::to_string(meta_.num_outputs()),
output_meta_strs);
}

private:
// Must keep the Module object alive or else the meta object is invalidated.
std::shared_ptr<Module> module_;
torch::executor::MethodMeta meta_;
};

struct PyModule final {
explicit PyModule(
const py::bytes& buffer,
Expand Down Expand Up @@ -751,8 +875,13 @@ struct PyModule final {
return list;
}

std::unique_ptr<PyMethodMeta> method_meta(const std::string method_name) {
auto& method = module_->get_method(method_name);
return std::make_unique<PyMethodMeta>(module_, method.method_meta());
}

private:
std::unique_ptr<Module> module_;
std::shared_ptr<Module> module_;
// Need to keep-alive output storages until they can be compared in case of
// bundled programs.
std::vector<std::vector<uint8_t>> output_storages_;
Expand Down Expand Up @@ -866,6 +995,11 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
py::arg("method_name"),
py::arg("clone_outputs") = true,
call_guard)
.def(
"method_meta",
&PyModule::method_meta,
py::arg("method_name"),
call_guard)
.def(
"run_method",
&PyModule::run_method,
Expand Down Expand Up @@ -900,6 +1034,27 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
call_guard);

py::class_<PyBundledModule>(m, "BundledModule");
py::class_<PyTensorInfo>(m, "TensorInfo")
.def("sizes", &PyTensorInfo::sizes, call_guard)
.def("dtype", &PyTensorInfo::dtype, call_guard)
.def("is_memory_planned", &PyTensorInfo::is_memory_planned, call_guard)
.def("nbytes", &PyTensorInfo::nbytes, call_guard)
.def("__repr__", &PyTensorInfo::repr, call_guard);
py::class_<PyMethodMeta>(m, "MethodMeta")
.def("name", &PyMethodMeta::name, call_guard)
.def("num_inputs", &PyMethodMeta::num_inputs, call_guard)
.def("num_outputs", &PyMethodMeta::num_outputs, call_guard)
.def(
"input_tensor_meta",
&PyMethodMeta::input_tensor_meta,
py::arg("index"),
call_guard)
.def(
"output_tensor_meta",
&PyMethodMeta::output_tensor_meta,
py::arg("index"),
call_guard)
.def("__repr__", &PyMethodMeta::repr, call_guard);
}

} // namespace pybindings
Expand Down
69 changes: 69 additions & 0 deletions extension/pybindings/pybindings.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# LICENSE file in the root directory of this source tree.

# pyre-strict
from __future__ import annotations

from typing import Any, Dict, List, Optional, Sequence, Tuple

from executorch.exir._warnings import experimental
Expand Down Expand Up @@ -43,6 +45,7 @@ class ExecuTorchModule:
def write_etdump_result_to_file(
self, path: str, debug_buffer_path: Optional[str] = None
) -> None: ...
def method_meta(self, method_name: str) -> MethodMeta: ...

@experimental("This API is experimental and subject to change without notice.")
class BundledModule:
Expand All @@ -54,6 +57,72 @@ class BundledModule:

...

@experimental("This API is experimental and subject to change without notice.")
class TensorInfo:
"""Metadata about a tensor such as the shape and dtype.
.. warning::
This API is experimental and subject to change without notice.
"""

def sizes(self) -> Tuple[int, ...]:
"""Shape of the tensor as a tuple"""
...

def dtype(self) -> int:
"""The data type of the elements inside the tensor.
See documentation for ScalarType in executorch/runtime/core/portable_type/scalar_type.h
for the values these integers can take."""
...

def is_memory_planned(self) -> bool:
"""True if the tensor is already memory planned, meaning no allocation
needs to be provided. False otherwise"""
...

def nbytes(self) -> int:
"""Number of bytes in the tensor. Not the same as numel if the dtype is
larger than 1 byte wide"""
...

def __repr__(self) -> str: ...

@experimental("This API is experimental and subject to change without notice.")
class MethodMeta:
"""Metadata about a method such as the number of inputs and outputs.
.. warning::
This API is experimental and subject to change without notice.
"""

def name(self) -> str:
"""The name of the method, such as 'forward'"""
...

def num_inputs(self) -> int:
"""The number of user inputs to the method. This does not include any
internal buffers or weights, which don't need to be provided by the user"""
...

def num_outputs(self) -> int:
"""The number of outputs from the method. This does not include any mutated
internal buffers"""
...

def input_tensor_meta(self, index: int) -> TensorInfo:
"""The tensor info for the 'index'th input. Index must be in the interval
[0, num_inputs()). Raises an IndexError if the index is out of bounds"""
...

def output_tensor_meta(self, index: int) -> TensorInfo:
"""The tensor info for the 'index'th output. Index must be in the interval
[0, num_outputs()). Raises an IndexError if the index is out of bounds"""
...

def __repr__(self) -> str: ...

@experimental("This API is experimental and subject to change without notice.")
def _load_for_executorch(
path: str, enable_etdump: bool = False, debug_buffer_size: int = 0
Expand Down
47 changes: 46 additions & 1 deletion extension/pybindings/test/make_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,52 @@ def test_constant_output_not_memory_planned(tester):
# The test module returns the state. Check that its value is correct.
tester.assertEqual(str(torch.ones(2, 2)), str(executorch_output[1]))

######### RUN TEST CASES #########
def test_method_meta(tester) -> None:
exported_program, inputs = create_program(ModuleAdd())

# Use pybindings to load the program and query its metadata.
executorch_module = load_fn(exported_program.buffer)
meta = executorch_module.method_meta("forward")

# Ensure that all these APIs work even if the module object is destroyed.
del executorch_module
tester.assertEqual(meta.name(), "forward")
tester.assertEqual(meta.num_inputs(), 2)
tester.assertEqual(meta.num_outputs(), 1)
# Common string for all these tensors.
tensor_info = "TensorInfo(sizes=[2, 2], dtype=Float, is_memory_planned=True, nbytes=16)"
float_dtype = 6
tester.assertEqual(
str(meta),
"MethodMeta(name='forward', num_inputs=2, "
f"input_tensor_meta=['{tensor_info}', '{tensor_info}'], "
f"num_outputs=1, output_tensor_meta=['{tensor_info}'])",
)

input_tensors = [meta.input_tensor_meta(i) for i in range(2)]
output_tensor = meta.output_tensor_meta(0)
# Check that accessing out of bounds raises IndexError.
with tester.assertRaises(IndexError):
meta.input_tensor_meta(2)
# Test that tensor metadata can outlive method metadata.
del meta
tester.assertEqual([t.sizes() for t in input_tensors], [(2, 2), (2, 2)])
tester.assertEqual(
[t.dtype() for t in input_tensors], [float_dtype, float_dtype]
)
tester.assertEqual(
[t.is_memory_planned() for t in input_tensors], [True, True]
)
tester.assertEqual([t.nbytes() for t in input_tensors], [16, 16])
tester.assertEqual(str(input_tensors), f"[{tensor_info}, {tensor_info}]")

tester.assertEqual(output_tensor.sizes(), (2, 2))
tester.assertEqual(output_tensor.dtype(), float_dtype)
tester.assertEqual(output_tensor.is_memory_planned(), True)
tester.assertEqual(output_tensor.nbytes(), 16)
tester.assertEqual(str(output_tensor), tensor_info)

######### RUN TEST CASES #########
test_e2e(tester)
test_multiple_entry(tester)
test_output_lifespan(tester)
Expand All @@ -305,5 +349,6 @@ def test_constant_output_not_memory_planned(tester):
test_stderr_redirect(tester)
test_quantized_ops(tester)
test_constant_output_not_memory_planned(tester)
test_method_meta(tester)

return wrapper

0 comments on commit 11f2716

Please sign in to comment.