Skip to content

Commit 22f7797

Browse files
icemelonkevinthesun
authored andcommitted
[Relay][AutoTVM] Relay op strategy (apache#4644)
* relay op strategy fix lint bitpack strategy bitserial_dense (#6) * update strategy * address comments fix a few topi test Dense strategy (#5) * dense * add biforst; remove comments * address comment Refactor x86 conv2d_NCHWc (#4) * Refactor x86 conv2d * Add x86 depthwise_conv2d_NCHWc * Add back topi x86 conv2d_nchw * Merge x86 conv2d_nchw and conv2d_NCHWc * Minor fix for x86 conv2d fix more strategy Add x86 conv2d_NCHWc_int8 strategy (#8) * Add x86 conv2d_NCHWc_int8 strategy * Remove contrib_conv2d_nchwc_int8 * Fix generic conv2d_NCHWc for int8 * Fix topi arm_cpu conv2d_NCHWc_int8 update x86 conv2d enable specify relay ops to be tuned for autotvm add cuda conv2d strategy add conv2d strategy for rocm add conv2d strategy for hls add conv2d strategy for arm cpu add conv2d strategy for mali add conv2d strategy for bifrost add conv2d strategy for intel graphics clean up and fix lint remove template keys from autotvm remove 2 in the func name address comments fix * fix bugs * lint * address comments * add name to op implement * Modify topi tests (#9) * Add pooling, reorg, softmax and vision * Add lrn * fix topi test * fix more topi test * lint * address comments * x * fix more tests & bugs * Modify more tests (#10) * Modify tests for bitserial_conv2d, bitserial_dense, bitserial_conv2d_rasp and bnn * Minor fix * More minor fix * fix more test * try to update vta using strategy * fix cpptest * x * fix rebase err * Fix two tests (#11) * change autotvm log format * lint * minor fix * try fix vta test * fix rebase err * tweak * tmp hack for vta pass * fix tutorial * fix * fix more tutorials * fix vta tutorial * minor * address comments * fix * address comments * fix cpptest * fix docs * change data structure name and api * address comments * lint * fix rebase err * updates * fix winograd test * fix doc * rebase * upgrade tophub version number * fix bug * re-enable vta tsim test after tophub is upgraded * fix vta test to use the correct args so the config can be found in tophub Co-authored-by: Yao Wang <kevinthesunwy@gmail.com>
1 parent e087ccc commit 22f7797

File tree

270 files changed

+8466
-7067
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

270 files changed

+8466
-7067
lines changed

include/tvm/relay/op_attr_types.h

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include <tvm/relay/type.h>
3030
#include <tvm/relay/expr.h>
3131
#include <tvm/target/target.h>
32+
#include <tvm/target/generic_func.h>
3233
#include <tvm/tir/data_layout.h>
3334
#include <string>
3435

@@ -105,9 +106,8 @@ using TShapeDataDependant = bool;
105106
*/
106107
using FTVMCompute = runtime::TypedPackedFunc<
107108
Array<te::Tensor>(const Attrs& attrs,
108-
const Array<te::Tensor>& inputs,
109-
const Type& out_type,
110-
const Target& target)>;
109+
const Array<te::Tensor>& inputs,
110+
const Type& out_type)>;
111111

112112
/*!
113113
* \brief Build the computation schedule for
@@ -120,8 +120,18 @@ using FTVMCompute = runtime::TypedPackedFunc<
120120
*/
121121
using FTVMSchedule = runtime::TypedPackedFunc<
122122
te::Schedule(const Attrs& attrs,
123-
const Array<te::Tensor>& outs,
124-
const Target& target)>;
123+
const Array<te::Tensor>& outs,
124+
const Target& target)>;
125+
126+
/*!
127+
* \brief Generate the strategy of operators. This function is a generic
128+
* function and can be re-defined for different targets.
129+
*
130+
* The function signature of generic function is:
131+
* OpStrategy(const Attrs& attrs, const Array<Tensor>& inputs,
132+
* const Type& out_type, const Target& target)
133+
*/
134+
using FTVMStrategy = GenericFunc;
125135

126136
/*!
127137
* \brief Alternate the layout of operators or replace the
@@ -136,7 +146,8 @@ using FTVMSchedule = runtime::TypedPackedFunc<
136146
using FTVMAlterOpLayout = runtime::TypedPackedFunc<
137147
Expr(const Attrs& attrs,
138148
const Array<Expr>& args,
139-
const Array<te::Tensor>& tinfos)>;
149+
const Array<te::Tensor>& tinfos,
150+
const Type& out_type)>;
140151

141152
/*!
142153
* \brief Convert the layout of operators or replace the
@@ -191,9 +202,7 @@ using FForwardRewrite = runtime::TypedPackedFunc<
191202
* \brief Gradient for a specific op.
192203
*
193204
* \param orig_call the original Expr.
194-
*
195205
* \param output_grad the gradient of the Expr.
196-
*
197206
* \return the gradient for each parameters.
198207
*/
199208
using FPrimalGradient = runtime::TypedPackedFunc<tvm::Array<Expr>(const Expr& orig_call,
@@ -207,13 +216,13 @@ enum AnyCodegenStrategy {
207216
kVariableDimensions
208217
};
209218

210-
/* \brief A runtime representation of shape. */
219+
/*! \brief A runtime representation of shape. */
211220
using Shape = Array<IndexExpr>;
212221

213222
using FShapeFunc = runtime::TypedPackedFunc<
214223
Array<te::Tensor>(const Attrs& attrs,
215-
const Array<te::Tensor>& inputs,
216-
const Array<IndexExpr>& out_ndims)>;
224+
const Array<te::Tensor>& inputs,
225+
const Array<IndexExpr>& out_ndims)>;
217226

218227
} // namespace relay
219228
} // namespace tvm

include/tvm/relay/op_strategy.h

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file tvm/relay/op_strategy.h
22+
* \brief The Relay operator Strategy and related data structure.
23+
*/
24+
25+
#ifndef TVM_RELAY_OP_STRATEGY_H_
26+
#define TVM_RELAY_OP_STRATEGY_H_
27+
28+
#include <tvm/te/tensor.h>
29+
#include <tvm/te/schedule.h>
30+
#include <tvm/relay/expr.h>
31+
#include <tvm/relay/op_attr_types.h>
32+
#include <tvm/target/target.h>
33+
#include <string>
34+
35+
namespace tvm {
36+
namespace relay {
37+
38+
/*!
39+
* \brief Operator implementation that includes compute and schedule function.
40+
*/
41+
class OpImplementationNode : public Object {
42+
public:
43+
/*! \brief Compute function */
44+
FTVMCompute fcompute;
45+
/*! \brief Schedule function */
46+
FTVMSchedule fschedule;
47+
/*! \brief Name of the implementation */
48+
std::string name;
49+
/*! \brief Priority level */
50+
int plevel;
51+
52+
void VisitAttrs(tvm::AttrVisitor* v) {
53+
v->Visit("name", &name);
54+
v->Visit("plevel", &plevel);
55+
}
56+
57+
static constexpr const char* _type_key = "relay.OpImplementation";
58+
TVM_DECLARE_FINAL_OBJECT_INFO(OpImplementationNode, Object);
59+
};
60+
61+
/*!
62+
* \brief Operator implementation class.
63+
*/
64+
class OpImplementation : public ObjectRef {
65+
public:
66+
/*!
67+
* \brief Invoke the operator compute function.
68+
* \param attrs The attribute of the primitive
69+
* \param inputs The input tensors.
70+
* \param out_type The output type information.
71+
* \return The output compute description of the operator.
72+
*/
73+
TVM_DLL Array<te::Tensor> Compute(const Attrs& attrs,
74+
const Array<te::Tensor>& inputs,
75+
const Type& out_type);
76+
/*!
77+
* \brief Build the computation schedule.
78+
* \param attrs The attribute of the node.
79+
* \param outs The output tensors.
80+
* \param target The build target.
81+
* \return The computation schedule.
82+
*/
83+
TVM_DLL te::Schedule Schedule(const Attrs& attrs,
84+
const Array<te::Tensor>& outs,
85+
const Target& target);
86+
87+
TVM_DEFINE_OBJECT_REF_METHODS(OpImplementation, ObjectRef, OpImplementationNode);
88+
};
89+
90+
/*!
91+
* \brief Specialized implementations for operators under certain conditions.
92+
*/
93+
class OpSpecializationNode : public Object {
94+
public:
95+
/*! \brief List of implementations. */
96+
Array<OpImplementation> implementations;
97+
/*! \brief Condition to enable the specialization.
98+
* Could be undefined to represent generic case. */
99+
te::SpecializedCondition condition;
100+
101+
void VisitAttrs(tvm::AttrVisitor* v) {
102+
v->Visit("condition", &condition);
103+
v->Visit("implementations", &implementations);
104+
}
105+
106+
static constexpr const char* _type_key = "relay.OpSpecialization";
107+
TVM_DECLARE_FINAL_OBJECT_INFO(OpSpecializationNode, ExprNode);
108+
};
109+
110+
/*!
111+
* \brief Operator specialization class.
112+
*/
113+
class OpSpecialization : public ObjectRef {
114+
public:
115+
/*!
116+
* \brief Add an implementation.
117+
* \param fcompute Compute function
118+
* \param fschedule Schedule function
119+
* \param name Name of the implementation
120+
* \param plevel Priority level of the implementation
121+
*/
122+
TVM_DLL void AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule,
123+
std::string name, int plevel);
124+
125+
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(OpSpecialization, ObjectRef, OpSpecializationNode);
126+
};
127+
128+
/*!
129+
* \brief Operator strategy to choose implementation.
130+
*/
131+
class OpStrategyNode : public Object {
132+
public:
133+
/*! \brief List of operator specializations. */
134+
Array<OpSpecialization> specializations;
135+
136+
void VisitAttrs(tvm::AttrVisitor* v) {
137+
v->Visit("specializations", &specializations);
138+
}
139+
140+
static constexpr const char* _type_key = "relay.OpStrategy";
141+
TVM_DECLARE_FINAL_OBJECT_INFO(OpStrategyNode, ExprNode);
142+
};
143+
144+
/*!
145+
* \brief Operator strategy class.
146+
*/
147+
class OpStrategy : public ObjectRef {
148+
public:
149+
/*!
150+
* \brief Add an implementation.
151+
* \param fcompute Compute function
152+
* \param fschedule Schedule function
153+
* \param name Name of the implementation
154+
* \param plevel Priority level of the implementation
155+
*/
156+
TVM_DLL void AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule,
157+
std::string name, int plevel);
158+
159+
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(OpStrategy, ObjectRef, OpStrategyNode);
160+
};
161+
162+
} // namespace relay
163+
} // namespace tvm
164+
#endif // TVM_RELAY_OP_STRATEGY_H_

include/tvm/te/schedule.h

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include <tvm/tir/expr.h>
2929
#include <tvm/te/tensor.h>
3030
#include <tvm/te/tensor_intrin.h>
31+
#include <tvm/support/with.h>
3132

3233
#include <string>
3334
#include <unordered_map>
@@ -742,6 +743,53 @@ class SingletonNode : public IterVarRelationNode {
742743
TVM_DECLARE_FINAL_OBJECT_INFO(SingletonNode, IterVarRelationNode);
743744
};
744745

746+
/*! \brief Container for specialization conditions. */
747+
class SpecializedConditionNode : public Object {
748+
public:
749+
/*!
750+
* \brief List of conditions in conjunctive joint form (CNF).
751+
* Each condition should be a simple expression, e.g., n > 16, m % 8 == 0, etc.,
752+
* where n, m are tvm::Var that represents a dimension in the tensor shape.
753+
*/
754+
Array<PrimExpr> clauses;
755+
756+
void VisitAttrs(AttrVisitor* v) {
757+
v->Visit("clauses", &clauses);
758+
}
759+
760+
static constexpr const char* _type_key = "SpecializedCondition";
761+
TVM_DECLARE_FINAL_OBJECT_INFO(SpecializedConditionNode, Object);
762+
};
763+
764+
/*!
765+
* \brief Specialized condition to enable op specialization
766+
*/
767+
class SpecializedCondition : public ObjectRef {
768+
public:
769+
/*!
770+
* \brief construct from conditions
771+
* \param conditions The clauses in the specialized condition.
772+
*/
773+
TVM_DLL SpecializedCondition(Array<PrimExpr> conditions); // NOLINT(*)
774+
775+
/*!
776+
* \brief Get the current specialized condition.
777+
* \return the current specialized condition.
778+
*/
779+
TVM_DLL static SpecializedCondition Current();
780+
781+
TVM_DEFINE_OBJECT_REF_METHODS(SpecializedCondition, ObjectRef, SpecializedConditionNode);
782+
class Internal;
783+
784+
private:
785+
// enable with syntax.
786+
friend class Internal;
787+
friend class With<SpecializedCondition>;
788+
/*! \brief Push a new specialized condition onto the thread local stack. */
789+
TVM_DLL void EnterWithScope();
790+
/*! \brief Pop a specialized condition off the thread local context stack. */
791+
TVM_DLL void ExitWithScope();
792+
};
745793

746794
// implementations
747795
inline const StageNode* Stage::operator->() const {
@@ -765,6 +813,7 @@ inline const IterVarRelationNode* IterVarRelation::operator->() const {
765813
inline const IterVarAttrNode* IterVarAttr::operator->() const {
766814
return static_cast<const IterVarAttrNode*>(get());
767815
}
816+
768817
} // namespace te
769818
} // namespace tvm
770819
#endif // TVM_TE_SCHEDULE_H_

python/tvm/autotvm/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@
4141
from .measure import measure_option, MeasureInput, MeasureResult, MeasureErrorNo, \
4242
LocalBuilder, LocalRunner, RPCRunner
4343
from .tuner import callback
44-
from .task import template, get_config, create, ConfigSpace, ConfigEntity, \
45-
register_topi_compute, register_topi_schedule, \
44+
from .task import get_config, create, ConfigSpace, ConfigEntity, \
45+
register_topi_compute, register_topi_schedule, register_customized_task, \
4646
DispatchContext, FallbackContext, ApplyHistoryBest as apply_history_best, \
4747
ApplyGraphBest as apply_graph_best
4848
from .env import GLOBAL_SCOPE

python/tvm/autotvm/database.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def load(self, inp, get_all=False):
125125
current = self.get(measure_str_key(inp))
126126
if current is not None:
127127
records = [decode(x) for x in current.split(RedisDatabase.MAGIC_SPLIT)]
128-
results = [rec[1] for rec in records]
128+
results = [rec[1] for rec in records if rec is not None]
129129
if get_all:
130130
return results
131131
return max(results, key=lambda result: result.timestamp)
@@ -167,9 +167,12 @@ def filter(self, func):
167167
current = self.get(key)
168168
try:
169169
records = [decode(x) for x in current.split(RedisDatabase.MAGIC_SPLIT)]
170+
records = [rec for rec in records if rec is not None]
170171
except TypeError: # got a badly formatted/old format record
171172
continue
172173

174+
if not records:
175+
continue
173176
inps, results = zip(*records)
174177
inp = inps[0]
175178
if not func(inp, results):

python/tvm/autotvm/feature.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,10 @@ def get_flatten_name(fea):
153153
from .record import decode
154154
# flatten line to feature
155155
line = fea
156-
inp, _ = decode(line)
156+
ret = decode(line)
157+
if ret is None:
158+
raise ValueError("Unsupported AutoTVM log format")
159+
inp, _ = ret
157160
target = _target.create(inp.target)
158161
with target:
159162
s, args = inp.template.instantiate(inp.config)

0 commit comments

Comments
 (0)