Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/api/python/tvm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ The user facing API for computation declaration.
tvm.load_json
tvm.save_json
tvm.var
tvm.size_var
tvm.const
tvm.convert
tvm.placeholder
Expand All @@ -49,6 +50,7 @@ The user facing API for computation declaration.
.. autofunction:: tvm.load_json
.. autofunction:: tvm.save_json
.. autofunction:: tvm.var
.. autofunction:: tvm.size_var
.. autofunction:: tvm.const
.. autofunction:: tvm.convert
.. autofunction:: tvm.placeholder
Expand Down
59 changes: 56 additions & 3 deletions include/tvm/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,27 +65,33 @@ class Var;
*/
class VarNode : public PrimExprNode {
public:
/*! \brief constructor */
VarNode() {}
VarNode(DataType dtype, std::string name_hint);

/*!
* \brief The hint to the variable name.
* \note Each variable is uniquely identified by its address.
*/
std::string name_hint;

static Var make(DataType dtype, std::string name_hint);

void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
v->Visit("name", &name_hint);
}

static constexpr const char* _type_key = "Variable";
TVM_DECLARE_FINAL_OBJECT_INFO(VarNode, PrimExprNode);
TVM_DECLARE_BASE_OBJECT_INFO(VarNode, PrimExprNode);
};

/*! \brief a named variable in TVM */
class Var : public PrimExpr {
public:
explicit Var(ObjectPtr<Object> n) : PrimExpr(n) {}
/*! \brief constructor
* \param name_hint variable name
* \param t data type
*/
TVM_DLL explicit Var(std::string name_hint = "v",
DataType t = DataType::Int(32));
/*!
Expand Down Expand Up @@ -114,6 +120,53 @@ class Var : public PrimExpr {
using ContainerType = VarNode;
};

class SizeVar;
/*!
* \brief A variable node represent a tensor index size,
* whose value must be non-negative.
*/
class SizeVarNode : public VarNode {
public:
/*! \brief constructor */
SizeVarNode() {}
/*! \brief constructor
* \param dtype data type
* \param name_hint variable name
*/
SizeVarNode(DataType dtype, std::string name_hint);

static constexpr const char* _type_key = "SizeVar";
TVM_DECLARE_FINAL_OBJECT_INFO(SizeVarNode, VarNode);
};

/*! \brief a named variable represents a tensor index size */
class SizeVar : public Var {
public:
explicit SizeVar(ObjectPtr<Object> n) : Var(n) {}
/*! \brief constructor
* \param name_hint variable name
* \param t data type
*/
TVM_DLL explicit SizeVar(std::string name_hint = "s",
DataType t = DataType::Int(32));
/*!
* \brief Get pointer to the internal value.
* \return the corresponding Variable.
*/
const SizeVarNode* operator->() const {
return get();
}
/*!
* \brief Get pointer to the internal value.
* \return the corresponding Variable.
*/
const SizeVarNode* get() const {
return static_cast<const SizeVarNode*>(data_.get());
}
/*! \brief type indicate the container type */
using ContainerType = SizeVarNode;
};

/*!
* \brief Container of constant int that adds more constructors.
*
Expand Down
3 changes: 2 additions & 1 deletion include/tvm/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ namespace ir {
using IntImmNode = tvm::IntImmNode;
using FloatImmNode = tvm::FloatImmNode;
using VarNode = tvm::VarNode;
using SizeVarNode = tvm::SizeVarNode;

/*! \brief String constants, only used in asserts. */
class StringImmNode : public PrimExprNode {
Expand Down Expand Up @@ -679,7 +680,7 @@ class AnyNode : public PrimExprNode {
void VisitAttrs(AttrVisitor* v) {}
/*! \brief Convert to var. */
Var ToVar() const {
return VarNode::make(DataType::Int(32), "any_dim");
return Var("any_dim", DataType::Int(32));
}

TVM_DLL static PrimExpr make();
Expand Down
6 changes: 6 additions & 0 deletions include/tvm/ir_functor_ext.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ class ExprFunctor<R(const PrimExpr& n, Args...)> {
}
// Functions that can be overriden by subclass
virtual R VisitExpr_(const VarNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const SizeVarNode* op, Args... args) {
return VisitExpr_(static_cast<const VarNode*>(op), std::forward<Args>(args)...);
}
virtual R VisitExpr_(const LoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
Expand Down Expand Up @@ -174,6 +177,7 @@ class ExprFunctor<R(const PrimExpr& n, Args...)> {
FType vtable;
// Set dispatch
IR_EXPR_FUNCTOR_DISPATCH(VarNode);
IR_EXPR_FUNCTOR_DISPATCH(SizeVarNode);
IR_EXPR_FUNCTOR_DISPATCH(LoadNode);
IR_EXPR_FUNCTOR_DISPATCH(LetNode);
IR_EXPR_FUNCTOR_DISPATCH(CallNode);
Expand Down Expand Up @@ -297,6 +301,7 @@ class TVM_DLL ExprVisitor :
using ExprFunctor::VisitExpr;
// list of functions to override.
void VisitExpr_(const VarNode* op) override;
void VisitExpr_(const SizeVarNode* op) override;
void VisitExpr_(const LoadNode* op) override;
void VisitExpr_(const LetNode* op) override;
void VisitExpr_(const CallNode* op) override;
Expand Down Expand Up @@ -341,6 +346,7 @@ class TVM_DLL ExprMutator :
using ExprFunctor::VisitExpr;
// list of functions to override.
PrimExpr VisitExpr_(const VarNode* op) override;
PrimExpr VisitExpr_(const SizeVarNode* op) override;
PrimExpr VisitExpr_(const LoadNode* op) override;
PrimExpr VisitExpr_(const LetNode* op) override;
PrimExpr VisitExpr_(const CallNode* op) override;
Expand Down
19 changes: 19 additions & 0 deletions python/tvm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,25 @@ def var(name="tindex", dtype=int32):
return _api_internal._Var(name, dtype)


def size_var(name="size", dtype=int32):
"""Create a new variable represents a tensor shape size, which is non-negative.
Parameters
----------
name : str
The name
dtype : str
The data type
Returns
-------
var : SizeVar
The result symbolic shape variable.
"""
return _api_internal._SizeVar(name, dtype)


def any(*args):
"""Create a new experssion of the union of all conditions in the arguments
Expand Down
19 changes: 19 additions & 0 deletions python/tvm/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,25 @@ def __init__(self, name, dtype):
_api_internal._Var, name, dtype)


@register_object
class SizeVar(Var):
"""Symbolic variable to represent a tensor index size
which is greater or equal to zero
Parameters
----------
name : str
The name
dtype : int
The data type
"""
# pylint: disable=super-init-not-called
def __init__(self, name, dtype):
self.__init_handle_by_constructor__(
_api_internal._SizeVar, name, dtype)


@register_object
class Reduce(PrimExpr):
"""Reduce node.
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/hybrid/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def visit_Call(self, node):
_internal_assert(func_id in list(HYBRID_GLOBALS.keys()) + \
['range', 'max', 'min', 'len'] + \
list(self.symbols.keys()), \
"Function call id not in intrinsics' list")
"Function call id " + func_id + " not in intrinsics' list")
for elem in node.args:
self.visit(elem)

Expand Down
7 changes: 6 additions & 1 deletion src/api/api_ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,12 @@ namespace ir {

TVM_REGISTER_GLOBAL("_Var")
.set_body_typed([](std::string s, DataType t) {
return VarNode::make(t, s);
return Var(s, t);
});

TVM_REGISTER_GLOBAL("_SizeVar")
.set_body_typed([](std::string s, DataType t) {
return SizeVar(s, t);
});

TVM_REGISTER_GLOBAL("make.abs")
Expand Down
3 changes: 2 additions & 1 deletion src/arithmetic/bound_deducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ class BoundDeducer: public ExprVisitor {

void VisitExpr(const PrimExpr& e) final {
if (!success_) return;
if (e.get() == path_[iter_++]) {
if (iter_ < path_.size() && e.get() == path_[iter_++]) {
ExprVisitor::VisitExpr(e);
} else {
success_ = false;
Expand Down Expand Up @@ -297,6 +297,7 @@ void BoundDeducer::Transform() {
void BoundDeducer::Deduce() {
Init();
if (!success_) return;

Relax();
if (!success_) return;
// get the path
Expand Down
10 changes: 10 additions & 0 deletions src/arithmetic/const_int_bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,16 @@ class ConstIntBoundAnalyzer::Impl :
}
}

Entry VisitExpr_(const SizeVarNode* op) final {
SizeVar v = GetRef<SizeVar>(op);
auto it = var_map_.find(v);
if (it != var_map_.end()) {
return it->second;
} else {
return MakeBound(0, kPosInf);
}
}

Entry VisitRightShift(const CallNode* op) {
Entry a = VisitExpr(op->args[0]);
Entry b = VisitExpr(op->args[1]);
Expand Down
1 change: 1 addition & 0 deletions src/arithmetic/int_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ class IntervalSetEvaluator :
}
}


IntervalSet VisitExpr_(const AddNode* op) final {
return VisitBinaryExpr_(op);
}
Expand Down
4 changes: 4 additions & 0 deletions src/ir/attr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ class AttrFunctor<R(const ObjectRef& n, Args...)> {
virtual R VisitAttr_(const ir::StringImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
// deep comparison of symbolic integer expressions.
virtual R VisitAttr_(const VarNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const SizeVarNode* op, Args... args) {
return VisitAttr_(static_cast<const VarNode*>(op), std::forward<Args>(args)...);
}
virtual R VisitAttr_(const ir::AddNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::SubNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::MulNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
Expand Down Expand Up @@ -115,6 +118,7 @@ class AttrFunctor<R(const ObjectRef& n, Args...)> {
ATTR_FUNCTOR_DISPATCH(FloatImmNode);
ATTR_FUNCTOR_DISPATCH(StringImmNode);
ATTR_FUNCTOR_DISPATCH(VarNode);
ATTR_FUNCTOR_DISPATCH(SizeVarNode);
ATTR_FUNCTOR_DISPATCH(AddNode);
ATTR_FUNCTOR_DISPATCH(SubNode);
ATTR_FUNCTOR_DISPATCH(MulNode);
Expand Down
2 changes: 1 addition & 1 deletion src/ir/attrs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ bool AttrsEqualHandler::VisitAttr_(const StringImmNode* lhs, const ObjectRef& ot
bool AttrsEqualHandler::VisitAttr_(const ArrayNode* lhs, const ObjectRef& other) {
if (const auto* rhs = other.as<ArrayNode>()) {
if (rhs->data.size() != lhs->data.size()) return false;
for (size_t i = 0; i < lhs->data.size(); ++i) {
for (size_t i = 0; i < lhs->data.size(); ++i) {
if (!Equal(lhs->data[i], rhs->data[i])) return false;
}
}
Expand Down
16 changes: 10 additions & 6 deletions src/lang/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,19 @@ PrimExpr::PrimExpr(std::string str)
: PrimExpr(ir::StringImmNode::make(str)) {}

Var::Var(std::string name_hint, DataType t)
: Var(VarNode::make(t, name_hint)) {}
: Var(make_object<VarNode>(t, name_hint)) {}

Var VarNode::make(DataType t, std::string name_hint) {
ObjectPtr<VarNode> node = make_object<VarNode>();
node->dtype = t;
node->name_hint = std::move(name_hint);
return Var(node);
VarNode::VarNode(DataType t, std::string name_hint) {
this->dtype = t;
this->name_hint = std::move(name_hint);
}

SizeVar::SizeVar(std::string name_hint, DataType t)
: SizeVar(make_object<SizeVarNode>(t, name_hint)) {}

SizeVarNode::SizeVarNode(DataType t, std::string name_hint)
: VarNode(t, std::move(name_hint)) {}

Range::Range(PrimExpr begin, PrimExpr end)
: Range(make_object<RangeNode>(
begin,
Expand Down
5 changes: 5 additions & 0 deletions src/lang/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,10 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
// stream << op->name << "." << op->type;
p->stream << op->name_hint;
})
.set_dispatch<SizeVarNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const SizeVarNode*>(node.get());
p->stream << "{" << op->name_hint << "|" << op->name_hint << ">=0}";
})
.set_dispatch<AddNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const AddNode*>(node.get());
p->stream << '(';
Expand Down Expand Up @@ -1143,6 +1147,7 @@ TVM_REGISTER_NODE_TYPE(IntImmNode);
TVM_REGISTER_NODE_TYPE(StringImmNode);
TVM_REGISTER_NODE_TYPE(CastNode);
TVM_REGISTER_NODE_TYPE(VarNode);
TVM_REGISTER_NODE_TYPE(SizeVarNode);
TVM_REGISTER_NODE_TYPE(AddNode);
TVM_REGISTER_NODE_TYPE(SubNode);
TVM_REGISTER_NODE_TYPE(MulNode);
Expand Down
8 changes: 8 additions & 0 deletions src/pass/ir_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,10 @@ void StmtVisitor::VisitStmt_(const EvaluateNode* op) {

void ExprVisitor::VisitExpr_(const VarNode* op) {}

void ExprVisitor::VisitExpr_(const SizeVarNode* op) {
this->VisitExpr_(static_cast<const VarNode*>(op));
}

void ExprVisitor::VisitExpr_(const LoadNode* op) {
this->VisitExpr(op->index);
this->VisitExpr(op->predicate);
Expand Down Expand Up @@ -596,6 +600,10 @@ PrimExpr ExprMutator::VisitExpr_(const VarNode* op) {
return GetRef<PrimExpr>(op);
}

PrimExpr ExprMutator::VisitExpr_(const SizeVarNode* op) {
return this->VisitExpr_(static_cast<const VarNode*>(op));
}

PrimExpr ExprMutator::VisitExpr_(const LoadNode* op) {
PrimExpr index = this->VisitExpr(op->index);
PrimExpr predicate = this->VisitExpr(op->predicate);
Expand Down
Loading