Skip to content

Commit 7eed94a

Browse files
dulinrileyfacebook-github-bot
authored andcommitted
Add MethodMeta object for python visibility (#5571)
Summary: Pull Request resolved: #5571 Some clients and consumers of the Executorch program files (.pte) were requesting ways to access metadata like the sizes of tensors and the number of bytes they needed. When I told them how to access them in C++, they requested Python wrappers since they had processing scripts written in Python. Add some implementations of MethodMeta and TensorInfo methods. Note that these become more expensive than in C++ because they need to allocate python objects, but I doubt these are used in performance-sensitive applications anyway. And dealing with lifetimes of mixed C++/Python objects is complex, so I favored simple lifetimes. Differential Revision: D63288433
1 parent 3e79ea4 commit 7eed94a

File tree

3 files changed

+249
-2
lines changed

3 files changed

+249
-2
lines changed

extension/pybindings/pybindings.cpp

Lines changed: 150 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include <executorch/extension/data_loader/mmap_data_loader.h>
2525
#include <executorch/extension/memory_allocator/malloc_memory_allocator.h>
2626
#include <executorch/runtime/core/data_loader.h>
27+
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
2728
#include <executorch/runtime/executor/method.h>
2829
#include <executorch/runtime/executor/program.h>
2930
#include <executorch/runtime/kernel/operator_registry.h>
@@ -448,6 +449,123 @@ struct PyBundledModule final {
448449
size_t program_len_;
449450
};
450451

452+
/// Expose a subset of TensorInfo information to python.
453+
struct PyTensorInfo final {
454+
explicit PyTensorInfo(
455+
std::shared_ptr<Module> module,
456+
torch::executor::TensorInfo info)
457+
: module_(std::move(module)), info_(info) {}
458+
459+
py::tuple sizes() const {
460+
const auto shape = info_.sizes();
461+
py::tuple tup(shape.size());
462+
for (size_t i = 0; i < shape.size(); ++i) {
463+
tup[i] = py::cast(shape[i]);
464+
}
465+
return tup;
466+
}
467+
468+
int8_t dtype() const {
469+
return static_cast<std::underlying_type<torch::executor::ScalarType>::type>(
470+
info_.scalar_type());
471+
}
472+
473+
bool is_memory_planned() const {
474+
return info_.is_memory_planned();
475+
}
476+
477+
size_t nbytes() const {
478+
return info_.nbytes();
479+
}
480+
481+
std::string repr() const {
482+
std::string size_str = "[";
483+
for (const auto& d : info_.sizes()) {
484+
size_str.append(std::to_string(d));
485+
size_str.append(", ");
486+
}
487+
if (size_str.length() >= 2) {
488+
// Pop the last two characters (command and space) and add close bracket.
489+
size_str.pop_back();
490+
size_str.pop_back();
491+
}
492+
size_str.append("]");
493+
return "TensorInfo(sizes=" + size_str + ", dtype=" +
494+
std::string(executorch::runtime::toString(info_.scalar_type())) +
495+
", is_memory_planned=" + std::to_string(info_.is_memory_planned()) +
496+
", nbytes=" + std::to_string(info_.nbytes()) + ")";
497+
}
498+
499+
private:
500+
// TensorInfo relies on module to be alive.
501+
std::shared_ptr<Module> module_;
502+
torch::executor::TensorInfo info_;
503+
};
504+
505+
/// Expose a subset of MethodMeta information to python.
506+
struct PyMethodMeta final {
507+
explicit PyMethodMeta(
508+
std::shared_ptr<Module> module,
509+
torch::executor::MethodMeta meta)
510+
: module_(std::move(module)), meta_(meta) {}
511+
512+
const char* name() const {
513+
return meta_.name();
514+
}
515+
516+
size_t num_inputs() const {
517+
return meta_.num_inputs();
518+
}
519+
520+
std::unique_ptr<PyTensorInfo> input_tensor_meta(size_t index) const {
521+
const auto result = meta_.input_tensor_meta(index);
522+
THROW_IF_ERROR(
523+
result.error(),
524+
"Cannot get input tensor meta at %zu: 0x%" PRIx32,
525+
index,
526+
static_cast<uint32_t>(result.error()));
527+
return std::make_unique<PyTensorInfo>(module_, result.get());
528+
}
529+
530+
size_t num_outputs() const {
531+
return meta_.num_outputs();
532+
}
533+
534+
std::unique_ptr<PyTensorInfo> output_tensor_meta(size_t index) const {
535+
const auto result = meta_.output_tensor_meta(index);
536+
THROW_IF_ERROR(
537+
result.error(),
538+
"Cannot get output tensor meta at %zu: 0x%" PRIx32,
539+
index,
540+
static_cast<uint32_t>(result.error()));
541+
return std::make_unique<PyTensorInfo>(module_, result.get());
542+
}
543+
544+
py::str repr() const {
545+
py::list input_meta_strs;
546+
for (size_t i = 0; i < meta_.num_inputs(); ++i) {
547+
input_meta_strs.append(py::str(input_tensor_meta(i)->repr()));
548+
}
549+
py::list output_meta_strs;
550+
for (size_t i = 0; i < meta_.num_outputs(); ++i) {
551+
output_meta_strs.append(py::str(output_tensor_meta(i)->repr()));
552+
}
553+
py::str format =
554+
"MethodMeta(name={}, num_inputs={}, input_tensor_meta={}, num_outputs={}, output_tensor_meta={})";
555+
return format.format(
556+
std::string(meta_.name()),
557+
std::to_string(meta_.num_inputs()),
558+
input_meta_strs,
559+
std::to_string(meta_.num_outputs()),
560+
output_meta_strs);
561+
}
562+
563+
private:
564+
// Must keep the Module object alive or else the meta object is invalidated.
565+
std::shared_ptr<Module> module_;
566+
torch::executor::MethodMeta meta_;
567+
};
568+
451569
struct PyModule final {
452570
explicit PyModule(
453571
const py::bytes& buffer,
@@ -751,8 +869,13 @@ struct PyModule final {
751869
return list;
752870
}
753871

872+
std::unique_ptr<PyMethodMeta> method_meta(const std::string method_name) {
873+
auto& method = module_->get_method(method_name);
874+
return std::make_unique<PyMethodMeta>(module_, method.method_meta());
875+
}
876+
754877
private:
755-
std::unique_ptr<Module> module_;
878+
std::shared_ptr<Module> module_;
756879
// Need to keep-alive output storages until they can be compared in case of
757880
// bundled programs.
758881
std::vector<std::vector<uint8_t>> output_storages_;
@@ -866,6 +989,11 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
866989
py::arg("method_name"),
867990
py::arg("clone_outputs") = true,
868991
call_guard)
992+
.def(
993+
"method_meta",
994+
&PyModule::method_meta,
995+
py::arg("method_name"),
996+
call_guard)
869997
.def(
870998
"run_method",
871999
&PyModule::run_method,
@@ -900,6 +1028,27 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
9001028
call_guard);
9011029

9021030
py::class_<PyBundledModule>(m, "BundledModule");
1031+
py::class_<PyTensorInfo>(m, "TensorInfo")
1032+
.def("sizes", &PyTensorInfo::sizes, call_guard)
1033+
.def("dtype", &PyTensorInfo::dtype, call_guard)
1034+
.def("is_memory_planned", &PyTensorInfo::is_memory_planned, call_guard)
1035+
.def("nbytes", &PyTensorInfo::nbytes, call_guard)
1036+
.def("__repr__", &PyTensorInfo::repr, call_guard);
1037+
py::class_<PyMethodMeta>(m, "MethodMeta")
1038+
.def("name", &PyMethodMeta::name, call_guard)
1039+
.def("num_inputs", &PyMethodMeta::num_inputs, call_guard)
1040+
.def("num_outputs", &PyMethodMeta::num_outputs, call_guard)
1041+
.def(
1042+
"input_tensor_meta",
1043+
&PyMethodMeta::input_tensor_meta,
1044+
py::arg("index"),
1045+
call_guard)
1046+
.def(
1047+
"output_tensor_meta",
1048+
&PyMethodMeta::output_tensor_meta,
1049+
py::arg("index"),
1050+
call_guard)
1051+
.def("__repr__", &PyMethodMeta::repr, call_guard);
9031052
}
9041053

9051054
} // namespace pybindings

extension/pybindings/pybindings.pyi

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
# LICENSE file in the root directory of this source tree.
66

77
# pyre-strict
8+
from __future__ import annotations
9+
810
from typing import Any, Dict, List, Optional, Sequence, Tuple
911

1012
from executorch.exir._warnings import experimental
@@ -43,6 +45,7 @@ class ExecuTorchModule:
4345
def write_etdump_result_to_file(
4446
self, path: str, debug_buffer_path: Optional[str] = None
4547
) -> None: ...
48+
def method_meta(self, method_name: str) -> MethodMeta: ...
4649

4750
@experimental("This API is experimental and subject to change without notice.")
4851
class BundledModule:
@@ -54,6 +57,72 @@ class BundledModule:
5457

5558
...
5659

60+
@experimental("This API is experimental and subject to change without notice.")
61+
class TensorInfo:
62+
"""Metadata about a tensor such as the shape and dtype.
63+
64+
.. warning::
65+
66+
This API is experimental and subject to change without notice.
67+
"""
68+
69+
def sizes(self) -> Tuple[int, ...]:
70+
"""Shape of the tensor as a tuple"""
71+
...
72+
73+
def dtype(self) -> int:
74+
"""The data type of the elements inside the tensor.
75+
See documentation for ScalarType in executorch/runtime/core/portable_type/scalar_type.h
76+
for the values these integers can take."""
77+
...
78+
79+
def is_memory_planned(self) -> bool:
80+
"""True if the tensor is already memory planned, meaning no allocation
81+
needs to be provided. False otherwise"""
82+
...
83+
84+
def nbytes(self) -> int:
85+
"""Number of bytes in the tensor. Not the same as numel if the dtype is
86+
larger than 1 byte wide"""
87+
...
88+
89+
def __repr__(self) -> str: ...
90+
91+
@experimental("This API is experimental and subject to change without notice.")
92+
class MethodMeta:
93+
"""Metadata about a method such as the number of inputs and outputs.
94+
95+
.. warning::
96+
97+
This API is experimental and subject to change without notice.
98+
"""
99+
100+
def name(self) -> str:
101+
"""The name of the method, such as 'forward'"""
102+
...
103+
104+
def num_inputs(self) -> int:
105+
"""The number of user inputs to the method. This does not include any
106+
internal buffers or weights, which don't need to be provided by the user"""
107+
...
108+
109+
def num_outputs(self) -> int:
110+
"""The number of outputs from the method. This does not include any mutated
111+
internal buffers"""
112+
...
113+
114+
def input_tensor_meta(self, index: int) -> TensorInfo:
115+
"""The tensor info for the 'index'th input. Index must be in the interval
116+
[0, num_inputs())"""
117+
...
118+
119+
def output_tensor_meta(self, index: int) -> TensorInfo:
120+
"""The tensor info for the 'index'th output. Index must be in the interval
121+
[0, num_outputs())"""
122+
...
123+
124+
def __repr__(self) -> str: ...
125+
57126
@experimental("This API is experimental and subject to change without notice.")
58127
def _load_for_executorch(
59128
path: str, enable_etdump: bool = False, debug_buffer_size: int = 0

extension/pybindings/test/make_test.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,8 +295,36 @@ def test_constant_output_not_memory_planned(tester):
295295
# The test module returns the state. Check that its value is correct.
296296
tester.assertEqual(str(torch.ones(2, 2)), str(executorch_output[1]))
297297

298-
######### RUN TEST CASES #########
298+
def test_method_meta(tester) -> None:
299+
exported_program, inputs = create_program(ModuleAdd())
299300

301+
# Use pybindings to load the program and query its metadata.
302+
executorch_module = load_fn(exported_program.buffer)
303+
meta = executorch_module.method_meta("forward")
304+
305+
# Ensure that all these APIs work even if the module object is destroyed.
306+
del executorch_module
307+
tester.assertEqual(meta.name(), "forward")
308+
tester.assertEqual(meta.num_inputs(), 2)
309+
tester.assertEqual(meta.num_outputs(), 1)
310+
311+
input_tensors = [meta.input_tensor_meta(i) for i in range(2)]
312+
output_tensor = meta.output_tensor_meta(0)
313+
# Test that tensor metadata can outlive method metadata.
314+
del meta
315+
tester.assertEqual([t.sizes() for t in input_tensors], [(2, 2), (2, 2)])
316+
tester.assertEqual([t.dtype() for t in input_tensors], [6, 6])
317+
tester.assertEqual(
318+
[t.is_memory_planned() for t in input_tensors], [True, True]
319+
)
320+
tester.assertEqual([t.nbytes() for t in input_tensors], [16, 16])
321+
322+
tester.assertEqual(output_tensor.sizes(), (2, 2))
323+
tester.assertEqual(output_tensor.dtype(), 6)
324+
tester.assertEqual(output_tensor.is_memory_planned(), True)
325+
tester.assertEqual(output_tensor.nbytes(), 16)
326+
327+
######### RUN TEST CASES #########
300328
test_e2e(tester)
301329
test_multiple_entry(tester)
302330
test_output_lifespan(tester)
@@ -305,5 +333,6 @@ def test_constant_output_not_memory_planned(tester):
305333
test_stderr_redirect(tester)
306334
test_quantized_ops(tester)
307335
test_constant_output_not_memory_planned(tester)
336+
test_method_meta(tester)
308337

309338
return wrapper

0 commit comments

Comments
 (0)