Skip to content

Commit 94a73f4

Browse files
edlopertensorflower-gardener
authored andcommitted
Replace the mechanism used to register & look up Python types from c code in tensorflow/python/util.h with one that supports non-type symbols as well.
PiperOrigin-RevId: 314996128 Change-Id: I8edca2552d5d45cf74a1f5fb00bc88996af033da
1 parent 9d572b8 commit 94a73f4

File tree

3 files changed

+54
-19
lines changed

3 files changed

+54
-19
lines changed

tensorflow/python/util/util.cc

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,25 @@ limitations under the License.
2929
namespace tensorflow {
3030
namespace swig {
3131

32-
std::unordered_map<string, PyObject*>* PythonTypesMap() {
32+
namespace {
33+
string PyObjectToString(PyObject* o);
34+
} // namespace
35+
36+
std::unordered_map<string, PyObject*>* RegisteredPyObjectMap() {
3337
static auto* m = new std::unordered_map<string, PyObject*>();
3438
return m;
3539
}
3640

37-
PyObject* GetRegisteredType(const string& key) {
38-
auto* m = PythonTypesMap();
39-
auto it = m->find(key);
40-
if (it == m->end()) return nullptr;
41+
PyObject* GetRegisteredPyObject(const string& name) {
42+
const auto* m = RegisteredPyObjectMap();
43+
auto it = m->find(name);
44+
if (it == m->end()) {
45+
PyErr_SetString(PyExc_TypeError,
46+
tensorflow::strings::StrCat("No object with name ", name,
47+
" has been registered.")
48+
.c_str());
49+
return nullptr;
50+
}
4151
return it->second;
4252
}
4353

@@ -49,26 +59,35 @@ PyObject* RegisterType(PyObject* type_name, PyObject* type) {
4959
.c_str());
5060
return nullptr;
5161
}
62+
return RegisterPyObject(type_name, type);
63+
}
5264

65+
PyObject* RegisterPyObject(PyObject* name, PyObject* value) {
5366
string key;
54-
if (PyBytes_Check(type_name)) {
55-
key = PyBytes_AsString(type_name);
56-
}
67+
if (PyBytes_Check(name)) {
68+
key = PyBytes_AsString(name);
5769
#if PY_MAJOR_VERSION >= 3
58-
if (PyUnicode_Check(type_name)) {
59-
key = PyUnicode_AsUTF8(type_name);
60-
}
70+
} else if (PyUnicode_Check(name)) {
71+
key = PyUnicode_AsUTF8(name);
6172
#endif
73+
} else {
74+
PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat(
75+
"Expected name to be a str, got",
76+
PyObjectToString(name))
77+
.c_str());
78+
return nullptr;
79+
}
6280

63-
if (PythonTypesMap()->find(key) != PythonTypesMap()->end()) {
81+
auto* m = RegisteredPyObjectMap();
82+
if (m->find(key) != m->end()) {
6483
PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat(
65-
"Type already registered for ", key)
84+
"Value already registered for ", key)
6685
.c_str());
6786
return nullptr;
6887
}
6988

70-
Py_INCREF(type);
71-
PythonTypesMap()->emplace(key, type);
89+
Py_INCREF(value);
90+
m->emplace(key, value);
7291

7392
Py_RETURN_NONE;
7493
}
@@ -196,7 +215,7 @@ class CachedTypeCheck {
196215
// Returns 0 otherwise.
197216
// Returns -1 if an error occurred (e.g., if 'type_name' is not registered.)
198217
int IsInstanceOfRegisteredType(PyObject* obj, const char* type_name) {
199-
PyObject* type_obj = GetRegisteredType(type_name);
218+
PyObject* type_obj = GetRegisteredPyObject(type_name);
200219
if (TF_PREDICT_FALSE(type_obj == nullptr)) {
201220
PyErr_SetString(PyExc_RuntimeError,
202221
tensorflow::strings::StrCat(
@@ -513,7 +532,8 @@ class AttrsValueIterator : public ValueIterator {
513532
};
514533

515534
bool IsSparseTensorValueType(PyObject* o) {
516-
PyObject* sparse_tensor_value_type = GetRegisteredType("SparseTensorValue");
535+
PyObject* sparse_tensor_value_type =
536+
GetRegisteredPyObject("SparseTensorValue");
517537
if (TF_PREDICT_FALSE(sparse_tensor_value_type == nullptr)) {
518538
return false;
519539
}

tensorflow/python/util/util.h

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ limitations under the License.
1919

2020
#include <Python.h>
2121

22+
#include <string>
23+
2224
namespace tensorflow {
2325
namespace swig {
2426

@@ -270,11 +272,20 @@ PyObject* FlattenForData(PyObject* nested);
270272
PyObject* AssertSameStructureForData(PyObject* o1, PyObject* o2,
271273
bool check_types);
272274

273-
// RegisterType is used to pass PyTypeObject (which is defined in python) for an
274-
// arbitrary identifier `type_name` into C++.
275+
// Registers a Python object so it can be looked up from c++. The set of
276+
// valid names, and the expected values for those names, are listed in
277+
// the documentation for `RegisteredPyObjects`. Returns PyNone.
278+
PyObject* RegisterPyObject(PyObject* name, PyObject* value);
279+
280+
// Variant of RegisterPyObject that requires the object's value to be a type.
275281
PyObject* RegisterType(PyObject* type_name, PyObject* type);
276282

277283
} // namespace swig
284+
285+
// Returns a borrowed reference to an object that was registered with
286+
// RegisterPyObject. (Do not call PY_DECREF on the result).
287+
PyObject* GetRegisteredPyObject(const std::string& name);
288+
278289
} // namespace tensorflow
279290

280291
#endif // TENSORFLOW_PYTHON_UTIL_UTIL_H_

tensorflow/python/util/util_wrapper.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ PYBIND11_MODULE(_pywrap_utils, m) {
3030
return tensorflow::PyoOrThrow(
3131
tensorflow::swig::RegisterType(type_name.ptr(), type.ptr()));
3232
});
33+
m.def("RegisterPyObject", [](const py::handle& name, const py::handle& type) {
34+
return tensorflow::PyoOrThrow(
35+
tensorflow::swig::RegisterPyObject(name.ptr(), type.ptr()));
36+
});
3337
m.def(
3438
"IsTensor",
3539
[](const py::handle& o) {

0 commit comments

Comments
 (0)