Skip to content

Commit 02c1e11

Browse files
authored
[RUNTIME] Refactor object python FFI to new protocol. (#4128)
* [RUNTIME] Refactor object python FFI to new protocol. This is a pre-req to bring the Node system under object protocol. Most of the code reflects the current code in the Node system. - Use new instead of init so subclass can define their own constructors - Allow register via name, besides type idnex - Introduce necessary runtime C API functions - Refactored Tensor and Datatype to directly use constructor. * address review comments
1 parent e3fbdc8 commit 02c1e11

File tree

23 files changed

+440
-252
lines changed

23 files changed

+440
-252
lines changed

include/tvm/runtime/c_runtime_api.h

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ typedef enum {
104104
kStr = 11U,
105105
kBytes = 12U,
106106
kNDArrayContainer = 13U,
107-
kObjectCell = 14U,
107+
kObjectHandle = 14U,
108108
// Extension codes for other frameworks to integrate TVM PackedFunc.
109109
// To make sure each framework's id do not conflict, use first and
110110
// last sections to mark ranges.
@@ -549,13 +549,31 @@ TVM_DLL int TVMStreamStreamSynchronize(int device_type,
549549
TVMStreamHandle dst);
550550

551551
/*!
552-
* \brief Get the tag from an object.
552+
* \brief Get the type_index from an object.
553553
*
554554
* \param obj The object handle.
555-
* \param tag The tag of object.
555+
* \param out_tindex the output type index.
556556
* \return 0 when success, -1 when failure happens
557557
*/
558-
TVM_DLL int TVMGetObjectTag(TVMObjectHandle obj, int* tag);
558+
TVM_DLL int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex);
559+
560+
/*!
561+
* \brief Convert type key to type index.
562+
* \param type_key The key of the type.
563+
* \param out_tindex the corresponding type index.
564+
* \return 0 when success, -1 when failure happens
565+
*/
566+
TVM_DLL int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex);
567+
568+
/*!
569+
* \brief Free the object.
570+
*
571+
* \param obj The object handle.
572+
* \note Internally we decrease the reference counter of the object.
573+
* The object will be freed when every reference to the object are removed.
574+
* \return 0 when success, -1 when failure happens
575+
*/
576+
TVM_DLL int TVMObjectFree(TVMObjectHandle obj);
559577

560578
#ifdef __cplusplus
561579
} // TVM_EXTERN_C

include/tvm/runtime/object.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ class Object {
253253
template<typename>
254254
friend class ObjectPtr;
255255
friend class TVMRetValue;
256+
friend class TVMObjectCAPI;
256257
};
257258

258259
/*!

include/tvm/runtime/packed_func.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,7 @@ class TVMPODValue_ {
491491
}
492492
operator ObjectRef() const {
493493
if (type_code_ == kNull) return ObjectRef(ObjectPtr<Object>(nullptr));
494-
TVM_CHECK_TYPE_CODE(type_code_, kObjectCell);
494+
TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle);
495495
return ObjectRef(ObjectPtr<Object>(static_cast<Object*>(value_.v_handle)));
496496
}
497497
operator TVMContext() const {
@@ -761,7 +761,7 @@ class TVMRetValue : public TVMPODValue_ {
761761
}
762762
TVMRetValue& operator=(ObjectRef other) {
763763
this->Clear();
764-
type_code_ = kObjectCell;
764+
type_code_ = kObjectHandle;
765765
// move the handle out
766766
value_.v_handle = other.data_.data_;
767767
other.data_.data_ = nullptr;
@@ -862,7 +862,7 @@ class TVMRetValue : public TVMPODValue_ {
862862
kNodeHandle, *other.template ptr<NodePtr<Node> >());
863863
break;
864864
}
865-
case kObjectCell: {
865+
case kObjectHandle: {
866866
*this = other.operator ObjectRef();
867867
break;
868868
}
@@ -913,7 +913,7 @@ class TVMRetValue : public TVMPODValue_ {
913913
static_cast<NDArray::Container*>(value_.v_handle)->DecRef();
914914
break;
915915
}
916-
case kObjectCell: {
916+
case kObjectHandle: {
917917
static_cast<Object*>(value_.v_handle)->DecRef();
918918
break;
919919
}
@@ -946,7 +946,7 @@ inline const char* TypeCode2Str(int type_code) {
946946
case kFuncHandle: return "FunctionHandle";
947947
case kModuleHandle: return "ModuleHandle";
948948
case kNDArrayContainer: return "NDArrayContainer";
949-
case kObjectCell: return "ObjectCell";
949+
case kObjectHandle: return "ObjectCell";
950950
default: LOG(FATAL) << "unknown type_code="
951951
<< static_cast<int>(type_code); return "";
952952
}
@@ -1164,7 +1164,7 @@ class TVMArgsSetter {
11641164
}
11651165
void operator()(size_t i, const ObjectRef& value) const { // NOLINT(*)
11661166
values_[i].v_handle = value.data_.data_;
1167-
type_codes_[i] = kObjectCell;
1167+
type_codes_[i] = kObjectHandle;
11681168
}
11691169
void operator()(size_t i, const TVMRetValue& value) const { // NOLINT(*)
11701170
if (value.type_code() == kStr) {

python/tvm/_ffi/_ctypes/function.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from .types import TVMPackedCFunc, TVMCFuncFinalizer
3434
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func, _ctx_to_int64
3535
from .node import NodeBase
36+
from . import object as _object
3637
from . import node as _node
3738

3839
FunctionHandle = ctypes.c_void_p
@@ -165,7 +166,7 @@ def _make_tvm_args(args, temp_args):
165166
temp_args.append(arg)
166167
elif isinstance(arg, _CLASS_OBJECT):
167168
values[i].v_handle = arg.handle
168-
type_codes[i] = TypeCode.OBJECT_CELL
169+
type_codes[i] = TypeCode.OBJECT_HANDLE
169170
else:
170171
raise TypeError("Don't know how to handle type %s" % type(arg))
171172
return values, type_codes, num_args
@@ -225,7 +226,7 @@ def __init_handle_by_constructor__(fconstructor, args):
225226
raise get_last_ffi_error()
226227
_ = temp_args
227228
_ = args
228-
assert ret_tcode.value == TypeCode.NODE_HANDLE
229+
assert ret_tcode.value in (TypeCode.NODE_HANDLE, TypeCode.OBJECT_HANDLE)
229230
handle = ret_val.v_handle
230231
return handle
231232

@@ -247,6 +248,7 @@ def _handle_return_func(x):
247248

248249
# setup return handle for function type
249250
_node.__init_by_constructor__ = __init_handle_by_constructor__
251+
_object.__init_by_constructor__ = __init_handle_by_constructor__
250252
RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func
251253
RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module
252254
RETURN_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False, True)

python/tvm/_ffi/_ctypes/object.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=invalid-name
18+
"""Runtime Object api"""
19+
from __future__ import absolute_import
20+
21+
import ctypes
22+
from ..base import _LIB, check_call
23+
from .types import TypeCode, RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func
24+
25+
26+
ObjectHandle = ctypes.c_void_p
27+
__init_by_constructor__ = None
28+
29+
"""Maps object type to its constructor"""
30+
OBJECT_TYPE = {}
31+
32+
def _register_object(index, cls):
33+
"""register object class"""
34+
OBJECT_TYPE[index] = cls
35+
36+
37+
def _return_object(x):
38+
handle = x.v_handle
39+
if not isinstance(handle, ObjectHandle):
40+
handle = ObjectHandle(handle)
41+
tindex = ctypes.c_uint()
42+
check_call(_LIB.TVMObjectGetTypeIndex(handle, ctypes.byref(tindex)))
43+
cls = OBJECT_TYPE.get(tindex.value, ObjectBase)
44+
# Avoid calling __init__ of cls, instead directly call __new__
45+
# This allows child class to implement their own __init__
46+
obj = cls.__new__(cls)
47+
obj.handle = handle
48+
return obj
49+
50+
RETURN_SWITCH[TypeCode.OBJECT_HANDLE] = _return_object
51+
C_TO_PY_ARG_SWITCH[TypeCode.OBJECT_HANDLE] = _wrap_arg_func(
52+
_return_object, TypeCode.OBJECT_HANDLE)
53+
54+
55+
class ObjectBase(object):
56+
"""Base object for all object types"""
57+
__slots__ = ["handle"]
58+
59+
def __del__(self):
60+
if _LIB is not None:
61+
check_call(_LIB.TVMObjectFree(self.handle))
62+
63+
def __init_handle_by_constructor__(self, fconstructor, *args):
64+
"""Initialize the handle by calling constructor function.
65+
66+
Parameters
67+
----------
68+
fconstructor : Function
69+
Constructor function.
70+
71+
args: list of objects
72+
The arguments to the constructor
73+
74+
Note
75+
----
76+
We have a special calling convention to call constructor functions.
77+
So the return handle is directly set into the Node object
78+
instead of creating a new Node.
79+
"""
80+
# assign handle first to avoid error raising
81+
self.handle = None
82+
handle = __init_by_constructor__(fconstructor, args)
83+
if not isinstance(handle, ObjectHandle):
84+
handle = ObjectHandle(handle)
85+
self.handle = handle

python/tvm/_ffi/_ctypes/vmobj.py

Lines changed: 0 additions & 52 deletions
This file was deleted.

python/tvm/_ffi/_cython/base.pxi

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ cdef enum TVMTypeCode:
3737
kStr = 11
3838
kBytes = 12
3939
kNDArrayContainer = 13
40-
kObjectCell = 14
40+
kObjectHandle = 14
4141
kExtBegin = 15
4242

4343
cdef extern from "tvm/runtime/c_runtime_api.h":
@@ -130,7 +130,9 @@ cdef extern from "tvm/runtime/c_runtime_api.h":
130130
int TVMArrayToDLPack(DLTensorHandle arr_from,
131131
DLManagedTensor** out)
132132
void TVMDLManagedTensorCallDeleter(DLManagedTensor* dltensor)
133-
int TVMGetObjectTag(ObjectHandle obj, int* tag)
133+
int TVMObjectFree(ObjectHandle obj)
134+
int TVMObjectGetTypeIndex(ObjectHandle obj, unsigned* out_index)
135+
134136

135137
cdef extern from "tvm/c_dsl_api.h":
136138
int TVMNodeFree(NodeHandle handle)

python/tvm/_ffi/_cython/core.pyx

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
# under the License.
1717

1818
include "./base.pxi"
19+
include "./object.pxi"
1920
include "./node.pxi"
2021
include "./function.pxi"
2122
include "./ndarray.pxi"
22-
include "./vmobj.pxi"
23+

python/tvm/_ffi/_cython/function.pxi

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ cdef int tvm_callback(TVMValue* args,
4444
if (tcode == kNodeHandle or
4545
tcode == kFuncHandle or
4646
tcode == kModuleHandle or
47-
tcode == kObjectCell or
47+
tcode == kObjectHandle or
4848
tcode > kExtBegin):
4949
CALL(TVMCbArgToReturn(&value, tcode))
5050

@@ -155,12 +155,12 @@ cdef inline int make_arg(object arg,
155155
value[0].v_handle = (<NodeBase>arg).chandle
156156
tcode[0] = kNodeHandle
157157
temp_args.append(arg)
158+
elif isinstance(arg, _CLASS_OBJECT):
159+
value[0].v_handle = (<ObjectBase>arg).chandle
160+
tcode[0] = kObjectHandle
158161
elif isinstance(arg, _CLASS_MODULE):
159162
value[0].v_handle = c_handle(arg.handle)
160163
tcode[0] = kModuleHandle
161-
elif isinstance(arg, _CLASS_OBJECT):
162-
value[0].v_handle = c_handle(arg.handle)
163-
tcode[0] = kObjectCell
164164
elif isinstance(arg, FunctionBase):
165165
value[0].v_handle = (<FunctionBase>arg).chandle
166166
tcode[0] = kFuncHandle
@@ -190,6 +190,8 @@ cdef inline object make_ret(TVMValue value, int tcode):
190190
"""convert result to return value."""
191191
if tcode == kNodeHandle:
192192
return make_ret_node(value.v_handle)
193+
elif tcode == kObjectHandle:
194+
return make_ret_object(value.v_handle)
193195
elif tcode == kNull:
194196
return None
195197
elif tcode == kInt:
@@ -212,8 +214,6 @@ cdef inline object make_ret(TVMValue value, int tcode):
212214
fobj = _CLASS_FUNCTION(None, False)
213215
(<FunctionBase>fobj).chandle = value.v_handle
214216
return fobj
215-
elif tcode == kObjectCell:
216-
return make_ret_object(value.v_handle)
217217
elif tcode in _TVM_EXT_RET:
218218
return _TVM_EXT_RET[tcode](ctypes_handle(value.v_handle))
219219

0 commit comments

Comments
 (0)