Skip to content

Commit 6de8c1f

Browse files
tqchendhruvaray
authored andcommitted
[PY][FFI] Introduce PyNativeObject, enable runtime.String to subclass str (apache#5426)
To make runtime.String to work as naturally as possible in the python side, we make it sub-class the python's str object. Note that however, we cannot sub-class Object at the same time due to python's type layout constraint. We introduce a PyNativeObject class to handle this kind of object sub-classing and updated the FFI to handle PyNativeObject classes. [Relay][Frontend][TFLite] Add parser support for shape and range Signed-off-by: Dhruva Ray <dhruvaray@gmail.com>
1 parent 6c77195 commit 6de8c1f

File tree

13 files changed

+338
-112
lines changed

13 files changed

+338
-112
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ if(MSVC)
108108
endif()
109109
else(MSVC)
110110
if ("${CMAKE_BUILD_TYPE}" STREQUAL "Debug")
111-
message("Build in Debug mode")
111+
message(STATUS "Build in Debug mode")
112112
set(CMAKE_C_FLAGS "-O0 -g -Wall -fPIC ${CMAKE_C_FLAGS}")
113113
set(CMAKE_CXX_FLAGS "-O0 -g -Wall -fPIC ${CMAKE_CXX_FLAGS}")
114114
set(CMAKE_CUDA_FLAGS "-O0 -g -Xcompiler=-Wall -Xcompiler=-fPIC ${CMAKE_CUDA_FLAGS}")

python/tvm/_ffi/_ctypes/object.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ def _return_object(x):
5050
tindex = ctypes.c_uint()
5151
check_call(_LIB.TVMObjectGetTypeIndex(handle, ctypes.byref(tindex)))
5252
cls = OBJECT_TYPE.get(tindex.value, _CLASS_OBJECT)
53+
if issubclass(cls, PyNativeObject):
54+
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
55+
obj.handle = handle
56+
return cls.__from_tvm_object__(cls, obj)
5357
# Avoid calling __init__ of cls, instead directly call __new__
5458
# This allows child class to implement their own __init__
5559
obj = cls.__new__(cls)
@@ -64,6 +68,33 @@ def _return_object(x):
6468
_return_object, TypeCode.OBJECT_RVALUE_REF_ARG)
6569

6670

71+
class PyNativeObject:
72+
"""Base class of all TVM objects that also subclass python's builtin types."""
73+
__slots__ = []
74+
75+
def __init_tvm_object_by_constructor__(self, fconstructor, *args):
76+
"""Initialize the internal tvm_object by calling constructor function.
77+
78+
Parameters
79+
----------
80+
fconstructor : Function
81+
Constructor function.
82+
83+
args: list of objects
84+
The arguments to the constructor
85+
86+
Note
87+
----
88+
We have a special calling convention to call constructor functions.
89+
So the return object is directly set into the object
90+
"""
91+
# pylint: disable=assigning-non-slot
92+
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
93+
obj.__init_handle_by_constructor__(fconstructor, *args)
94+
self.__tvm_object__ = obj
95+
96+
97+
6798
class ObjectBase(object):
6899
"""Base object for all object types"""
69100
__slots__ = ["handle"]

python/tvm/_ffi/_ctypes/packed_func.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from .types import TVMValue, TypeCode
3030
from .types import TVMPackedCFunc, TVMCFuncFinalizer
3131
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _ctx_to_int64
32-
from .object import ObjectBase, _set_class_object
32+
from .object import ObjectBase, PyNativeObject, _set_class_object
3333
from . import object as _object
3434

3535
PackedFuncHandle = ctypes.c_void_p
@@ -123,6 +123,9 @@ def _make_tvm_args(args, temp_args):
123123
values[i].v_handle = ctypes.cast(arg.handle, ctypes.c_void_p)
124124
type_codes[i] = (TypeCode.NDARRAY_HANDLE
125125
if not arg.is_view else TypeCode.DLTENSOR_HANDLE)
126+
elif isinstance(arg, PyNativeObject):
127+
values[i].v_handle = arg.__tvm_object__.handle
128+
type_codes[i] = TypeCode.OBJECT_HANDLE
126129
elif isinstance(arg, _nd._TVM_COMPATS):
127130
values[i].v_handle = ctypes.c_void_p(arg._tvm_handle)
128131
type_codes[i] = arg.__class__._tvm_tcode

python/tvm/_ffi/_cython/object.pxi

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,18 +39,49 @@ cdef inline object make_ret_object(void* chandle):
3939
object_type = OBJECT_TYPE
4040
handle = ctypes_handle(chandle)
4141
CALL(TVMObjectGetTypeIndex(chandle, &tindex))
42+
4243
if tindex < len(OBJECT_TYPE):
4344
cls = OBJECT_TYPE[tindex]
4445
if cls is not None:
46+
if issubclass(cls, PyNativeObject):
47+
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
48+
(<ObjectBase>obj).chandle = chandle
49+
return cls.__from_tvm_object__(cls, obj)
4550
obj = cls.__new__(cls)
4651
else:
4752
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
4853
else:
4954
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
55+
5056
(<ObjectBase>obj).chandle = chandle
5157
return obj
5258

5359

60+
class PyNativeObject:
61+
"""Base class of all TVM objects that also subclass python's builtin types."""
62+
__slots__ = []
63+
64+
def __init_tvm_object_by_constructor__(self, fconstructor, *args):
65+
"""Initialize the internal tvm_object by calling constructor function.
66+
67+
Parameters
68+
----------
69+
fconstructor : Function
70+
Constructor function.
71+
72+
args: list of objects
73+
The arguments to the constructor
74+
75+
Note
76+
----
77+
We have a special calling convention to call constructor functions.
78+
So the return object is directly set into the object
79+
"""
80+
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
81+
obj.__init_handle_by_constructor__(fconstructor, *args)
82+
self.__tvm_object__ = obj
83+
84+
5485
cdef class ObjectBase:
5586
cdef void* chandle
5687

python/tvm/_ffi/_cython/packed_func.pxi

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,9 @@ cdef inline int make_arg(object arg,
109109
value[0].v_handle = (<NDArrayBase>arg).chandle
110110
tcode[0] = (kTVMNDArrayHandle if
111111
not (<NDArrayBase>arg).c_is_view else kTVMDLTensorHandle)
112+
elif isinstance(arg, PyNativeObject):
113+
value[0].v_handle = (<ObjectBase>(arg.__tvm_object__)).chandle
114+
tcode[0] = kTVMObjectHandle
112115
elif isinstance(arg, _TVM_COMPATS):
113116
ptr = arg._tvm_handle
114117
value[0].v_handle = (<void*>ptr)

python/tvm/relay/frontend/tflite.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from .common import ExprTable
3333
from .common import infer_shape as _infer_shape
3434

35+
3536
__all__ = ['from_tflite']
3637

3738
class TensorWrapper(object):
@@ -105,6 +106,7 @@ def __init__(self, model, subgraph, exp_tab):
105106
'PAD': self.convert_pad,
106107
'POW': self.convert_pow,
107108
'PRELU': self.convert_prelu,
109+
'RANGE': self.convert_range,
108110
'REDUCE_ANY': self._convert_reduce_any,
109111
'REDUCE_MAX': self._convert_reduce_max,
110112
'REDUCE_MIN': self._convert_reduce_min,
@@ -115,6 +117,7 @@ def __init__(self, model, subgraph, exp_tab):
115117
'RESIZE_NEAREST_NEIGHBOR': self.convert_resize_nearest_neighbor,
116118
'ROUND': self.convert_round,
117119
'RSQRT': self.convert_rsqrt,
120+
'SHAPE': self.convert_shape,
118121
'SIN': self.convert_sin,
119122
'SLICE': self.convert_slice,
120123
'SOFTMAX': self.convert_softmax,
@@ -552,6 +555,63 @@ def convert_tanh(self, op):
552555

553556
return out
554557

558+
def convert_range(self, op):
559+
"""Convert TFLite Range"""
560+
try:
561+
from tflite.Operator import Operator
562+
from tflite.TensorType import TensorType
563+
except ImportError:
564+
raise ImportError("The tflite package must be installed")
565+
566+
if self.is_quantized(op):
567+
raise tvm.error.OpNotImplemented(
568+
'TFlite quantized RANGE operator is not supported yet.')
569+
570+
assert isinstance(op, Operator)
571+
input_tensors = self.get_input_tensors(op)
572+
assert len(input_tensors) == 3, "input tensors length should be 3"
573+
574+
start, limit, delta = input_tensors[0], input_tensors[1], input_tensors[2]
575+
expressions = []
576+
577+
for t in [start, limit, delta]:
578+
if self.has_expr(t.tensor_idx):
579+
expressions.append(self.get_expr(t.tensor_idx))
580+
else:
581+
tensor_type = self.get_tensor_type_str(t.tensor.Type())
582+
tensor_value = self.get_tensor_value(t)
583+
expressions.append(self.exp_tab.new_const(tensor_value, dtype=tensor_type))
584+
585+
#out type inference
586+
if delta.tensor.Type() == TensorType.FLOAT32:
587+
out_type = self.get_tensor_type_str(delta.tensor.Type())
588+
else:
589+
out_type = self.get_tensor_type_str(start.tensor.Type())
590+
591+
#put type here form op
592+
out = _op.arange(expressions[0], expressions[1], expressions[2], out_type)
593+
594+
return out
595+
596+
def convert_shape(self, op):
597+
"""Convert TFLite Shape"""
598+
try:
599+
from tflite.Operator import Operator
600+
except ImportError:
601+
raise ImportError("The tflite package must be installed")
602+
603+
if self.is_quantized(op):
604+
raise tvm.error.OpNotImplemented(
605+
'TFlite quantized SHAPE operator is not supported yet.')
606+
607+
assert isinstance(op, Operator)
608+
input_tensors = self.get_input_tensors(op)
609+
assert len(input_tensors) == 1, "input tensors length should be 1"
610+
611+
out = _op.shape_of(self.get_expr(input_tensors[0].tensor_idx))
612+
613+
return out
614+
555615
def convert_relu(self, op):
556616
"""Convert TFLite ReLU"""
557617
input_tensors = self.get_input_tensors(op)

python/tvm/runtime/container.py

Lines changed: 23 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616
# under the License.
1717
"""Runtime container structures."""
1818
import tvm._ffi
19-
from tvm._ffi.base import string_types
20-
from tvm.runtime import Object, ObjectTypes
21-
from tvm.runtime import _ffi_api
19+
from .object import Object, PyNativeObject
20+
from .object_generic import ObjectTypes
21+
from . import _ffi_api
22+
2223

2324
def getitem_helper(obj, elem_getter, length, idx):
2425
"""Helper function to implement a pythonic getitem function.
@@ -112,64 +113,26 @@ def tuple_object(fields=None):
112113

113114

114115
@tvm._ffi.register_object("runtime.String")
115-
class String(Object):
116-
"""The string object.
116+
class String(str, PyNativeObject):
117+
"""TVM runtime.String object, represented as a python str.
117118
118119
Parameters
119120
----------
120-
string : str
121-
The string used to construct a runtime String object
122-
123-
Returns
124-
-------
125-
ret : String
126-
The created object.
121+
content : str
122+
The content string used to construct the object.
127123
"""
128-
def __init__(self, string):
129-
self.__init_handle_by_constructor__(_ffi_api.String, string)
130-
131-
def __str__(self):
132-
return _ffi_api.GetStdString(self)
133-
134-
def __len__(self):
135-
return _ffi_api.GetStringSize(self)
136-
137-
def __hash__(self):
138-
return _ffi_api.StringHash(self)
139-
140-
def __eq__(self, other):
141-
if isinstance(other, string_types):
142-
return self.__str__() == other
143-
144-
if not isinstance(other, String):
145-
return False
146-
147-
return _ffi_api.CompareString(self, other) == 0
148-
149-
def __ne__(self, other):
150-
return not self.__eq__(other)
151-
152-
def __gt__(self, other):
153-
return _ffi_api.CompareString(self, other) > 0
154-
155-
def __lt__(self, other):
156-
return _ffi_api.CompareString(self, other) < 0
157-
158-
def __getitem__(self, key):
159-
return self.__str__()[key]
160-
161-
def startswith(self, string):
162-
"""Check if the runtime string starts with a given string
163-
164-
Parameters
165-
----------
166-
string : str
167-
The provided string
168-
169-
Returns
170-
-------
171-
ret : boolean
172-
Return true if the runtime string starts with the given string,
173-
otherwise, false.
174-
"""
175-
return self.__str__().startswith(string)
124+
__slots__ = ["__tvm_object__"]
125+
126+
def __new__(cls, content):
127+
"""Construct from string content."""
128+
val = str.__new__(cls, content)
129+
val.__init_tvm_object_by_constructor__(_ffi_api.String, content)
130+
return val
131+
132+
# pylint: disable=no-self-argument
133+
def __from_tvm_object__(cls, obj):
134+
"""Construct from a given tvm object."""
135+
content = _ffi_api.GetFFIString(obj)
136+
val = str.__new__(cls, content)
137+
val.__tvm_object__ = obj
138+
return val

python/tvm/runtime/object.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,11 @@
2727
if _FFI_MODE == "ctypes":
2828
raise ImportError()
2929
from tvm._ffi._cy3.core import _set_class_object, _set_class_object_generic
30-
from tvm._ffi._cy3.core import ObjectBase
30+
from tvm._ffi._cy3.core import ObjectBase, PyNativeObject
3131
except (RuntimeError, ImportError):
3232
# pylint: disable=wrong-import-position,unused-import
3333
from tvm._ffi._ctypes.packed_func import _set_class_object, _set_class_object_generic
34-
from tvm._ffi._ctypes.object import ObjectBase
34+
from tvm._ffi._ctypes.object import ObjectBase, PyNativeObject
3535

3636

3737
def _new_object(cls):
@@ -41,6 +41,7 @@ def _new_object(cls):
4141

4242
class Object(ObjectBase):
4343
"""Base class for all tvm's runtime objects."""
44+
__slots__ = []
4445
def __repr__(self):
4546
return _ffi_node_api.AsRepr(self)
4647

@@ -78,13 +79,10 @@ def __getstate__(self):
7879
def __setstate__(self, state):
7980
# pylint: disable=assigning-non-slot, assignment-from-no-return
8081
handle = state['handle']
82+
self.handle = None
8183
if handle is not None:
82-
json_str = handle
83-
other = _ffi_node_api.LoadJSON(json_str)
84-
self.handle = other.handle
85-
other.handle = None
86-
else:
87-
self.handle = None
84+
self.__init_handle_by_constructor__(
85+
_ffi_node_api.LoadJSON, handle)
8886

8987
def _move(self):
9088
"""Create an RValue reference to the object and mark the object as moved.

python/tvm/runtime/object_generic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from tvm._ffi.runtime_ctypes import ObjectRValueRef
2222

2323
from . import _ffi_node_api, _ffi_api
24-
from .object import ObjectBase, _set_class_object_generic
24+
from .object import ObjectBase, PyNativeObject, _set_class_object_generic
2525
from .ndarray import NDArrayBase
2626
from .packed_func import PackedFuncBase, convert_to_tvm_func
2727
from .module import Module
@@ -34,7 +34,7 @@ def asobject(self):
3434
raise NotImplementedError()
3535

3636

37-
ObjectTypes = (ObjectBase, NDArrayBase, Module, ObjectRValueRef)
37+
ObjectTypes = (ObjectBase, NDArrayBase, Module, ObjectRValueRef, PyNativeObject)
3838

3939

4040
def convert_to_object(value):

0 commit comments

Comments
 (0)