Skip to content

Commit 0429c63

Browse files
xqdanxiaoqiang.dan
andauthored
Complete register op from python (#8079)
* Complete register op from python * fix lint * fix lint * fix lint * fix comments * fix * fix * fix comments * fix lint * fix lint * add comments * fix build * fix * add exception case * fix * fix comments * fix * fix * fix * fix * fix * fix * fix Co-authored-by: xiaoqiang.dan <xiaoqiang.dan@streamcoputing.com>
1 parent c7f1b45 commit 0429c63

File tree

7 files changed

+318
-17
lines changed

7 files changed

+318
-17
lines changed

include/tvm/ir/op.h

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,12 +244,18 @@ class OpRegEntry {
244244
runtime::TypedPackedFunc<bool(const Array<Type>&, int, const Attrs&, const TypeReporter&)>
245245
type_rel_func);
246246
/*!
247-
* \brief Set the the attrs type key and index to be AttrsType.
247+
* \brief Set the attrs type key and index to be AttrsType.
248248
* \tparam AttrsType the attribute type to b set.
249249
* \return reference to self.
250250
*/
251251
template <typename AttrsType>
252252
inline OpRegEntry& set_attrs_type();
253+
/*!
254+
* \brief Set the attrs type key and index to be AttrsType.
255+
* \param key The attribute type key to be set.
256+
* \return reference to self.
257+
*/
258+
inline OpRegEntry& set_attrs_type_key(const String& key);
253259
/*!
254260
* \brief Set the num_inputs
255261
* \param n The number of inputs to be set.
@@ -454,6 +460,12 @@ inline OpRegEntry& OpRegEntry::set_attrs_type() { // NOLINT(*)
454460
return *this;
455461
}
456462

463+
inline OpRegEntry& OpRegEntry::set_attrs_type_key(const String& key) { // NOLINT(*)
464+
get()->attrs_type_key = key;
465+
get()->attrs_type_index = Object::TypeKey2Index(key);
466+
return *this;
467+
}
468+
457469
inline OpRegEntry& OpRegEntry::set_support_level(int32_t n) { // NOLINT(*)
458470
get()->support_level = n;
459471
return *this;

python/tvm/ir/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from .tensor_type import TensorType
2424
from .type_relation import TypeCall, TypeRelation
2525
from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, Range
26-
from .op import Op, register_op, register_op_attr, register_intrin_lowering
26+
from .op import Op, register_op_attr, register_intrin_lowering
2727
from .function import CallingConv, BaseFunc
2828
from .adt import Constructor, TypeData
2929
from .module import IRModule

python/tvm/ir/op.py

Lines changed: 62 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -85,17 +85,71 @@ def reset_attr(self, attr_name):
8585
"""
8686
_ffi_api.OpResetAttr(self, attr_name)
8787

88+
def add_type_rel(self, rel_name, type_rel_func=None):
89+
"""Attach the type function corresponding to the return type.
8890
89-
def register_op(op_name):
90-
"""Register an operator by name
91+
Parameters
92+
----------
93+
rel_name : str
94+
The type relation name to register.
95+
96+
type_rel_func : Optional[function (args: List[Type], attrs: Attrs) -> Type]
97+
The backing relation function which can solve an arbitrary relation on variables.
98+
Differences with type_rel_func in C++:
99+
1, when type_rel_func is not None:
100+
1) OpAddTypeRel on C++ side will adjust type_rel_func with TypeReporter to
101+
calling convention of relay type system.
102+
2) type_rel_func returns output argument's type, return None means can't
103+
infer output's type.
104+
3) only support single output operators for now, the last argument is output tensor.
105+
2, when type_rel_func is None, will call predefined type_rel_funcs in relay
106+
accorrding to `tvm.relay.type_relation.` + rel_name.
107+
"""
108+
_ffi_api.OpAddTypeRel(self, rel_name, type_rel_func)
91109

92-
Parameters
93-
----------
94-
op_name : str
95-
The name of new operator
96-
"""
110+
def add_argument(self, name, type, description): # pylint: disable=redefined-builtin
111+
"""Add arguments information to the function.
97112
98-
_ffi_api.RegisterOp(op_name)
113+
Parameters
114+
----------
115+
name : str
116+
The argument name.
117+
type : str
118+
The argument type.
119+
description : str
120+
The argument description.
121+
"""
122+
_ffi_api.OpAddArgument(self, name, type, description)
123+
124+
def set_support_level(self, level):
125+
"""Set the support level of op.
126+
127+
Parameters
128+
----------
129+
level : int
130+
The support level.
131+
"""
132+
_ffi_api.OpSetSupportLevel(self, level)
133+
134+
def set_num_inputs(self, n):
135+
"""Set the support level of op.
136+
137+
Parameters
138+
----------
139+
n : int
140+
The input number.
141+
"""
142+
_ffi_api.OpSetNumInputs(self, n)
143+
144+
def set_attrs_type_key(self, key):
145+
"""Set the attribute type key of op.
146+
147+
Parameters
148+
----------
149+
key : str
150+
The type key.
151+
"""
152+
_ffi_api.OpSetAttrsTypeKey(self, key)
99153

100154

101155
def register_op_attr(op_name, attr_key, value=None, level=10):

python/tvm/relay/op/op.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from tvm.driver import lower, build
2222
from tvm.target import get_native_generic_func, GenericFunc
2323
from tvm.runtime import Object
24+
import tvm.ir._ffi_api
2425
from . import _make
2526

2627

@@ -40,6 +41,40 @@ def get(op_name):
4041
return tvm.ir.Op.get(op_name)
4142

4243

44+
def register(op_name, describe=""):
45+
"""Get the Op for a given name.
46+
when the op_name is not registered, create a new empty op with the given name.
47+
when the op_name has been registered, abort with an error message.
48+
49+
Parameters
50+
----------
51+
op_name : str
52+
The operator name
53+
54+
describe : Optional[str]
55+
The operator description
56+
"""
57+
58+
tvm.ir._ffi_api.RegisterOp(op_name, describe)
59+
60+
61+
def register_stateful(op_name, stateful, level=10):
62+
"""Register operator pattern for an op.
63+
64+
Parameters
65+
----------
66+
op_name : str
67+
The name of the op.
68+
69+
stateful : bool
70+
The stateful flag.
71+
72+
level : int
73+
The priority level
74+
"""
75+
tvm.ir.register_op_attr(op_name, "TOpIsStateful", stateful, level)
76+
77+
4378
class OpPattern(object):
4479
"""Operator generic patterns
4580

src/ir/op.cc

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,71 @@ TVM_REGISTER_GLOBAL("ir.OpResetAttr").set_body_typed([](Op op, String attr_name)
102102
reg.reset_attr(attr_name);
103103
});
104104

105-
TVM_REGISTER_GLOBAL("ir.RegisterOp").set_body_typed([](String op_name) {
105+
TVM_REGISTER_GLOBAL("ir.RegisterOp").set_body_typed([](String op_name, String descr) {
106106
const OpRegEntry* reg = OpRegistry::Global()->Get(op_name);
107107
ICHECK(reg == nullptr) << "AttributeError: Operator " << op_name << " is registered before";
108-
OpRegistry::Global()->RegisterOrGet(op_name).set_name();
108+
auto& op = OpRegistry::Global()->RegisterOrGet(op_name).set_name();
109+
op.describe(descr);
110+
});
111+
112+
// This is exposed FFI api for prototyping using in python.
113+
// Note: it is not full of the C++ type relation,
114+
// since in python side we don't have access to the type reporter,
115+
// and cannot propagate constraints to the inputs, only to the output.
116+
TVM_REGISTER_GLOBAL("ir.OpAddTypeRel")
117+
.set_body_typed([](Op op, String rel_name, runtime::TVMArgValue value) {
118+
auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name();
119+
if (value.type_code() == kTVMPackedFuncHandle) {
120+
// do an eager copy of the PackedFunc to avoid deleting function from frontend.
121+
PackedFunc* fcopy = new PackedFunc(value.operator tvm::runtime::PackedFunc());
122+
auto f = [=](const Array<Type>& args, int num_inputs, const Attrs& attrs,
123+
const TypeReporter& reporter) -> bool {
124+
Array<Type> input_types(args.begin(), args.end() - 1);
125+
// call customized relation functions
126+
// *fcopy's signature: function (args: List[Type], attrs: Attrs) -> Type
127+
Type ret_type = (*fcopy)(input_types, attrs);
128+
// when defined ret_type, inference of output type is ok, do type assign
129+
// otherwise, inference failure happens
130+
if (ret_type.defined()) {
131+
// the last argument is output
132+
// TODO(xqdan): support multiple outputs
133+
reporter->Assign(args.back(), ret_type);
134+
return true;
135+
}
136+
return false;
137+
};
138+
// adjust function call to call conventions of relay type system with TypeReporter
139+
auto type_rel = runtime::TypedPackedFunc<bool(const Array<Type>&, int, const Attrs&,
140+
const TypeReporter&)>(f);
141+
reg.add_type_rel(rel_name, type_rel);
142+
} else if (value.type_code() == kTVMNullptr) {
143+
// Call relation functions of relay
144+
auto func_name = std::string("tvm.relay.type_relation.") + rel_name;
145+
auto* f = runtime::Registry::Get(func_name);
146+
ICHECK(f != nullptr) << "AddTypeRel error: no type_relation registered.";
147+
reg.add_type_rel(rel_name, *f);
148+
}
149+
});
150+
151+
TVM_REGISTER_GLOBAL("ir.OpAddArgument")
152+
.set_body_typed([](Op op, String name, String type, String description) {
153+
auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name();
154+
reg.add_argument(name, type, description);
155+
});
156+
157+
TVM_REGISTER_GLOBAL("ir.OpSetSupportLevel").set_body_typed([](Op op, int level) {
158+
auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name();
159+
reg.set_support_level(level);
160+
});
161+
162+
TVM_REGISTER_GLOBAL("ir.OpSetNumInputs").set_body_typed([](Op op, int n) {
163+
auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name();
164+
reg.set_num_inputs(n);
165+
});
166+
167+
TVM_REGISTER_GLOBAL("ir.OpSetAttrsTypeKey").set_body_typed([](Op op, String key) {
168+
auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name();
169+
reg.set_attrs_type_key(key);
109170
});
110171

111172
TVM_REGISTER_GLOBAL("ir.RegisterOpAttr")

tests/python/relay/test_ir_op.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import tvm
1818
from tvm import relay
1919
from tvm.relay.testing.temp_op_attr import TempOpAttr
20+
from tvm.relay.op import op as _op
2021

2122

2223
def test_op_attr():
@@ -103,11 +104,20 @@ def test_op_register():
103104
"""Tests register_op functionality."""
104105
op_name = "custom_op"
105106

106-
tvm.ir.register_op(op_name)
107-
tvm.ir.register_op_attr(op_name, "num_inputs", 2, 256)
108-
109-
assert tvm.ir.Op.get(op_name).name == op_name
110-
assert tvm.ir.Op.get(op_name).num_inputs == 2
107+
_op.register(op_name, r"code(Add two tensor with inner broadcasting.)code")
108+
_op.get(op_name).set_num_inputs(2)
109+
_op.get(op_name).add_argument("data_0", "Tensor", "The input data tensor.")
110+
_op.get(op_name).add_argument("data_1", "Tensor", "The input data tensor.")
111+
# call default relation functions
112+
_op.get(op_name).add_type_rel("Identity")
113+
_op.get(op_name).set_support_level(1)
114+
_op.register_pattern(op_name, _op.OpPattern.ELEMWISE)
115+
_op.register_stateful(op_name, False)
116+
117+
assert _op.get(op_name).name == op_name
118+
assert _op.get(op_name).num_inputs == 2
119+
assert _op.get(op_name).get_attr("TOpPattern") == _op.OpPattern.ELEMWISE
120+
assert _op.get(op_name).get_attr("TOpIsStateful") == False
111121

112122

113123
if __name__ == "__main__":

0 commit comments

Comments
 (0)