@@ -29,15 +29,25 @@ limitations under the License.
29
29
namespace tensorflow {
30
30
namespace swig {
31
31
32
- std::unordered_map<string, PyObject*>* PythonTypesMap () {
32
+ namespace {
33
+ string PyObjectToString (PyObject* o);
34
+ } // namespace
35
+
36
+ std::unordered_map<string, PyObject*>* RegisteredPyObjectMap () {
33
37
static auto * m = new std::unordered_map<string, PyObject*>();
34
38
return m;
35
39
}
36
40
37
- PyObject* GetRegisteredType (const string& key) {
38
- auto * m = PythonTypesMap ();
39
- auto it = m->find (key);
40
- if (it == m->end ()) return nullptr ;
41
+ PyObject* GetRegisteredPyObject (const string& name) {
42
+ const auto * m = RegisteredPyObjectMap ();
43
+ auto it = m->find (name);
44
+ if (it == m->end ()) {
45
+ PyErr_SetString (PyExc_TypeError,
46
+ tensorflow::strings::StrCat (" No object with name " , name,
47
+ " has been registered." )
48
+ .c_str ());
49
+ return nullptr ;
50
+ }
41
51
return it->second ;
42
52
}
43
53
@@ -49,26 +59,35 @@ PyObject* RegisterType(PyObject* type_name, PyObject* type) {
49
59
.c_str ());
50
60
return nullptr ;
51
61
}
62
+ return RegisterPyObject (type_name, type);
63
+ }
52
64
65
+ PyObject* RegisterPyObject (PyObject* name, PyObject* value) {
53
66
string key;
54
- if (PyBytes_Check (type_name)) {
55
- key = PyBytes_AsString (type_name);
56
- }
67
+ if (PyBytes_Check (name)) {
68
+ key = PyBytes_AsString (name);
57
69
#if PY_MAJOR_VERSION >= 3
58
- if (PyUnicode_Check (type_name)) {
59
- key = PyUnicode_AsUTF8 (type_name);
60
- }
70
+ } else if (PyUnicode_Check (name)) {
71
+ key = PyUnicode_AsUTF8 (name);
61
72
#endif
73
+ } else {
74
+ PyErr_SetString (PyExc_TypeError, tensorflow::strings::StrCat (
75
+ " Expected name to be a str, got" ,
76
+ PyObjectToString (name))
77
+ .c_str ());
78
+ return nullptr ;
79
+ }
62
80
63
- if (PythonTypesMap ()->find (key) != PythonTypesMap ()->end ()) {
81
+ auto * m = RegisteredPyObjectMap ();
82
+ if (m->find (key) != m->end ()) {
64
83
PyErr_SetString (PyExc_TypeError, tensorflow::strings::StrCat (
65
- " Type already registered for " , key)
84
+ " Value already registered for " , key)
66
85
.c_str ());
67
86
return nullptr ;
68
87
}
69
88
70
- Py_INCREF (type );
71
- PythonTypesMap () ->emplace (key, type );
89
+ Py_INCREF (value );
90
+ m ->emplace (key, value );
72
91
73
92
Py_RETURN_NONE;
74
93
}
@@ -196,7 +215,7 @@ class CachedTypeCheck {
196
215
// Returns 0 otherwise.
197
216
// Returns -1 if an error occurred (e.g., if 'type_name' is not registered.)
198
217
int IsInstanceOfRegisteredType (PyObject* obj, const char * type_name) {
199
- PyObject* type_obj = GetRegisteredType (type_name);
218
+ PyObject* type_obj = GetRegisteredPyObject (type_name);
200
219
if (TF_PREDICT_FALSE (type_obj == nullptr )) {
201
220
PyErr_SetString (PyExc_RuntimeError,
202
221
tensorflow::strings::StrCat (
@@ -513,7 +532,8 @@ class AttrsValueIterator : public ValueIterator {
513
532
};
514
533
515
534
bool IsSparseTensorValueType (PyObject* o) {
516
- PyObject* sparse_tensor_value_type = GetRegisteredType (" SparseTensorValue" );
535
+ PyObject* sparse_tensor_value_type =
536
+ GetRegisteredPyObject (" SparseTensorValue" );
517
537
if (TF_PREDICT_FALSE (sparse_tensor_value_type == nullptr )) {
518
538
return false ;
519
539
}
0 commit comments