29
29
#include < tvm/relay/type.h>
30
30
#include < tvm/relay/expr.h>
31
31
#include < tvm/target/target.h>
32
+ #include < tvm/target/generic_func.h>
32
33
#include < tvm/tir/data_layout.h>
33
34
#include < string>
34
35
@@ -105,9 +106,8 @@ using TShapeDataDependant = bool;
105
106
*/
106
107
using FTVMCompute = runtime::TypedPackedFunc<
107
108
Array<te::Tensor>(const Attrs& attrs,
108
- const Array<te::Tensor>& inputs,
109
- const Type& out_type,
110
- const Target& target)>;
109
+ const Array<te::Tensor>& inputs,
110
+ const Type& out_type)>;
111
111
112
112
/* !
113
113
* \brief Build the computation schedule for
@@ -123,6 +123,16 @@ using FTVMSchedule = runtime::TypedPackedFunc<
123
123
const Array<te::Tensor>& outs,
124
124
const Target& target)>;
125
125
126
+ /* !
127
+ * \brief Generate the strategy of operators. This function is a generic
128
+ * function and can be re-defined for different targets.
129
+ *
130
+ * The function signature of generic function is:
131
+ * OpStrategy(const Attrs& attrs, const Array<Tensor>& inputs,
132
+ * const Type& out_type, const Target& target)
133
+ */
134
+ using FTVMStrategy = GenericFunc;
135
+
126
136
/* !
127
137
* \brief Alternate the layout of operators or replace the
128
138
* operator with other expressions. This function will be invoked
@@ -136,7 +146,8 @@ using FTVMSchedule = runtime::TypedPackedFunc<
136
146
using FTVMAlterOpLayout = runtime::TypedPackedFunc<
137
147
Expr (const Attrs& attrs,
138
148
const Array<Expr>& args,
139
- const Array<te::Tensor>& tinfos)>;
149
+ const Array<te::Tensor>& tinfos,
150
+ const Type& out_type)>;
140
151
141
152
/* !
142
153
* \brief Convert the layout of operators or replace the
@@ -191,9 +202,7 @@ using FForwardRewrite = runtime::TypedPackedFunc<
191
202
* \brief Gradient for a specific op.
192
203
*
193
204
* \param orig_call the original Expr.
194
- *
195
205
* \param output_grad the gradient of the Expr.
196
- *
197
206
* \return the gradient for each parameters.
198
207
*/
199
208
using FPrimalGradient = runtime::TypedPackedFunc<tvm::Array<Expr>(const Expr& orig_call,
@@ -207,14 +216,182 @@ enum AnyCodegenStrategy {
207
216
kVariableDimensions
208
217
};
209
218
210
- /* \brief A runtime representation of shape. */
219
+ /* ! \brief A runtime representation of shape. */
211
220
using Shape = Array<IndexExpr>;
212
221
213
222
using FShapeFunc = runtime::TypedPackedFunc<
214
223
Array<te::Tensor>(const Attrs& attrs,
215
224
const Array<te::Tensor>& inputs,
216
225
const Array<IndexExpr>& out_ndims)>;
217
226
227
+ /* !
228
+ * \brief Operator implementation in TVM.
229
+ */
230
+ class OpImplementNode : public Object {
231
+ public:
232
+ /* ! \brief Compute function */
233
+ FTVMCompute fcompute;
234
+ /* ! \brief Schedule function */
235
+ FTVMSchedule fschedule;
236
+ /* ! \brief Priority level */
237
+ Integer plevel;
238
+
239
+ void VisitAttrs (tvm::AttrVisitor* v) {
240
+ v->Visit (" plevel" , &plevel);
241
+ }
242
+
243
+ static constexpr const char * _type_key = " relay.OpImplement" ;
244
+ TVM_DECLARE_FINAL_OBJECT_INFO (OpImplementNode, Object);
245
+ };
246
+
247
+ /* !
248
+ * \brief Operator implementation class.
249
+ */
250
+ class OpImplement : public ObjectRef {
251
+ public:
252
+ /* ! \brief default constructor */
253
+ OpImplement () {}
254
+ /* ! \brief constructor from node pointer */
255
+ explicit OpImplement (ObjectPtr<Object> n) : ObjectRef(n) {}
256
+ /* !
257
+ * \brief access the internal node container
258
+ * \return the pointer to the internal node container
259
+ */
260
+ inline const OpImplementNode* operator ->() const ;
261
+ /* !
262
+ * \brief Invoke the operator compute function.
263
+ * \param attrs The attribute of the primitive
264
+ * \param inputs The input tensors.
265
+ * \param out_type The output type information.
266
+ * \return The output compute description of the operator.
267
+ */
268
+ Array<te::Tensor> Compute (const Attrs& attrs,
269
+ const Array<te::Tensor>& inputs,
270
+ const Type& out_type);
271
+ /* !
272
+ * \brief Build the computation schedule.
273
+ * \param attrs The attribute of the node.
274
+ * \param outs The output tensors.
275
+ * \param target The build target.
276
+ * \return The computation schedule.
277
+ */
278
+ te::Schedule Schedule (const Attrs& attrs,
279
+ const Array<te::Tensor>& outs,
280
+ const Target& target);
281
+ };
282
+
283
+ /* !
284
+ * \brief Specialized implementations for operators under certain conditions.
285
+ */
286
+ class OpSpecializationNode : public Object {
287
+ public:
288
+ /* ! \brief List of implementations. */
289
+ Array<OpImplement> implements;
290
+ /* ! \brief Condition to enable the specialization.
291
+ * Could be undefined to represent generic case. */
292
+ te::SpecializedCondition condition;
293
+
294
+ void VisitAttrs (tvm::AttrVisitor* v) {
295
+ v->Visit (" condition" , &condition);
296
+ v->Visit (" implements" , &implements);
297
+ }
298
+
299
+ static constexpr const char * _type_key = " relay.OpSpecialization" ;
300
+ TVM_DECLARE_FINAL_OBJECT_INFO (OpSpecializationNode, ExprNode);
301
+ };
302
+
303
+ /* !
304
+ * \brief Operator specialization class.
305
+ */
306
+ class OpSpecialization : public ObjectRef {
307
+ public:
308
+ OpSpecialization () {}
309
+ explicit OpSpecialization (ObjectPtr<Object> n) : ObjectRef(n) {}
310
+ /* !
311
+ * \brief access the internal node container
312
+ * \return the pointer to the internal node container
313
+ */
314
+ inline const OpSpecializationNode* operator ->() const ;
315
+ /* !
316
+ * \brief access the internal node container
317
+ * \return the pointer to the internal node container
318
+ */
319
+ inline OpSpecializationNode* operator ->();
320
+ /* !
321
+ * \brief Add an implementation.
322
+ * \param compute Compute function
323
+ * \param schedule Schedule function
324
+ * \param plevel Priority level of this implemntation.
325
+ */
326
+ void AddImplement (FTVMCompute fcompute, FTVMSchedule fschedule,
327
+ int plevel);
328
+ };
329
+
330
+ /* !
331
+ * \brief Operator strategy to choose implementation.
332
+ */
333
+ class OpStrategyNode : public Object {
334
+ public:
335
+ /* ! \brief List of operator specializations. */
336
+ Array<OpSpecialization> specializations;
337
+
338
+ void VisitAttrs (tvm::AttrVisitor* v) {
339
+ v->Visit (" specializations" , &specializations);
340
+ }
341
+
342
+ static constexpr const char * _type_key = " relay.OpStrategy" ;
343
+ TVM_DECLARE_FINAL_OBJECT_INFO (OpStrategyNode, ExprNode);
344
+ };
345
+
346
+ /* !
347
+ * \brief Operator strategy class.
348
+ */
349
+ class OpStrategy : public ObjectRef {
350
+ public:
351
+ /* ! \brief default constructor */
352
+ OpStrategy () {}
353
+ /* ! \brief constructor from node pointer */
354
+ explicit OpStrategy (ObjectPtr<Object> n) : ObjectRef(n) {}
355
+ /* !
356
+ * \brief access the internal node container
357
+ * \return the pointer to the internal node container
358
+ */
359
+ inline const OpStrategyNode* operator ->() const ;
360
+ /* !
361
+ * \brief access the internal node container
362
+ * \return the pointer to the internal node container
363
+ */
364
+ inline OpStrategyNode* operator ->();
365
+ /* !
366
+ * \brief Add an implementation.
367
+ * \param compute Compute function
368
+ * \param schedule Schedule function
369
+ * \param plevel Priority level of this implementation.
370
+ */
371
+ void AddImplement (FTVMCompute fcompute, FTVMSchedule fschedule, int plevel);
372
+ };
373
+
374
+ // implementations
375
+ inline const OpImplementNode* OpImplement::operator ->() const {
376
+ return static_cast <const OpImplementNode*>(get ());
377
+ }
378
+
379
+ inline const OpSpecializationNode* OpSpecialization::operator ->() const {
380
+ return static_cast <const OpSpecializationNode*>(get ());
381
+ }
382
+
383
+ inline OpSpecializationNode* OpSpecialization::operator ->() {
384
+ return static_cast <OpSpecializationNode*>(get_mutable ());
385
+ }
386
+
387
+ inline const OpStrategyNode* OpStrategy::operator ->() const {
388
+ return static_cast <const OpStrategyNode*>(get ());
389
+ }
390
+
391
+ inline OpStrategyNode* OpStrategy::operator ->() {
392
+ return static_cast <OpStrategyNode*>(get_mutable ());
393
+ }
394
+
218
395
} // namespace relay
219
396
} // namespace tvm
220
397
#endif // TVM_RELAY_OP_ATTR_TYPES_H_
0 commit comments