@@ -102,10 +102,71 @@ TVM_REGISTER_GLOBAL("ir.OpResetAttr").set_body_typed([](Op op, String attr_name)
102
102
reg.reset_attr (attr_name);
103
103
});
104
104
105
- TVM_REGISTER_GLOBAL (" ir.RegisterOp" ).set_body_typed([](String op_name) {
105
+ TVM_REGISTER_GLOBAL (" ir.RegisterOp" ).set_body_typed([](String op_name, String descr ) {
106
106
const OpRegEntry* reg = OpRegistry::Global ()->Get (op_name);
107
107
ICHECK (reg == nullptr ) << " AttributeError: Operator " << op_name << " is registered before" ;
108
- OpRegistry::Global ()->RegisterOrGet (op_name).set_name ();
108
+ auto & op = OpRegistry::Global ()->RegisterOrGet (op_name).set_name ();
109
+ op.describe (descr);
110
+ });
111
+
112
+ // This is exposed FFI api for prototyping using in python.
113
+ // Note: it is not full of the C++ type relation,
114
+ // since in python side we don't have access to the type reporter,
115
+ // and cannot propagate constraints to the inputs, only to the output.
116
+ TVM_REGISTER_GLOBAL (" ir.OpAddTypeRel" )
117
+ .set_body_typed([](Op op, String rel_name, runtime::TVMArgValue value) {
118
+ auto & reg = OpRegistry::Global ()->RegisterOrGet (op->name ).set_name ();
119
+ if (value.type_code () == kTVMPackedFuncHandle ) {
120
+ // do an eager copy of the PackedFunc to avoid deleting function from frontend.
121
+ PackedFunc* fcopy = new PackedFunc (value.operator tvm::runtime::PackedFunc ());
122
+ auto f = [=](const Array<Type>& args, int num_inputs, const Attrs& attrs,
123
+ const TypeReporter& reporter) -> bool {
124
+ Array<Type> input_types (args.begin (), args.end () - 1 );
125
+ // call customized relation functions
126
+ // *fcopy's signature: function (args: List[Type], attrs: Attrs) -> Type
127
+ Type ret_type = (*fcopy)(input_types, attrs);
128
+ // when defined ret_type, inference of output type is ok, do type assign
129
+ // otherwise, inference failure happens
130
+ if (ret_type.defined ()) {
131
+ // the last argument is output
132
+ // TODO(xqdan): support multiple outputs
133
+ reporter->Assign (args.back (), ret_type);
134
+ return true ;
135
+ }
136
+ return false ;
137
+ };
138
+ // adjust function call to call conventions of relay type system with TypeReporter
139
+ auto type_rel = runtime::TypedPackedFunc<bool (const Array<Type>&, int , const Attrs&,
140
+ const TypeReporter&)>(f);
141
+ reg.add_type_rel (rel_name, type_rel);
142
+ } else if (value.type_code () == kTVMNullptr ) {
143
+ // Call relation functions of relay
144
+ auto func_name = std::string (" tvm.relay.type_relation." ) + rel_name;
145
+ auto * f = runtime::Registry::Get (func_name);
146
+ ICHECK (f != nullptr ) << " AddTypeRel error: no type_relation registered." ;
147
+ reg.add_type_rel (rel_name, *f);
148
+ }
149
+ });
150
+
151
+ TVM_REGISTER_GLOBAL (" ir.OpAddArgument" )
152
+ .set_body_typed([](Op op, String name, String type, String description) {
153
+ auto & reg = OpRegistry::Global ()->RegisterOrGet (op->name ).set_name ();
154
+ reg.add_argument (name, type, description);
155
+ });
156
+
157
+ TVM_REGISTER_GLOBAL (" ir.OpSetSupportLevel" ).set_body_typed([](Op op, int level) {
158
+ auto & reg = OpRegistry::Global ()->RegisterOrGet (op->name ).set_name ();
159
+ reg.set_support_level (level);
160
+ });
161
+
162
+ TVM_REGISTER_GLOBAL (" ir.OpSetNumInputs" ).set_body_typed([](Op op, int n) {
163
+ auto & reg = OpRegistry::Global ()->RegisterOrGet (op->name ).set_name ();
164
+ reg.set_num_inputs (n);
165
+ });
166
+
167
+ TVM_REGISTER_GLOBAL (" ir.OpSetAttrsTypeKey" ).set_body_typed([](Op op, String key) {
168
+ auto & reg = OpRegistry::Global ()->RegisterOrGet (op->name ).set_name ();
169
+ reg.set_attrs_type_key (key);
109
170
});
110
171
111
172
TVM_REGISTER_GLOBAL (" ir.RegisterOpAttr" )
0 commit comments