Skip to content

Commit

Permalink
check stmt in
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Oct 23, 2016
1 parent dac6b52 commit 151707e
Show file tree
Hide file tree
Showing 8 changed files with 224 additions and 45 deletions.
26 changes: 24 additions & 2 deletions include/tvm/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,28 @@ class NodeRef;
class UnaryOp;
class BinaryOp;

/*! \brief pointer type mask */
const int kPtrTypeMask = 16;

/*! \brief list of all supported data types */
enum DataType : int {
kUnknown = 0,
kInt32 = 1,
kFloat32 = 2
kFloat32 = 2,
kInt32Buffer = kInt32 | kPtrTypeMask,
kFloat32Buffer = kFloat32 | kPtrTypeMask
};

/*!
* \brief convert pointer type to data type
* \param ptr_type The pointer type.
* \return The corresponding data type.
*/
inline DataType Ptr2DataType(DataType ptr_type) {
CHECK_GE(ptr_type, kPtrTypeMask);
return static_cast<DataType>(ptr_type & (kPtrTypeMask -1));
}

/*!
* \brief List of subset node types used for quick runtime switch.
*
Expand All @@ -45,6 +60,7 @@ enum NodeType {
kBinaryOpNode,
kReduceNode,
kTensorReadNode,
kBufferReadNode,
// stmt nodes
kStoreNode,
kForRangeNode,
Expand Down Expand Up @@ -157,6 +173,8 @@ class NodeRef {
inline bool operator!=(const NodeRef& other) const;
/*! \return the hash function for NodeRef */
inline size_t hash() const;
/*! \return the raw internal pointer of the node */
inline Node* node_ptr() const;

protected:
template<typename T, typename>
Expand Down Expand Up @@ -217,7 +235,11 @@ inline bool NodeRef::operator!=(const NodeRef& other) const {
}

inline size_t NodeRef::hash() const {
return std::hash<Node*>()(node_.get());
return std::hash<Node*>()(node_ptr());
}

inline Node* NodeRef::node_ptr() const {
return node_.get();
}

} // namespace tvm
Expand Down
60 changes: 31 additions & 29 deletions include/tvm/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
#include "./base.h"

namespace tvm {
// forward declare Expr
// Forward declare Expr
class Expr;
class Var;

/*!
* \brief create a constant expression
Expand All @@ -23,35 +24,34 @@ template<typename T,
typename = typename std::enable_if<std::is_arithmetic<T>::value>::type >
inline Expr constant(T value);

/*!
* \brief create a integer expression
* \param value The value to the expression
* \return the expression.
*/
Expr IntConstant(int64_t value);

/*!
* \brief create a float expression.
* \param value The value to the expression
* \return the expression.
*/
Expr FloatConstant(double value);

/*!
* \brief create a float expression.
* \param value The value to the expression
* \return the expression.
*/
Expr BufferRead(Var buffer, Expr offset);

/*!
* \brief a expression type, holds a ref to root of an AST
*/
class Expr : public NodeRef {
public:
/*! \brief default constructor */
Expr() = default;
/*!
* \brief copy constructor
* \param other the input
*/
Expr(const Expr& other) = default;
/*!
* \brief move constructor
* \param other the input
*/
Expr(Expr&& other) = default;
/*!
* \brief assign operator.
* \param other the input.
* \return reference to self
*/
Expr& operator=(const Expr& other) = default;
/*!
* \brief assign move operator.
* \param other the input.
* \return reference to self
*/
Expr& operator=(Expr&& other) = default;
Expr() {}
/*!
* \brief constructor from constant value
* \param value the constant value
Expand Down Expand Up @@ -82,23 +82,25 @@ class Expr : public NodeRef {
void Print(std::ostream& os) const; // NOLINT(*)
};

/*! \brief Variable class */
/*!
* \brief Variable class to represent the symbolic placeholder
* in the DSL, internally it is a VarNode.
*
* The Variable is uniquely identified by the address of VarNode.
*/
class Var : public Expr {
public:
Var(std::string name="", DataType dtype=kInt32); // NOLINT(*)
};

Expr IntConstant(int64_t value);
Expr FloatConstant(double value);

/*! \brief base of expression node */
class ExprNode : public Node {
public:
/*! \brief type of data stored in expression */
DataType dtype_{kUnknown};
};

// inline implementations
// implementations
inline DataType Expr::dtype() const {
return static_cast<const ExprNode*>(node_.get())->dtype_;
}
Expand Down
43 changes: 31 additions & 12 deletions include/tvm/expr_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@
#include "./expr.h"

namespace tvm {

/*! \brief variable node for symbolic variables */
class VarNode : public ExprNode {
public:
struct VarNode : public ExprNode {
/*! \brief hint name of the variable */
std::string name;
/*! \brief constructor */
Expand All @@ -32,7 +30,7 @@ class VarNode : public ExprNode {
};

/*! \brief integer constant node */
class IntNode : public ExprNode {
struct IntNode : public ExprNode {
public:
/*! \brief the value field */
int64_t value;
Expand All @@ -51,8 +49,7 @@ class IntNode : public ExprNode {
};

/*! \brief float constant node */
class FloatNode : public ExprNode {
public:
struct FloatNode : public ExprNode {
/*! \brief the value field */
double value;
/*! \brief constructor */
Expand All @@ -61,7 +58,7 @@ class FloatNode : public ExprNode {
dtype_ = kFloat32;
}
const char* type_key() const override {
return "IntNode";
return "FloatNode";
}
void VisitAttrs(AttrVisitor* visitor) override {
visitor->Visit("value", &value);
Expand All @@ -70,8 +67,7 @@ class FloatNode : public ExprNode {
};

/*! \brief Unary mapping operator */
class UnaryOpNode : public ExprNode {
public:
struct UnaryOpNode : public ExprNode {
/*! \brief The operator */
const UnaryOp* op;
/*! \brief The source expression */
Expand Down Expand Up @@ -105,7 +101,6 @@ class UnaryOpNode : public ExprNode {

/*! \brief Binary mapping operator */
struct BinaryOpNode : public ExprNode {
public:
/*! \brief The operator */
const BinaryOp* op;
/*! \brief The left operand */
Expand Down Expand Up @@ -143,7 +138,6 @@ struct BinaryOpNode : public ExprNode {

/*! \brief Reduction operator operator */
struct ReduceNode : public ExprNode {
public:
/*! \brief The operator */
const BinaryOp* op;
/*! \brief The source operand */
Expand Down Expand Up @@ -180,7 +174,6 @@ struct ReduceNode : public ExprNode {

/*! \brief Tensor read operator */
struct TensorReadNode : public ExprNode {
public:
/*! \brief The tensor to be read from */
Tensor tensor;
/*! \brief The indices of read */
Expand Down Expand Up @@ -215,6 +208,32 @@ struct TensorReadNode : public ExprNode {
}
};

/*! \brief Buffer read node */
struct BufferReadNode : public ExprNode {
/*! \brief The buffer variable to be read from */
Var buffer;
/*! \brief The offset to be read from */
Expr offset;
/*! \brief constructor, do not use constructor */
BufferReadNode() {
node_type_ = kBufferReadNode;
}
const char* type_key() const override {
return "BufferReadNode";
}
void Verify() const override {
CHECK_EQ(dtype_, Ptr2DataType(buffer.dtype()));
CHECK_EQ(offset.dtype(), kInt32);
}
void VisitAttrs(AttrVisitor* visitor) override {
visitor->Visit("dtype", &dtype_);
}
void VisitNodeRefFields(FNodeRefVisit fvisit) override {
fvisit("buffer", &buffer);
fvisit("offset", &offset);
}
};

} // namespace tvm

#endif // TVM_EXPR_NODE_H_
57 changes: 57 additions & 0 deletions include/tvm/stmt.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*!
* Copyright (c) 2016 by Contributors
* \file stmt.h
* \brief The statement creation functions.
* The underlying container are defined in stmt_node.h
*/
#ifndef TVM_STMT_H_
#define TVM_STMT_H_

#include <type_traits>
#include "./base.h"
#include "./domain.h"

namespace tvm {

/*!
* \brief a expression type, holds a ref to root of an AST
*/
class Stmt : public NodeRef {
public:
/*! \brief default constructor */
Stmt() {}
/*!
* \brief constructor from node pointer
* \param nptr Another node shared pointer
*/
explicit Stmt(std::shared_ptr<Node>&& nptr) : NodeRef(std::move(nptr)) {
CHECK(node_.get() != nullptr);
}
};

/*!
* \brief construct Store Stmt.
* \param buffer The variable representing the buffer.
* \param offset The offset in the buffer
* \param src The source expression.
*/
Stmt Store(Var buffer, Expr offset, Expr src);

/*!
* \brief construct ForRange Stmt
* \param loop_var The loop variable
* \param range The loop range
* \param body The loop body
*/
Stmt ForRange(Var loop_var, Range range, Stmt body);

/*!
* \brief construct a IfThenElse
* \param cond The condition.
* \param then_body The body to go to in then condition.
* \param else_body The body to go to in else condition.
*/
Stmt IfThenElse(Expr cond, Stmt then_body, Stmt else_body);

} // namespace tvm
#endif // TVM_STMT_H_
Loading

0 comments on commit 151707e

Please sign in to comment.