Skip to content

Commit 07ef0d1

Browse files
yzhliuzhiics
authored andcommitted
[Arith] add SizeVar representing non-neg valued variable in a tensor shape (apache#4684)
* [arith] add ShapeVar representing non-neg valued variable in a tensor shape * bounder remover; deal with div in int_set differently * fix bounder_remover * migrate unittest to use shape_var * use tvm.shape_var in integration & relay tests * add test case; fix Var register * fix lint * fix lint again * add default ShapeVar visitor in Relay * fix override * fix ShapeVar visit bug * revert IntervalSet for shape_var * remove bound_remover * remove is_var; use constructor for shapevar/var instead * ShapeVar -> SizeVar; add constructor comments * shape_var -> size_var in doc * tindex -> size
1 parent b0139c6 commit 07ef0d1

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+417
-267
lines changed

docs/api/python/tvm.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ The user facing API for computation declaration.
2424
tvm.load_json
2525
tvm.save_json
2626
tvm.var
27+
tvm.size_var
2728
tvm.const
2829
tvm.convert
2930
tvm.placeholder
@@ -49,6 +50,7 @@ The user facing API for computation declaration.
4950
.. autofunction:: tvm.load_json
5051
.. autofunction:: tvm.save_json
5152
.. autofunction:: tvm.var
53+
.. autofunction:: tvm.size_var
5254
.. autofunction:: tvm.const
5355
.. autofunction:: tvm.convert
5456
.. autofunction:: tvm.placeholder

include/tvm/expr.h

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,27 +65,33 @@ class Var;
6565
*/
6666
class VarNode : public PrimExprNode {
6767
public:
68+
/*! \brief constructor */
69+
VarNode() {}
70+
VarNode(DataType dtype, std::string name_hint);
71+
6872
/*!
6973
* \brief The hint to the variable name.
7074
* \note Each variable is uniquely identified by its address.
7175
*/
7276
std::string name_hint;
7377

74-
static Var make(DataType dtype, std::string name_hint);
75-
7678
void VisitAttrs(AttrVisitor* v) {
7779
v->Visit("dtype", &dtype);
7880
v->Visit("name", &name_hint);
7981
}
8082

8183
static constexpr const char* _type_key = "Variable";
82-
TVM_DECLARE_FINAL_OBJECT_INFO(VarNode, PrimExprNode);
84+
TVM_DECLARE_BASE_OBJECT_INFO(VarNode, PrimExprNode);
8385
};
8486

8587
/*! \brief a named variable in TVM */
8688
class Var : public PrimExpr {
8789
public:
8890
explicit Var(ObjectPtr<Object> n) : PrimExpr(n) {}
91+
/*! \brief constructor
92+
* \param name_hint variable name
93+
* \param t data type
94+
*/
8995
TVM_DLL explicit Var(std::string name_hint = "v",
9096
DataType t = DataType::Int(32));
9197
/*!
@@ -114,6 +120,53 @@ class Var : public PrimExpr {
114120
using ContainerType = VarNode;
115121
};
116122

123+
class SizeVar;
124+
/*!
125+
* \brief A variable node represent a tensor index size,
126+
* whose value must be non-negative.
127+
*/
128+
class SizeVarNode : public VarNode {
129+
public:
130+
/*! \brief constructor */
131+
SizeVarNode() {}
132+
/*! \brief constructor
133+
* \param dtype data type
134+
* \param name_hint variable name
135+
*/
136+
SizeVarNode(DataType dtype, std::string name_hint);
137+
138+
static constexpr const char* _type_key = "SizeVar";
139+
TVM_DECLARE_FINAL_OBJECT_INFO(SizeVarNode, VarNode);
140+
};
141+
142+
/*! \brief a named variable represents a tensor index size */
143+
class SizeVar : public Var {
144+
public:
145+
explicit SizeVar(ObjectPtr<Object> n) : Var(n) {}
146+
/*! \brief constructor
147+
* \param name_hint variable name
148+
* \param t data type
149+
*/
150+
TVM_DLL explicit SizeVar(std::string name_hint = "s",
151+
DataType t = DataType::Int(32));
152+
/*!
153+
* \brief Get pointer to the internal value.
154+
* \return the corresponding Variable.
155+
*/
156+
const SizeVarNode* operator->() const {
157+
return get();
158+
}
159+
/*!
160+
* \brief Get pointer to the internal value.
161+
* \return the corresponding Variable.
162+
*/
163+
const SizeVarNode* get() const {
164+
return static_cast<const SizeVarNode*>(data_.get());
165+
}
166+
/*! \brief type indicate the container type */
167+
using ContainerType = SizeVarNode;
168+
};
169+
117170
/*!
118171
* \brief Container of constant int that adds more constructors.
119172
*

include/tvm/ir.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ namespace ir {
3838
using IntImmNode = tvm::IntImmNode;
3939
using FloatImmNode = tvm::FloatImmNode;
4040
using VarNode = tvm::VarNode;
41+
using SizeVarNode = tvm::SizeVarNode;
4142

4243
/*! \brief String constants, only used in asserts. */
4344
class StringImmNode : public PrimExprNode {
@@ -679,7 +680,7 @@ class AnyNode : public PrimExprNode {
679680
void VisitAttrs(AttrVisitor* v) {}
680681
/*! \brief Convert to var. */
681682
Var ToVar() const {
682-
return VarNode::make(DataType::Int(32), "any_dim");
683+
return Var("any_dim", DataType::Int(32));
683684
}
684685

685686
TVM_DLL static PrimExpr make();

include/tvm/ir_functor_ext.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,9 @@ class ExprFunctor<R(const PrimExpr& n, Args...)> {
133133
}
134134
// Functions that can be overriden by subclass
135135
virtual R VisitExpr_(const VarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
136+
virtual R VisitExpr_(const SizeVarNode* op, Args... args) {
137+
return VisitExpr_(static_cast<const VarNode*>(op), std::forward<Args>(args)...);
138+
}
136139
virtual R VisitExpr_(const LoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
137140
virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
138141
virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
@@ -174,6 +177,7 @@ class ExprFunctor<R(const PrimExpr& n, Args...)> {
174177
FType vtable;
175178
// Set dispatch
176179
IR_EXPR_FUNCTOR_DISPATCH(VarNode);
180+
IR_EXPR_FUNCTOR_DISPATCH(SizeVarNode);
177181
IR_EXPR_FUNCTOR_DISPATCH(LoadNode);
178182
IR_EXPR_FUNCTOR_DISPATCH(LetNode);
179183
IR_EXPR_FUNCTOR_DISPATCH(CallNode);
@@ -297,6 +301,7 @@ class TVM_DLL ExprVisitor :
297301
using ExprFunctor::VisitExpr;
298302
// list of functions to override.
299303
void VisitExpr_(const VarNode* op) override;
304+
void VisitExpr_(const SizeVarNode* op) override;
300305
void VisitExpr_(const LoadNode* op) override;
301306
void VisitExpr_(const LetNode* op) override;
302307
void VisitExpr_(const CallNode* op) override;
@@ -341,6 +346,7 @@ class TVM_DLL ExprMutator :
341346
using ExprFunctor::VisitExpr;
342347
// list of functions to override.
343348
PrimExpr VisitExpr_(const VarNode* op) override;
349+
PrimExpr VisitExpr_(const SizeVarNode* op) override;
344350
PrimExpr VisitExpr_(const LoadNode* op) override;
345351
PrimExpr VisitExpr_(const LetNode* op) override;
346352
PrimExpr VisitExpr_(const CallNode* op) override;

python/tvm/api.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,25 @@ def var(name="tindex", dtype=int32):
192192
return _api_internal._Var(name, dtype)
193193

194194

195+
def size_var(name="size", dtype=int32):
196+
"""Create a new variable represents a tensor shape size, which is non-negative.
197+
198+
Parameters
199+
----------
200+
name : str
201+
The name
202+
203+
dtype : str
204+
The data type
205+
206+
Returns
207+
-------
208+
var : SizeVar
209+
The result symbolic shape variable.
210+
"""
211+
return _api_internal._SizeVar(name, dtype)
212+
213+
195214
def any(*args):
196215
"""Create a new experssion of the union of all conditions in the arguments
197216

python/tvm/expr.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,25 @@ def __init__(self, name, dtype):
278278
_api_internal._Var, name, dtype)
279279

280280

281+
@register_object
282+
class SizeVar(Var):
283+
"""Symbolic variable to represent a tensor index size
284+
which is greater or equal to zero
285+
286+
Parameters
287+
----------
288+
name : str
289+
The name
290+
291+
dtype : int
292+
The data type
293+
"""
294+
# pylint: disable=super-init-not-called
295+
def __init__(self, name, dtype):
296+
self.__init_handle_by_constructor__(
297+
_api_internal._SizeVar, name, dtype)
298+
299+
281300
@register_object
282301
class Reduce(PrimExpr):
283302
"""Reduce node.

python/tvm/hybrid/preprocessor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def visit_Call(self, node):
6363
_internal_assert(func_id in list(HYBRID_GLOBALS.keys()) + \
6464
['range', 'max', 'min', 'len'] + \
6565
list(self.symbols.keys()), \
66-
"Function call id not in intrinsics' list")
66+
"Function call id " + func_id + " not in intrinsics' list")
6767
for elem in node.args:
6868
self.visit(elem)
6969

src/api/api_ir.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,12 @@ namespace ir {
3333

3434
TVM_REGISTER_GLOBAL("_Var")
3535
.set_body_typed([](std::string s, DataType t) {
36-
return VarNode::make(t, s);
36+
return Var(s, t);
37+
});
38+
39+
TVM_REGISTER_GLOBAL("_SizeVar")
40+
.set_body_typed([](std::string s, DataType t) {
41+
return SizeVar(s, t);
3742
});
3843

3944
TVM_REGISTER_GLOBAL("make.abs")

src/arithmetic/bound_deducer.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ class BoundDeducer: public ExprVisitor {
8686

8787
void VisitExpr(const PrimExpr& e) final {
8888
if (!success_) return;
89-
if (e.get() == path_[iter_++]) {
89+
if (iter_ < path_.size() && e.get() == path_[iter_++]) {
9090
ExprVisitor::VisitExpr(e);
9191
} else {
9292
success_ = false;
@@ -297,6 +297,7 @@ void BoundDeducer::Transform() {
297297
void BoundDeducer::Deduce() {
298298
Init();
299299
if (!success_) return;
300+
300301
Relax();
301302
if (!success_) return;
302303
// get the path

src/arithmetic/const_int_bound.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,16 @@ class ConstIntBoundAnalyzer::Impl :
284284
}
285285
}
286286

287+
Entry VisitExpr_(const SizeVarNode* op) final {
288+
SizeVar v = GetRef<SizeVar>(op);
289+
auto it = var_map_.find(v);
290+
if (it != var_map_.end()) {
291+
return it->second;
292+
} else {
293+
return MakeBound(0, kPosInf);
294+
}
295+
}
296+
287297
Entry VisitRightShift(const CallNode* op) {
288298
Entry a = VisitExpr(op->args[0]);
289299
Entry b = VisitExpr(op->args[1]);

0 commit comments

Comments
 (0)