Skip to content

Commit d42c343

Browse files
committed
relay op strategy
fix lint bitpack strategy bitserial_dense (#6) * update strategy * address comments fix a few topi test Dense strategy (#5) * dense * add biforst; remove comments * address comment Refactor x86 conv2d_NCHWc (#4) * Refactor x86 conv2d * Add x86 depthwise_conv2d_NCHWc * Add back topi x86 conv2d_nchw * Merge x86 conv2d_nchw and conv2d_NCHWc * Minor fix for x86 conv2d fix more strategy Add x86 conv2d_NCHWc_int8 strategy (#8) * Add x86 conv2d_NCHWc_int8 strategy * Remove contrib_conv2d_nchwc_int8 * Fix generic conv2d_NCHWc for int8 * Fix topi arm_cpu conv2d_NCHWc_int8 update x86 conv2d
1 parent 703ed9b commit d42c343

File tree

164 files changed

+5104
-3553
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

164 files changed

+5104
-3553
lines changed

include/tvm/relay/op_attr_types.h

Lines changed: 182 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,7 @@ using TShapeDataDependant = bool;
101101
using FTVMCompute = runtime::TypedPackedFunc<
102102
Array<top::Tensor>(const Attrs& attrs,
103103
const Array<top::Tensor>& inputs,
104-
const Type& out_type,
105-
const Target& target)>;
104+
const Type& out_type)>;
106105

107106
/*!
108107
* \brief Build the computation schedule for
@@ -118,6 +117,16 @@ using FTVMSchedule = runtime::TypedPackedFunc<
118117
const Array<top::Tensor>& outs,
119118
const Target& target)>;
120119

120+
/*!
121+
* \brief Generate the strategy of operators. This function is a generic
122+
* function and can be re-defined for different targets.
123+
*
124+
* The function signature of generic function is:
125+
* OpStrategy(const Attrs& attrs, const Array<Tensor>& inputs,
126+
* const Type& out_type, const Target& target)
127+
*/
128+
using FTVMStrategy = GenericFunc;
129+
121130
/*!
122131
* \brief Alternate the layout of operators or replace the
123132
* operator with other expressions. This function will be invoked
@@ -131,7 +140,8 @@ using FTVMSchedule = runtime::TypedPackedFunc<
131140
using FTVMAlterOpLayout = runtime::TypedPackedFunc<
132141
Expr(const Attrs& attrs,
133142
const Array<Expr>& args,
134-
const Array<top::Tensor>& tinfos)>;
143+
const Array<top::Tensor>& tinfos,
144+
const Type& out_type)>;
135145

136146
/*!
137147
* \brief Convert the layout of operators or replace the
@@ -186,9 +196,7 @@ using FForwardRewrite = runtime::TypedPackedFunc<
186196
* \brief Gradient for a specific op.
187197
*
188198
* \param orig_call the original Expr.
189-
*
190199
* \param output_grad the gradient of the Expr.
191-
*
192200
* \return the gradient for each parameters.
193201
*/
194202
using FPrimalGradient = runtime::TypedPackedFunc<tvm::Array<Expr>(const Expr& orig_call,
@@ -202,14 +210,182 @@ enum AnyCodegenStrategy {
202210
kVariableDimensions
203211
};
204212

205-
/* \brief A runtime representation of shape. */
213+
/*! \brief A runtime representation of shape. */
206214
using Shape = Array<IndexExpr>;
207215

208216
using FShapeFunc = runtime::TypedPackedFunc<
209217
Array<top::Tensor>(const Attrs& attrs,
210218
const Array<top::Tensor>& inputs,
211219
const Array<IndexExpr>& out_ndims)>;
212220

221+
/*!
222+
* \brief Operator implementation in TVM.
223+
*/
224+
class OpImplementNode : public Object {
225+
public:
226+
/*! \brief Compute function */
227+
FTVMCompute fcompute;
228+
/*! \brief Schedule function */
229+
FTVMSchedule fschedule;
230+
/*! \brief Priority level */
231+
Integer plevel;
232+
233+
void VisitAttrs(tvm::AttrVisitor* v) {
234+
v->Visit("plevel", &plevel);
235+
}
236+
237+
static constexpr const char* _type_key = "relay.OpImplement";
238+
TVM_DECLARE_FINAL_OBJECT_INFO(OpImplementNode, Object);
239+
};
240+
241+
/*!
242+
* \brief Operator implementation class.
243+
*/
244+
class OpImplement : public ObjectRef {
245+
public:
246+
/*! \brief default constructor */
247+
OpImplement() {}
248+
/*! \brief constructor from node pointer */
249+
explicit OpImplement(ObjectPtr<Object> n) : ObjectRef(n) {}
250+
/*!
251+
* \brief access the internal node container
252+
* \return the pointer to the internal node container
253+
*/
254+
inline const OpImplementNode* operator->() const;
255+
/*!
256+
* \brief Invoke the operator compute function.
257+
* \param attrs The attribute of the primitive
258+
* \param inputs The input tensors.
259+
* \param out_type The output type information.
260+
* \return The output compute description of the operator.
261+
*/
262+
Array<top::Tensor> Compute(const Attrs& attrs,
263+
const Array<top::Tensor>& inputs,
264+
const Type& out_type);
265+
/*!
266+
* \brief Build the computation schedule.
267+
* \param attrs The attribute of the node.
268+
* \param outs The output tensors.
269+
* \param target The build target.
270+
* \return The computation schedule.
271+
*/
272+
top::Schedule Schedule(const Attrs& attrs,
273+
const Array<top::Tensor>& outs,
274+
const Target& target);
275+
};
276+
277+
/*!
278+
* \brief Specialized implementations for operators under certain conditions.
279+
*/
280+
class OpSpecializationNode : public Object {
281+
public:
282+
/*! \brief List of implementations. */
283+
Array<OpImplement> implements;
284+
/*! \brief Condition to enable the specialization.
285+
* Could be undefined to represent generic case. */
286+
top::SpecializedCondition condition;
287+
288+
void VisitAttrs(tvm::AttrVisitor* v) {
289+
v->Visit("condition", &condition);
290+
v->Visit("implements", &implements);
291+
}
292+
293+
static constexpr const char* _type_key = "relay.OpSpecialization";
294+
TVM_DECLARE_FINAL_OBJECT_INFO(OpSpecializationNode, ExprNode);
295+
};
296+
297+
/*!
298+
* \brief Operator specialization class.
299+
*/
300+
class OpSpecialization : public ObjectRef {
301+
public:
302+
OpSpecialization() {}
303+
explicit OpSpecialization(ObjectPtr<Object> n) : ObjectRef(n) {}
304+
/*!
305+
* \brief access the internal node container
306+
* \return the pointer to the internal node container
307+
*/
308+
inline const OpSpecializationNode* operator->() const;
309+
/*!
310+
* \brief access the internal node container
311+
* \return the pointer to the internal node container
312+
*/
313+
inline OpSpecializationNode* operator->();
314+
/*!
315+
* \brief Add an implementation.
316+
* \param compute Compute function
317+
* \param schedule Schedule function
318+
* \param plevel Priority level of this implemntation.
319+
*/
320+
void AddImplement(FTVMCompute fcompute, FTVMSchedule fschedule,
321+
int plevel);
322+
};
323+
324+
/*!
325+
* \brief Operator strategy to choose implementation.
326+
*/
327+
class OpStrategyNode : public Object {
328+
public:
329+
/*! \brief List of operator specializations. */
330+
Array<OpSpecialization> specializations;
331+
332+
void VisitAttrs(tvm::AttrVisitor* v) {
333+
v->Visit("specializations", &specializations);
334+
}
335+
336+
static constexpr const char* _type_key = "relay.OpStrategy";
337+
TVM_DECLARE_FINAL_OBJECT_INFO(OpStrategyNode, ExprNode);
338+
};
339+
340+
/*!
341+
* \brief Operator strategy class.
342+
*/
343+
class OpStrategy : public ObjectRef {
344+
public:
345+
/*! \brief default constructor */
346+
OpStrategy() {}
347+
/*! \brief constructor from node pointer */
348+
explicit OpStrategy(ObjectPtr<Object> n) : ObjectRef(n) {}
349+
/*!
350+
* \brief access the internal node container
351+
* \return the pointer to the internal node container
352+
*/
353+
inline const OpStrategyNode* operator->() const;
354+
/*!
355+
* \brief access the internal node container
356+
* \return the pointer to the internal node container
357+
*/
358+
inline OpStrategyNode* operator->();
359+
/*!
360+
* \brief Add an implementation.
361+
* \param compute Compute function
362+
* \param schedule Schedule function
363+
* \param plevel Priority level of this implementation.
364+
*/
365+
void AddImplement(FTVMCompute fcompute, FTVMSchedule fschedule, int plevel);
366+
};
367+
368+
// implementations
369+
inline const OpImplementNode* OpImplement::operator->() const {
370+
return static_cast<const OpImplementNode*>(get());
371+
}
372+
373+
inline const OpSpecializationNode* OpSpecialization::operator->() const {
374+
return static_cast<const OpSpecializationNode*>(get());
375+
}
376+
377+
inline OpSpecializationNode* OpSpecialization::operator->() {
378+
return static_cast<OpSpecializationNode*>(get_mutable());
379+
}
380+
381+
inline const OpStrategyNode* OpStrategy::operator->() const {
382+
return static_cast<const OpStrategyNode*>(get());
383+
}
384+
385+
inline OpStrategyNode* OpStrategy::operator->() {
386+
return static_cast<OpStrategyNode*>(get_mutable());
387+
}
388+
213389
} // namespace relay
214390
} // namespace tvm
215391
#endif // TVM_RELAY_OP_ATTR_TYPES_H_

include/tvm/top/schedule.h

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
#include <tvm/expr.h>
2929
#include <tvm/top/tensor.h>
3030
#include <tvm/top/tensor_intrin.h>
31-
31+
#include <tvm/support/with.h>
3232

3333
#include <string>
3434
#include <unordered_map>
@@ -744,6 +744,55 @@ class SingletonNode : public IterVarRelationNode {
744744
TVM_DECLARE_FINAL_OBJECT_INFO(SingletonNode, IterVarRelationNode);
745745
};
746746

747+
class SpecializedConditionNode;
748+
749+
/*!
750+
* \brief Specialized condition to enable op specialization
751+
*/
752+
class SpecializedCondition : public ObjectRef {
753+
public:
754+
SpecializedCondition() {}
755+
explicit SpecializedCondition(ObjectPtr<Object> n) : ObjectRef(n) {}
756+
/*!
757+
* \brief Get the current specialized condition.
758+
* \return The current specialized condition.
759+
*/
760+
TVM_DLL static SpecializedCondition Current();
761+
762+
const SpecializedConditionNode* operator->() const;
763+
764+
using ContainerType = SpecializedConditionNode;
765+
class Internal;
766+
private:
767+
// enable with syntax.
768+
friend class Internal;
769+
friend class With<SpecializedCondition>;
770+
/*! \brief Push a new specialized condition onto the thread local stack. */
771+
TVM_DLL void EnterWithScope();
772+
/*! \brief Pop a specialized condition off the thread local context stack. */
773+
TVM_DLL void ExitWithScope();
774+
};
775+
776+
/*! \brief Container for specialization conditions. */
777+
class SpecializedConditionNode : public Object {
778+
public:
779+
/*!
780+
* \brief List of conditions in conjunctive joint form (CNF).
781+
* Each condition should be a simple expression, e.g., n > 16, m % 8 == 0, etc.,
782+
* where n, m are tvm::Var that represents a dimension in the tensor shape.
783+
*/
784+
Array<PrimExpr> clauses;
785+
786+
void VisitAttrs(AttrVisitor* v) {
787+
v->Visit("clauses", &clauses);
788+
}
789+
790+
static SpecializedCondition make(Array<PrimExpr> conditions);
791+
792+
static constexpr const char* _type_key = "SpecializedCondition";
793+
TVM_DECLARE_FINAL_OBJECT_INFO(SpecializedConditionNode, Object);
794+
};
795+
747796

748797
// implementations
749798
inline const StageNode* Stage::operator->() const {
@@ -767,6 +816,11 @@ inline const IterVarRelationNode* IterVarRelation::operator->() const {
767816
inline const IterVarAttrNode* IterVarAttr::operator->() const {
768817
return static_cast<const IterVarAttrNode*>(get());
769818
}
819+
820+
inline const SpecializedConditionNode* SpecializedCondition::operator->() const {
821+
return static_cast<const SpecializedConditionNode*>(get());
822+
}
823+
770824
} // namespace top
771825
} // namespace tvm
772826
#endif // TVM_TOP_SCHEDULE_H_

python/tvm/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
from .tensor_intrin import decl_tensor_intrin
5858
from .object import register_object
5959
from .ndarray import register_extension
60-
from .schedule import create_schedule
60+
from .schedule import create_schedule, current_specialization
6161
from .build_module import build, lower, build_config
6262
from .tag import tag_scope
6363

python/tvm/autotvm/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from .tuner import callback
4444
from .task import template, get_config, create, ConfigSpace, ConfigEntity, \
4545
register_topi_compute, register_topi_schedule, \
46+
register_topi_compute2, register_topi_schedule2, register_customized_task, \
4647
DispatchContext, FallbackContext, ApplyHistoryBest as apply_history_best, \
4748
ApplyGraphBest as apply_graph_best
4849
from .env import GLOBAL_SCOPE

python/tvm/autotvm/graph_tuner/base_graph_tuner.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import tvm
2626
from tvm import autotvm, relay
2727
from tvm.autotvm.task import get_config
28-
from tvm.autotvm.task.topi_integration import deserialize_args, serialize_args
28+
from tvm.autotvm.task.topi_integration import serialize_args
2929
from tvm.autotvm.record import encode, load_from_file
3030
from tvm.autotvm.measure import MeasureResult, MeasureInput
3131

@@ -42,11 +42,17 @@
4242
"topi_nn_depthwise_conv2d_nchw": topi.nn.depthwise_conv2d_infer_layout,
4343
}
4444

45+
def get_infer_layout(task_name):
46+
if task_name.startswith("conv2d"):
47+
return topi.nn.conv2d_infer_layout
48+
elif task_name.startswith("depthwise_conv2d"):
49+
return topi.nn.depthwise_conv2d_infer_layout
50+
else:
51+
raise ValueError("Cannot find infer layout for task %s" % task_name)
4552

46-
@autotvm.template
53+
@autotvm.register_customized_task("layout_transform")
4754
def layout_transform(*args):
4855
"""Autotvm layout transform template."""
49-
args = deserialize_args(args)
5056
cfg = get_config()
5157
cfg.add_flop(-1)
5258
data = args[0]
@@ -212,7 +218,7 @@ def _fetch_cfg(self):
212218
node_entry["record_candidates"] = cache_dict[workload]
213219
continue
214220
record_candidates = []
215-
infer_layout_func = OP2LAYOUT[node_entry["topi_op"][0]]
221+
infer_layout_func = get_infer_layout(node_entry["topi_op"][0])
216222
layout_tracking_dict = {}
217223
for record in cfg_dict[workload]:
218224
in_measure, out_measure = record
@@ -264,7 +270,7 @@ def _iterate_layout_transform(self, callback):
264270

265271
if node_entry["op"] in self._target_ops:
266272
o_idx = key
267-
o_infer_layout_func = OP2LAYOUT[node_entry["topi_op"][0]]
273+
o_infer_layout_func = get_infer_layout(node_entry["topi_op"][0])
268274
o_wkl = node_entry["workloads"][0]
269275
i_topi_op = in_node_entry["topi_op"][0]
270276
i_wkl = in_node_entry["workloads"][0]
@@ -273,14 +279,14 @@ def _iterate_layout_transform(self, callback):
273279
pivot += 1
274280
i_topi_op = in_node_entry["topi_op"][pivot]
275281
i_wkl = in_node_entry["workloads"][pivot]
276-
i_infer_layout_func = OP2LAYOUT[i_topi_op]
282+
i_infer_layout_func = get_infer_layout(i_topi_op)
277283
else:
278284
o_idx = target_input_idx
279285
if i <= target_input_pos:
280286
continue
281-
o_infer_layout_func = OP2LAYOUT[node_entry["topi_op"][0]]
287+
o_infer_layout_func = get_infer_layout(node_entry["topi_op"][0])
282288
o_wkl = node_entry["workloads"][target_input_pos]
283-
i_infer_layout_func = OP2LAYOUT[node_entry["topi_op"][i]]
289+
i_infer_layout_func = get_infer_layout(node_entry["topi_op"][i])
284290
i_wkl = node_entry["workloads"][i]
285291

286292
if (i_idx, o_idx) in pair_tracker:

python/tvm/autotvm/graph_tuner/utils/traverse_graph.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ def expr2graph(expr, target_ops, node_dict, node_list):
6565
% op_name)
6666
topi_funcs += OP2COMPUTE[op_name]
6767
env.reset(topi_funcs)
68+
# TODO(@kevinthesun, @icemelon9): Currently graph tuning pass relies on the fact
69+
# that # autotvm tasks == # ops. But this won't be true after having relay op
70+
# strategy. We need to find a solution to fix this.
6871
with env:
6972
_expr2graph_impl(expr, target_ops, node_dict, node_list)
7073
task_pos = 0

0 commit comments

Comments
 (0)