Skip to content

Commit 62656e3

Browse files
committed
relay op strategy
fix lint bitpack strategy bitserial_dense (apache#6) * update strategy * address comments fix a few topi test Dense strategy (apache#5) * dense * add biforst; remove comments * address comment Refactor x86 conv2d_NCHWc (apache#4) * Refactor x86 conv2d * Add x86 depthwise_conv2d_NCHWc * Add back topi x86 conv2d_nchw * Merge x86 conv2d_nchw and conv2d_NCHWc * Minor fix for x86 conv2d fix more strategy Add x86 conv2d_NCHWc_int8 strategy (apache#8) * Add x86 conv2d_NCHWc_int8 strategy * Remove contrib_conv2d_nchwc_int8 * Fix generic conv2d_NCHWc for int8 * Fix topi arm_cpu conv2d_NCHWc_int8 update x86 conv2d enable specify relay ops to be tuned for autotvm add cuda conv2d strategy add conv2d strategy for rocm add conv2d strategy for hls add conv2d strategy for arm cpu add conv2d strategy for mali add conv2d strategy for bifrost add conv2d strategy for intel graphics clean up and fix lint remove template keys from autotvm remove 2 in the func name address comments fix
1 parent ee2d3cc commit 62656e3

File tree

200 files changed

+6909
-5787
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

200 files changed

+6909
-5787
lines changed

include/tvm/relay/op_attr_types.h

Lines changed: 184 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include <tvm/relay/type.h>
3030
#include <tvm/relay/expr.h>
3131
#include <tvm/target/target.h>
32+
#include <tvm/target/generic_func.h>
3233
#include <tvm/tir/data_layout.h>
3334
#include <string>
3435

@@ -105,9 +106,8 @@ using TShapeDataDependant = bool;
105106
*/
106107
using FTVMCompute = runtime::TypedPackedFunc<
107108
Array<te::Tensor>(const Attrs& attrs,
108-
const Array<te::Tensor>& inputs,
109-
const Type& out_type,
110-
const Target& target)>;
109+
const Array<te::Tensor>& inputs,
110+
const Type& out_type)>;
111111

112112
/*!
113113
* \brief Build the computation schedule for
@@ -123,6 +123,16 @@ using FTVMSchedule = runtime::TypedPackedFunc<
123123
const Array<te::Tensor>& outs,
124124
const Target& target)>;
125125

126+
/*!
127+
* \brief Generate the strategy of operators. This function is a generic
128+
* function and can be re-defined for different targets.
129+
*
130+
* The function signature of generic function is:
131+
* OpStrategy(const Attrs& attrs, const Array<Tensor>& inputs,
132+
* const Type& out_type, const Target& target)
133+
*/
134+
using FTVMStrategy = GenericFunc;
135+
126136
/*!
127137
* \brief Alternate the layout of operators or replace the
128138
* operator with other expressions. This function will be invoked
@@ -136,7 +146,8 @@ using FTVMSchedule = runtime::TypedPackedFunc<
136146
using FTVMAlterOpLayout = runtime::TypedPackedFunc<
137147
Expr(const Attrs& attrs,
138148
const Array<Expr>& args,
139-
const Array<te::Tensor>& tinfos)>;
149+
const Array<te::Tensor>& tinfos,
150+
const Type& out_type)>;
140151

141152
/*!
142153
* \brief Convert the layout of operators or replace the
@@ -191,9 +202,7 @@ using FForwardRewrite = runtime::TypedPackedFunc<
191202
* \brief Gradient for a specific op.
192203
*
193204
* \param orig_call the original Expr.
194-
*
195205
* \param output_grad the gradient of the Expr.
196-
*
197206
* \return the gradient for each parameters.
198207
*/
199208
using FPrimalGradient = runtime::TypedPackedFunc<tvm::Array<Expr>(const Expr& orig_call,
@@ -207,14 +216,182 @@ enum AnyCodegenStrategy {
207216
kVariableDimensions
208217
};
209218

210-
/* \brief A runtime representation of shape. */
219+
/*! \brief A runtime representation of shape. */
211220
using Shape = Array<IndexExpr>;
212221

213222
using FShapeFunc = runtime::TypedPackedFunc<
214223
Array<te::Tensor>(const Attrs& attrs,
215224
const Array<te::Tensor>& inputs,
216225
const Array<IndexExpr>& out_ndims)>;
217226

227+
/*!
228+
* \brief Operator implementation in TVM.
229+
*/
230+
class OpImplementNode : public Object {
231+
public:
232+
/*! \brief Compute function */
233+
FTVMCompute fcompute;
234+
/*! \brief Schedule function */
235+
FTVMSchedule fschedule;
236+
/*! \brief Priority level */
237+
Integer plevel;
238+
239+
void VisitAttrs(tvm::AttrVisitor* v) {
240+
v->Visit("plevel", &plevel);
241+
}
242+
243+
static constexpr const char* _type_key = "relay.OpImplement";
244+
TVM_DECLARE_FINAL_OBJECT_INFO(OpImplementNode, Object);
245+
};
246+
247+
/*!
248+
* \brief Operator implementation class.
249+
*/
250+
class OpImplement : public ObjectRef {
251+
public:
252+
/*! \brief default constructor */
253+
OpImplement() {}
254+
/*! \brief constructor from node pointer */
255+
explicit OpImplement(ObjectPtr<Object> n) : ObjectRef(n) {}
256+
/*!
257+
* \brief access the internal node container
258+
* \return the pointer to the internal node container
259+
*/
260+
inline const OpImplementNode* operator->() const;
261+
/*!
262+
* \brief Invoke the operator compute function.
263+
* \param attrs The attribute of the primitive
264+
* \param inputs The input tensors.
265+
* \param out_type The output type information.
266+
* \return The output compute description of the operator.
267+
*/
268+
Array<te::Tensor> Compute(const Attrs& attrs,
269+
const Array<te::Tensor>& inputs,
270+
const Type& out_type);
271+
/*!
272+
* \brief Build the computation schedule.
273+
* \param attrs The attribute of the node.
274+
* \param outs The output tensors.
275+
* \param target The build target.
276+
* \return The computation schedule.
277+
*/
278+
te::Schedule Schedule(const Attrs& attrs,
279+
const Array<te::Tensor>& outs,
280+
const Target& target);
281+
};
282+
283+
/*!
284+
* \brief Specialized implementations for operators under certain conditions.
285+
*/
286+
class OpSpecializationNode : public Object {
287+
public:
288+
/*! \brief List of implementations. */
289+
Array<OpImplement> implements;
290+
/*! \brief Condition to enable the specialization.
291+
* Could be undefined to represent generic case. */
292+
te::SpecializedCondition condition;
293+
294+
void VisitAttrs(tvm::AttrVisitor* v) {
295+
v->Visit("condition", &condition);
296+
v->Visit("implements", &implements);
297+
}
298+
299+
static constexpr const char* _type_key = "relay.OpSpecialization";
300+
TVM_DECLARE_FINAL_OBJECT_INFO(OpSpecializationNode, ExprNode);
301+
};
302+
303+
/*!
304+
* \brief Operator specialization class.
305+
*/
306+
class OpSpecialization : public ObjectRef {
307+
public:
308+
OpSpecialization() {}
309+
explicit OpSpecialization(ObjectPtr<Object> n) : ObjectRef(n) {}
310+
/*!
311+
* \brief access the internal node container
312+
* \return the pointer to the internal node container
313+
*/
314+
inline const OpSpecializationNode* operator->() const;
315+
/*!
316+
* \brief access the internal node container
317+
* \return the pointer to the internal node container
318+
*/
319+
inline OpSpecializationNode* operator->();
320+
/*!
321+
* \brief Add an implementation.
322+
* \param compute Compute function
323+
* \param schedule Schedule function
324+
* \param plevel Priority level of this implemntation.
325+
*/
326+
void AddImplement(FTVMCompute fcompute, FTVMSchedule fschedule,
327+
int plevel);
328+
};
329+
330+
/*!
331+
* \brief Operator strategy to choose implementation.
332+
*/
333+
class OpStrategyNode : public Object {
334+
public:
335+
/*! \brief List of operator specializations. */
336+
Array<OpSpecialization> specializations;
337+
338+
void VisitAttrs(tvm::AttrVisitor* v) {
339+
v->Visit("specializations", &specializations);
340+
}
341+
342+
static constexpr const char* _type_key = "relay.OpStrategy";
343+
TVM_DECLARE_FINAL_OBJECT_INFO(OpStrategyNode, ExprNode);
344+
};
345+
346+
/*!
347+
* \brief Operator strategy class.
348+
*/
349+
class OpStrategy : public ObjectRef {
350+
public:
351+
/*! \brief default constructor */
352+
OpStrategy() {}
353+
/*! \brief constructor from node pointer */
354+
explicit OpStrategy(ObjectPtr<Object> n) : ObjectRef(n) {}
355+
/*!
356+
* \brief access the internal node container
357+
* \return the pointer to the internal node container
358+
*/
359+
inline const OpStrategyNode* operator->() const;
360+
/*!
361+
* \brief access the internal node container
362+
* \return the pointer to the internal node container
363+
*/
364+
inline OpStrategyNode* operator->();
365+
/*!
366+
* \brief Add an implementation.
367+
* \param compute Compute function
368+
* \param schedule Schedule function
369+
* \param plevel Priority level of this implementation.
370+
*/
371+
void AddImplement(FTVMCompute fcompute, FTVMSchedule fschedule, int plevel);
372+
};
373+
374+
// implementations
375+
inline const OpImplementNode* OpImplement::operator->() const {
376+
return static_cast<const OpImplementNode*>(get());
377+
}
378+
379+
inline const OpSpecializationNode* OpSpecialization::operator->() const {
380+
return static_cast<const OpSpecializationNode*>(get());
381+
}
382+
383+
inline OpSpecializationNode* OpSpecialization::operator->() {
384+
return static_cast<OpSpecializationNode*>(get_mutable());
385+
}
386+
387+
inline const OpStrategyNode* OpStrategy::operator->() const {
388+
return static_cast<const OpStrategyNode*>(get());
389+
}
390+
391+
inline OpStrategyNode* OpStrategy::operator->() {
392+
return static_cast<OpStrategyNode*>(get_mutable());
393+
}
394+
218395
} // namespace relay
219396
} // namespace tvm
220397
#endif // TVM_RELAY_OP_ATTR_TYPES_H_

include/tvm/te/schedule.h

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include <tvm/tir/expr.h>
2929
#include <tvm/te/tensor.h>
3030
#include <tvm/te/tensor_intrin.h>
31+
#include <tvm/support/with.h>
3132

3233
#include <string>
3334
#include <unordered_map>
@@ -742,6 +743,55 @@ class SingletonNode : public IterVarRelationNode {
742743
TVM_DECLARE_FINAL_OBJECT_INFO(SingletonNode, IterVarRelationNode);
743744
};
744745

746+
class SpecializedConditionNode;
747+
748+
/*!
749+
* \brief Specialized condition to enable op specialization
750+
*/
751+
class SpecializedCondition : public ObjectRef {
752+
public:
753+
SpecializedCondition() {}
754+
explicit SpecializedCondition(ObjectPtr<Object> n) : ObjectRef(n) {}
755+
/*!
756+
* \brief Get the current specialized condition.
757+
* \return The current specialized condition.
758+
*/
759+
TVM_DLL static SpecializedCondition Current();
760+
761+
const SpecializedConditionNode* operator->() const;
762+
763+
using ContainerType = SpecializedConditionNode;
764+
class Internal;
765+
private:
766+
// enable with syntax.
767+
friend class Internal;
768+
friend class With<SpecializedCondition>;
769+
/*! \brief Push a new specialized condition onto the thread local stack. */
770+
TVM_DLL void EnterWithScope();
771+
/*! \brief Pop a specialized condition off the thread local context stack. */
772+
TVM_DLL void ExitWithScope();
773+
};
774+
775+
/*! \brief Container for specialization conditions. */
776+
class SpecializedConditionNode : public Object {
777+
public:
778+
/*!
779+
* \brief List of conditions in conjunctive joint form (CNF).
780+
* Each condition should be a simple expression, e.g., n > 16, m % 8 == 0, etc.,
781+
* where n, m are tvm::Var that represents a dimension in the tensor shape.
782+
*/
783+
Array<PrimExpr> clauses;
784+
785+
void VisitAttrs(AttrVisitor* v) {
786+
v->Visit("clauses", &clauses);
787+
}
788+
789+
static SpecializedCondition make(Array<PrimExpr> conditions);
790+
791+
static constexpr const char* _type_key = "SpecializedCondition";
792+
TVM_DECLARE_FINAL_OBJECT_INFO(SpecializedConditionNode, Object);
793+
};
794+
745795

746796
// implementations
747797
inline const StageNode* Stage::operator->() const {
@@ -765,6 +815,11 @@ inline const IterVarRelationNode* IterVarRelation::operator->() const {
765815
inline const IterVarAttrNode* IterVarAttr::operator->() const {
766816
return static_cast<const IterVarAttrNode*>(get());
767817
}
818+
819+
inline const SpecializedConditionNode* SpecializedCondition::operator->() const {
820+
return static_cast<const SpecializedConditionNode*>(get());
821+
}
822+
768823
} // namespace te
769824
} // namespace tvm
770825
#endif // TVM_TE_SCHEDULE_H_

python/tvm/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
from .api import *
5858
from .intrin import *
5959
from .tensor_intrin import decl_tensor_intrin
60-
from .schedule import create_schedule
60+
from .schedule import create_schedule, current_specialization
6161
from .build_module import build, lower, build_config
6262
from .tag import tag_scope
6363

python/tvm/autotvm/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@
4141
from .measure import measure_option, MeasureInput, MeasureResult, MeasureErrorNo, \
4242
LocalBuilder, LocalRunner, RPCRunner
4343
from .tuner import callback
44-
from .task import template, get_config, create, ConfigSpace, ConfigEntity, \
45-
register_topi_compute, register_topi_schedule, \
44+
from .task import get_config, create, ConfigSpace, ConfigEntity, \
45+
register_topi_compute, register_topi_schedule, register_customized_task, \
4646
DispatchContext, FallbackContext, ApplyHistoryBest as apply_history_best, \
4747
ApplyGraphBest as apply_graph_best
4848
from .env import GLOBAL_SCOPE

0 commit comments

Comments
 (0)