Skip to content

Commit 9116de1

Browse files
tqchendhruvaray
authored andcommitted
[PY][FFI] Introduce PyNativeObject, enable runtime.String to subclass str (apache#5426)
To make runtime.String to work as naturally as possible in the python side, we make it sub-class the python's str object. Note that however, we cannot sub-class Object at the same time due to python's type layout constraint. We introduce a PyNativeObject class to handle this kind of object sub-classing and updated the FFI to handle PyNativeObject classes.
1 parent f0b5a9e commit 9116de1

File tree

11 files changed

+132
-89
lines changed

11 files changed

+132
-89
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ if(MSVC)
108108
endif()
109109
else(MSVC)
110110
if ("${CMAKE_BUILD_TYPE}" STREQUAL "Debug")
111-
message("Build in Debug mode")
111+
message(STATUS "Build in Debug mode")
112112
set(CMAKE_C_FLAGS "-O0 -g -Wall -fPIC ${CMAKE_C_FLAGS}")
113113
set(CMAKE_CXX_FLAGS "-O0 -g -Wall -fPIC ${CMAKE_CXX_FLAGS}")
114114
set(CMAKE_CUDA_FLAGS "-O0 -g -Xcompiler=-Wall -Xcompiler=-fPIC ${CMAKE_CUDA_FLAGS}")

python/tvm/_ffi/_ctypes/object.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ def _return_object(x):
5050
tindex = ctypes.c_uint()
5151
check_call(_LIB.TVMObjectGetTypeIndex(handle, ctypes.byref(tindex)))
5252
cls = OBJECT_TYPE.get(tindex.value, _CLASS_OBJECT)
53+
if issubclass(cls, PyNativeObject):
54+
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
55+
obj.handle = handle
56+
return cls.__from_tvm_object__(cls, obj)
5357
# Avoid calling __init__ of cls, instead directly call __new__
5458
# This allows child class to implement their own __init__
5559
obj = cls.__new__(cls)
@@ -64,6 +68,33 @@ def _return_object(x):
6468
_return_object, TypeCode.OBJECT_RVALUE_REF_ARG)
6569

6670

71+
class PyNativeObject:
72+
"""Base class of all TVM objects that also subclass python's builtin types."""
73+
__slots__ = []
74+
75+
def __init_tvm_object_by_constructor__(self, fconstructor, *args):
76+
"""Initialize the internal tvm_object by calling constructor function.
77+
78+
Parameters
79+
----------
80+
fconstructor : Function
81+
Constructor function.
82+
83+
args: list of objects
84+
The arguments to the constructor
85+
86+
Note
87+
----
88+
We have a special calling convention to call constructor functions.
89+
So the return object is directly set into the object
90+
"""
91+
# pylint: disable=assigning-non-slot
92+
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
93+
obj.__init_handle_by_constructor__(fconstructor, *args)
94+
self.__tvm_object__ = obj
95+
96+
97+
6798
class ObjectBase(object):
6899
"""Base object for all object types"""
69100
__slots__ = ["handle"]

python/tvm/_ffi/_ctypes/packed_func.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from .types import TVMValue, TypeCode
3030
from .types import TVMPackedCFunc, TVMCFuncFinalizer
3131
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _ctx_to_int64
32-
from .object import ObjectBase, _set_class_object
32+
from .object import ObjectBase, PyNativeObject, _set_class_object
3333
from . import object as _object
3434

3535
PackedFuncHandle = ctypes.c_void_p
@@ -123,6 +123,9 @@ def _make_tvm_args(args, temp_args):
123123
values[i].v_handle = ctypes.cast(arg.handle, ctypes.c_void_p)
124124
type_codes[i] = (TypeCode.NDARRAY_HANDLE
125125
if not arg.is_view else TypeCode.DLTENSOR_HANDLE)
126+
elif isinstance(arg, PyNativeObject):
127+
values[i].v_handle = arg.__tvm_object__.handle
128+
type_codes[i] = TypeCode.OBJECT_HANDLE
126129
elif isinstance(arg, _nd._TVM_COMPATS):
127130
values[i].v_handle = ctypes.c_void_p(arg._tvm_handle)
128131
type_codes[i] = arg.__class__._tvm_tcode

python/tvm/_ffi/_cython/object.pxi

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,18 +39,49 @@ cdef inline object make_ret_object(void* chandle):
3939
object_type = OBJECT_TYPE
4040
handle = ctypes_handle(chandle)
4141
CALL(TVMObjectGetTypeIndex(chandle, &tindex))
42+
4243
if tindex < len(OBJECT_TYPE):
4344
cls = OBJECT_TYPE[tindex]
4445
if cls is not None:
46+
if issubclass(cls, PyNativeObject):
47+
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
48+
(<ObjectBase>obj).chandle = chandle
49+
return cls.__from_tvm_object__(cls, obj)
4550
obj = cls.__new__(cls)
4651
else:
4752
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
4853
else:
4954
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
55+
5056
(<ObjectBase>obj).chandle = chandle
5157
return obj
5258

5359

60+
class PyNativeObject:
61+
"""Base class of all TVM objects that also subclass python's builtin types."""
62+
__slots__ = []
63+
64+
def __init_tvm_object_by_constructor__(self, fconstructor, *args):
65+
"""Initialize the internal tvm_object by calling constructor function.
66+
67+
Parameters
68+
----------
69+
fconstructor : Function
70+
Constructor function.
71+
72+
args: list of objects
73+
The arguments to the constructor
74+
75+
Note
76+
----
77+
We have a special calling convention to call constructor functions.
78+
So the return object is directly set into the object
79+
"""
80+
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
81+
obj.__init_handle_by_constructor__(fconstructor, *args)
82+
self.__tvm_object__ = obj
83+
84+
5485
cdef class ObjectBase:
5586
cdef void* chandle
5687

python/tvm/_ffi/_cython/packed_func.pxi

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,9 @@ cdef inline int make_arg(object arg,
109109
value[0].v_handle = (<NDArrayBase>arg).chandle
110110
tcode[0] = (kTVMNDArrayHandle if
111111
not (<NDArrayBase>arg).c_is_view else kTVMDLTensorHandle)
112+
elif isinstance(arg, PyNativeObject):
113+
value[0].v_handle = (<ObjectBase>(arg.__tvm_object__)).chandle
114+
tcode[0] = kTVMObjectHandle
112115
elif isinstance(arg, _TVM_COMPATS):
113116
ptr = arg._tvm_handle
114117
value[0].v_handle = (<void*>ptr)

python/tvm/runtime/container.py

Lines changed: 23 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616
# under the License.
1717
"""Runtime container structures."""
1818
import tvm._ffi
19-
from tvm._ffi.base import string_types
20-
from tvm.runtime import Object, ObjectTypes
21-
from tvm.runtime import _ffi_api
19+
from .object import Object, PyNativeObject
20+
from .object_generic import ObjectTypes
21+
from . import _ffi_api
22+
2223

2324
def getitem_helper(obj, elem_getter, length, idx):
2425
"""Helper function to implement a pythonic getitem function.
@@ -112,64 +113,26 @@ def tuple_object(fields=None):
112113

113114

114115
@tvm._ffi.register_object("runtime.String")
115-
class String(Object):
116-
"""The string object.
116+
class String(str, PyNativeObject):
117+
"""TVM runtime.String object, represented as a python str.
117118
118119
Parameters
119120
----------
120-
string : str
121-
The string used to construct a runtime String object
122-
123-
Returns
124-
-------
125-
ret : String
126-
The created object.
121+
content : str
122+
The content string used to construct the object.
127123
"""
128-
def __init__(self, string):
129-
self.__init_handle_by_constructor__(_ffi_api.String, string)
130-
131-
def __str__(self):
132-
return _ffi_api.GetStdString(self)
133-
134-
def __len__(self):
135-
return _ffi_api.GetStringSize(self)
136-
137-
def __hash__(self):
138-
return _ffi_api.StringHash(self)
139-
140-
def __eq__(self, other):
141-
if isinstance(other, string_types):
142-
return self.__str__() == other
143-
144-
if not isinstance(other, String):
145-
return False
146-
147-
return _ffi_api.CompareString(self, other) == 0
148-
149-
def __ne__(self, other):
150-
return not self.__eq__(other)
151-
152-
def __gt__(self, other):
153-
return _ffi_api.CompareString(self, other) > 0
154-
155-
def __lt__(self, other):
156-
return _ffi_api.CompareString(self, other) < 0
157-
158-
def __getitem__(self, key):
159-
return self.__str__()[key]
160-
161-
def startswith(self, string):
162-
"""Check if the runtime string starts with a given string
163-
164-
Parameters
165-
----------
166-
string : str
167-
The provided string
168-
169-
Returns
170-
-------
171-
ret : boolean
172-
Return true if the runtime string starts with the given string,
173-
otherwise, false.
174-
"""
175-
return self.__str__().startswith(string)
124+
__slots__ = ["__tvm_object__"]
125+
126+
def __new__(cls, content):
127+
"""Construct from string content."""
128+
val = str.__new__(cls, content)
129+
val.__init_tvm_object_by_constructor__(_ffi_api.String, content)
130+
return val
131+
132+
# pylint: disable=no-self-argument
133+
def __from_tvm_object__(cls, obj):
134+
"""Construct from a given tvm object."""
135+
content = _ffi_api.GetFFIString(obj)
136+
val = str.__new__(cls, content)
137+
val.__tvm_object__ = obj
138+
return val

python/tvm/runtime/object.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,11 @@
2727
if _FFI_MODE == "ctypes":
2828
raise ImportError()
2929
from tvm._ffi._cy3.core import _set_class_object, _set_class_object_generic
30-
from tvm._ffi._cy3.core import ObjectBase
30+
from tvm._ffi._cy3.core import ObjectBase, PyNativeObject
3131
except (RuntimeError, ImportError):
3232
# pylint: disable=wrong-import-position,unused-import
3333
from tvm._ffi._ctypes.packed_func import _set_class_object, _set_class_object_generic
34-
from tvm._ffi._ctypes.object import ObjectBase
34+
from tvm._ffi._ctypes.object import ObjectBase, PyNativeObject
3535

3636

3737
def _new_object(cls):
@@ -41,6 +41,7 @@ def _new_object(cls):
4141

4242
class Object(ObjectBase):
4343
"""Base class for all tvm's runtime objects."""
44+
__slots__ = []
4445
def __repr__(self):
4546
return _ffi_node_api.AsRepr(self)
4647

@@ -78,13 +79,10 @@ def __getstate__(self):
7879
def __setstate__(self, state):
7980
# pylint: disable=assigning-non-slot, assignment-from-no-return
8081
handle = state['handle']
82+
self.handle = None
8183
if handle is not None:
82-
json_str = handle
83-
other = _ffi_node_api.LoadJSON(json_str)
84-
self.handle = other.handle
85-
other.handle = None
86-
else:
87-
self.handle = None
84+
self.__init_handle_by_constructor__(
85+
_ffi_node_api.LoadJSON, handle)
8886

8987
def _move(self):
9088
"""Create an RValue reference to the object and mark the object as moved.

python/tvm/runtime/object_generic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from tvm._ffi.runtime_ctypes import ObjectRValueRef
2222

2323
from . import _ffi_node_api, _ffi_api
24-
from .object import ObjectBase, _set_class_object_generic
24+
from .object import ObjectBase, PyNativeObject, _set_class_object_generic
2525
from .ndarray import NDArrayBase
2626
from .packed_func import PackedFuncBase, convert_to_tvm_func
2727
from .module import Module
@@ -34,7 +34,7 @@ def asobject(self):
3434
raise NotImplementedError()
3535

3636

37-
ObjectTypes = (ObjectBase, NDArrayBase, Module, ObjectRValueRef)
37+
ObjectTypes = (ObjectBase, NDArrayBase, Module, ObjectRValueRef, PyNativeObject)
3838

3939

4040
def convert_to_object(value):

src/runtime/container.cc

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
/*!
2121
* \file src/runtime/container.cc
22-
* \brief Implementations of common plain old data (POD) containers.
22+
* \brief Implementations of common containers.
2323
*/
2424
#include <tvm/runtime/container.h>
2525
#include <tvm/runtime/memory.h>
@@ -81,26 +81,11 @@ TVM_REGISTER_GLOBAL("runtime.String")
8181
return String(std::move(str));
8282
});
8383

84-
TVM_REGISTER_GLOBAL("runtime.GetStringSize")
85-
.set_body_typed([](String str) {
86-
return static_cast<int64_t>(str.size());
87-
});
88-
89-
TVM_REGISTER_GLOBAL("runtime.GetStdString")
84+
TVM_REGISTER_GLOBAL("runtime.GetFFIString")
9085
.set_body_typed([](String str) {
9186
return std::string(str);
9287
});
9388

94-
TVM_REGISTER_GLOBAL("runtime.CompareString")
95-
.set_body_typed([](String lhs, String rhs) {
96-
return lhs.compare(rhs);
97-
});
98-
99-
TVM_REGISTER_GLOBAL("runtime.StringHash")
100-
.set_body_typed([](String str) {
101-
return static_cast<int64_t>(std::hash<String>()(str));
102-
});
103-
10489
TVM_REGISTER_OBJECT_TYPE(ADTObj);
10590
TVM_REGISTER_OBJECT_TYPE(StringObj);
10691
TVM_REGISTER_OBJECT_TYPE(ClosureObj);

src/support/ffi_testing.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@ TVM_REGISTER_GLOBAL("testing.nop")
5858
.set_body([](TVMArgs args, TVMRetValue *ret) {
5959
});
6060

61+
TVM_REGISTER_GLOBAL("testing.echo")
62+
.set_body([](TVMArgs args, TVMRetValue *ret) {
63+
*ret = args[0];
64+
});
65+
6166
TVM_REGISTER_GLOBAL("testing.test_wrap_callback")
6267
.set_body([](TVMArgs args, TVMRetValue *ret) {
6368
PackedFunc pf = args[0];

0 commit comments

Comments
 (0)