diff --git a/include/tvm/base.h b/include/tvm/base.h index a024e70865ff..17049dcb4f4d 100644 --- a/include/tvm/base.h +++ b/include/tvm/base.h @@ -37,6 +37,7 @@ enum DataType : int { * \note kOtherNodes could mean more than one node type. */ enum NodeType { + // expr nodes kVarNode, kIntNode, kFloatNode, @@ -44,6 +45,10 @@ enum NodeType { kBinaryOpNode, kReduceNode, kTensorReadNode, + // stmt nodes + kStoreNode, + kForRangeNode, + kIfThenElseNode, kOtherNodes }; diff --git a/include/tvm/codegen.h b/include/tvm/codegen.h deleted file mode 100644 index 3baa284e935a..000000000000 --- a/include/tvm/codegen.h +++ /dev/null @@ -1,39 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file codegen.h - * \brief Common data structure for codegen - */ -#ifndef TVM_CODEGEN_H_ -#define TVM_CODEGEN_H_ - -namespace tvm { - -// incomplete spec. -struct Assign : public Node { - Expr src; - Expr offset; - Var ptr; -}; - -struct Assign : public Node { - Expr src; - Expr offset; - Var ptr; -}; - -struct Loop : public Node { - Expr init; - Expr cond; - Stmt body; -}; - -struct IfThenElse : public Node { - Expr cond; - Expr then_; - Stmt else_; -}; - - -} // namespace tvm - -#endif // TVM_CODEGEN_H_ diff --git a/include/tvm/stmt_node.h b/include/tvm/stmt_node.h new file mode 100644 index 000000000000..7e6e1824d965 --- /dev/null +++ b/include/tvm/stmt_node.h @@ -0,0 +1,73 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file stmt.h + * \brief Common data structure for codegen + */ +#ifndef TVM_STMT_NODE_H_ +#define TVM_STMT_NODE_H_ + +namespace tvm { + +struct StmtNode : public Node { +}; + +/*! \brief Store data into buffer */ +struct StoreNode : public StmtNode { + /*! \brief the variable representing the buffer */ + Var buffer; + /*! \brief the buffer offset */ + Expr offset; + /*! \brief The source expression*/ + Expr src; + /*! \brief constructor */ + StoreNode() { + node_type_ = kStoreNode; + } + void VisitNodeRefFields(FNodeRefVisit fvisit) override { + fvisit("buffer", &buffer); + fvisit("offset", &offset); + fvisit("src", &src); + } +}; + +/*! \brief for loop in range */ +struct ForRangeNode : public StmtNode { + /*! \brief loop variable */ + Var loop_var; + /*! \brief The loop range */ + Range range; + /*! \brief body of the loop */ + Stmt body; + /*! \brief constructor */ + ForRangeNode() { + node_type_ = kForRangeNode; + } + void VisitNodeRefFields(FNodeRefVisit fvisit) override { + fvisit("loop_var", &loop_var); + fvisit("range", &range); + fvisit("body", &body); + } +}; + +/*! \brief conditional expression */ +struct IfThenElseNode : public StmtNode { + /*! \brief The condition */ + Expr cond; + /*! \brief The statement in then */ + Stmt then_body; + /*! \brief The statement in else */ + Stmt else_body; + /*! \brief constructor */ + IfThenElseNode() { + node_type_ = kIfThenElseNode; + } + void VisitNodeRefFields(FNodeRefVisit fvisit) override { + fvisit("cond", &cond); + fvisit("then_body", &then_body); + fvisit("else_body", &else_body); + } +}; + +} // namespace tvm + +#endif // TVM_CODEGEN_H_