forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
14 changed files
with
641 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -88,4 +88,6 @@ ENV/ | |
# Rope project settings | ||
.ropeproject | ||
*~ | ||
*.pyc | ||
*.pyc | ||
*~ | ||
build |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
[submodule "dmlc-core"] | ||
path = dmlc-core | ||
url = https://github.com/dmlc/dmlc-core |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
export LDFLAGS = -pthread -lm | ||
export CFLAGS = -std=c++11 -Wall -O2 -Wno-unknown-pragmas -funroll-loops\ | ||
-Iinclude -Idmlc-core/include -fPIC | ||
|
||
# specify tensor path | ||
.PHONY: clean all | ||
|
||
all: lib/libtvm.a | ||
SRC = $(wildcard src/*.cc src/*/*.cc) | ||
ALL_OBJ = $(patsubst src/%.cc, build/%.o, $(SRC)) | ||
ALL_DEP = $(ALL_OBJ) | ||
|
||
build/%.o: src/%.cc | ||
@mkdir -p $(@D) | ||
$(CXX) $(CFLAGS) -MM -MT build/$*.o $< >build/$*.d | ||
$(CXX) -c $(CFLAGS) -c $< -o $@ | ||
|
||
|
||
lib/libtvm.a: $(ALL_DEP) | ||
@mkdir -p $(@D) | ||
ar crv $@ $(filter %.o, $?) | ||
|
||
lint: | ||
python2 dmlc-core/scripts/lint.py tvm cpp include src | ||
|
||
clean: | ||
$(RM) -rf build lib bin *~ */*~ */*/*~ */*/*/*~ */*.o */*/*.o */*/*/*.o | ||
|
||
-include build/*.d | ||
-include build/*/*.d |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
/*! | ||
* Copyright (c) 2016 by Contributors | ||
* \file base.h | ||
* \brief Defines the base data structure | ||
*/ | ||
#ifndef TVM_BASE_H_ | ||
#define TVM_BASE_H_ | ||
|
||
#include <dmlc/logging.h> | ||
#include <dmlc/registry.h> | ||
#include <string> | ||
#include <memory> | ||
#include <functional> | ||
#include <typeinfo> | ||
|
||
|
||
namespace tvm { | ||
|
||
// forward declaration | ||
class Node; | ||
class NodeRef; | ||
class UnaryOp; | ||
class BinaryOp; | ||
|
||
/*! \brief list of all supported data types */ | ||
enum DataType { | ||
kUnknown, | ||
kInt32, | ||
kFloat32 | ||
}; | ||
|
||
/*! | ||
* \brief List of subset node types used for quick runtime switch. | ||
* | ||
* \note The value of NodeType is not used for serialization type_key is used instead. | ||
* \note is_type and type_key can be used to do type checking for all types | ||
* \note kOtherNodes could mean more than one node type. | ||
*/ | ||
enum NodeType { | ||
kVarNode, | ||
kIntNode, | ||
kFloatNode, | ||
kUnaryOpNode, | ||
kBinaryOpNode, | ||
kReduceNode, | ||
kTensorReadNode, | ||
kOtherNodes | ||
}; | ||
|
||
/*! | ||
* \brief Visitor class to each node content. | ||
* The content is going to be called for each field. | ||
*/ | ||
class AttrVisitor { | ||
public: | ||
//! \cond Doxygen_Suppress | ||
virtual void Visit(const char* key, double* value) = 0; | ||
virtual void Visit(const char* key, int64_t* value) = 0; | ||
virtual void Visit(const char* key, DataType* value) = 0; | ||
virtual void Visit(const char* key, std::string* value) = 0; | ||
virtual void Visit(const char* key, const UnaryOp** value) = 0; | ||
virtual void Visit(const char* key, const BinaryOp** value) = 0; | ||
//! \endcond | ||
}; | ||
|
||
/*! | ||
* \brief A function to be applied when visit each NodeRef Field. | ||
* \param ref The child to be visited. | ||
*/ | ||
using FNodeRefVisit = std::function<void(const char* key, NodeRef* ref)>; | ||
|
||
/*! | ||
* \brief base class of node container in DSL AST. | ||
* All object's internal is stored as std::shared_ptr<Node> | ||
*/ | ||
class Node { | ||
public: | ||
/*! \brief virtual destructor */ | ||
virtual ~Node(); | ||
/*! \return The unique type key of the node */ | ||
virtual const char* type_key() const = 0; | ||
/*! \brief verify the correctness of node struct after it get mutated by visitor */ | ||
virtual void Verify() const {} | ||
/*! | ||
* \brief Apply visitor to each field of the Node | ||
* Visitor could mutate the content of the node. | ||
* override if Node contains attribute fields. | ||
* \param visitor The visitor | ||
*/ | ||
virtual void VisitAttrs(AttrVisitor* visitor) {} | ||
/*! | ||
* \brief Apply visitor to each field of the Node | ||
* Visitor could mutate the content of the node. | ||
* override if Node contains NodeRefFields. | ||
* \param visitor The visitor | ||
*/ | ||
virtual void VisitNodeRefFields(FNodeRefVisit visitor) {} | ||
/*! | ||
* \tparam NodeType the type to be checked. | ||
* \return whether the stored type is node type | ||
*/ | ||
template<typename TNode> | ||
inline bool is_type() const; | ||
/*! \return the node type */ | ||
inline NodeType node_type() const; | ||
|
||
protected: | ||
// node ref can see this | ||
friend class NodeRef; | ||
/*! \brief the node type enum */ | ||
NodeType node_type_{kOtherNodes}; | ||
}; | ||
|
||
/*! \brief base class of all node reference object */ | ||
class NodeRef { | ||
public: | ||
/*! | ||
* \return typed pointer of the node | ||
* \tparam TNode the type of the node. | ||
*/ | ||
template<typename TNode> | ||
inline const TNode* Get() const; | ||
/*! \return wheyjer the expression is null */ | ||
inline bool is_null() const; | ||
|
||
protected: | ||
NodeRef() = default; | ||
explicit NodeRef(std::shared_ptr<Node> node) : node_(node) {} | ||
/*! \brief the internal node */ | ||
std::shared_ptr<Node> node_; | ||
}; | ||
|
||
/*! \brief typedef the factory function of data iterator */ | ||
using NodeFactory = std::function<std::shared_ptr<Node> ()>; | ||
|
||
/*! | ||
* \brief Registry entry for DataIterator factory functions. | ||
*/ | ||
struct NodeFactoryReg | ||
: public dmlc::FunctionRegEntryBase<NodeFactoryReg, | ||
NodeFactory> { | ||
}; | ||
|
||
#define TVM_REGISTER_NODE_TYPE(TypeName) \ | ||
DMLC_REGISTRY_REGISTER(::tvm::NodeFactoryReg, NodeFactoryReg, TypeName) \ | ||
.set_body([]() { return std::make_shared<TypeName>(); }) | ||
|
||
// implementations of inline functions after this | ||
inline NodeType Node::node_type() const { | ||
return node_type_; | ||
} | ||
|
||
template<typename TNode> | ||
inline bool Node::is_type() const { | ||
const std::type_info& tinfo = typeid(*this); | ||
if (&typeid(TNode) == &tinfo) return true; | ||
return typeid(TNode) == tinfo; | ||
} | ||
|
||
template<typename TNode> | ||
inline const TNode* NodeRef::Get() const { | ||
CHECK(node_->is_type<TNode>()) | ||
<< " type inconsistent, expected " << typeid(TNode).name() | ||
<< " given " << typeid(*this).name(); | ||
return static_cast<const TNode*>(node_.get()); | ||
} | ||
|
||
inline bool NodeRef::is_null() const { | ||
return node_.get() == nullptr; | ||
} | ||
|
||
} // namespace tvm | ||
#endif // TVM_BASE_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
/*! | ||
* Copyright (c) 2016 by Contributors | ||
* \file c_api.h | ||
* \brief C API of TVM DSL | ||
*/ | ||
#ifndef TVM_C_API_H_ | ||
#define TVM_C_API_H_ | ||
|
||
#ifdef __cplusplus | ||
#define TVM_EXTERN_C extern "C" | ||
#else | ||
#define TVM_EXTERN_C | ||
#endif | ||
|
||
/*! \brief TVM_DLL prefix for windows */ | ||
#ifdef _WIN32 | ||
#ifdef TVM_EXPORTS | ||
#define TVM_DLL TVM_EXTERN_C __declspec(dllexport) | ||
#else | ||
#define TVM_DLL TVM_EXTERN_C __declspec(dllimport) | ||
#endif | ||
#else | ||
#define TVM_DLL TVM_EXTERN_C | ||
#endif | ||
|
||
/*! \brief handle to node creator */ | ||
typedef void* NodeCreatorHandle; | ||
/*! \brief handle to node */ | ||
typedef void* NodeHandle; | ||
|
||
TVM_DLL int TVMNodeCreatorGet(const char* node_type, | ||
NodeCreatorHandle *handle); | ||
|
||
TVM_DLL int TVMNodeCreate(NodeCreatorHandle handle, | ||
int num_child_ref, | ||
const char* child_ref_keys, | ||
NodeHandle* child_node_refs, | ||
int num_attrs, | ||
const char* attr_keys, | ||
const char* attr_vals, | ||
NodeHandle* handle); | ||
|
||
TVM_DLL int TVMNodeGetAttr(const char* key, | ||
const char** value); | ||
|
||
TVM_DLL int TVMNodeGetChildNodeRef(const char* key, | ||
NodeHandle* out); | ||
|
||
#endif // TVM_C_API_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
/*! | ||
* Copyright (c) 2016 by Contributors | ||
* \file domain.h | ||
* \brief Defines the AST | ||
*/ | ||
#ifndef TVM_DOMAIN_H_ | ||
#define TVM_DOMAIN_H_ | ||
|
||
#include <memory> | ||
|
||
namespace tvm { | ||
class RDom { | ||
}; | ||
|
||
} // namespace tvm | ||
|
||
#endif // TVM_DOMAIN_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
/*! | ||
* Copyright (c) 2016 by Contributors | ||
* \file expr.h | ||
* \brief Defines the expressions in AST. | ||
*/ | ||
#ifndef TVM_EXPR_H_ | ||
#define TVM_EXPR_H_ | ||
|
||
#include <type_traits> | ||
#include "./base.h" | ||
|
||
namespace tvm { | ||
// forward declare Expr | ||
class Expr; | ||
|
||
/*! | ||
* \brief create a constant expression | ||
* \tparam T the value type | ||
* \param value The value to the constant. | ||
* \return The created expression | ||
*/ | ||
template<typename T, | ||
typename = typename std::enable_if<std::is_arithmetic<T>::value>::type > | ||
inline Expr constant(T value); | ||
|
||
/*! | ||
* \brief a expression type, holds a ref to root of an AST | ||
*/ | ||
class Expr : public NodeRef { | ||
public: | ||
/*! \brief default constructor */ | ||
Expr() = default; | ||
/*! | ||
* \brief copy constructor | ||
* \param other the input | ||
*/ | ||
Expr(const Expr& other) = default; // NOLINT(*) | ||
/*! | ||
* \brief move constructor | ||
* \param other the input | ||
*/ | ||
Expr(Expr&& other) = default; // NOLINT(*) | ||
/*! | ||
* \brief assign operator. | ||
* \param other the input. | ||
* \return reference to self | ||
*/ | ||
Expr& operator=(const Expr& other) = default; | ||
/*! | ||
* \brief assign move operator. | ||
* \param other the input. | ||
* \return reference to self | ||
*/ | ||
Expr& operator=(Expr&& other) = default; | ||
/*! | ||
* \brief constructor from constant value | ||
* \param value the constant value | ||
* \tparam T The constant type | ||
*/ | ||
template<typename T, | ||
typename = typename std::enable_if<std::is_arithmetic<T>::value>::type > | ||
Expr(T value) { // NOLINT(*) | ||
*this = std::move(constant<T>(value)); | ||
} | ||
/*! | ||
* \brief constructor from node pointer | ||
* \param nptr Another node shared pointer | ||
*/ | ||
explicit Expr(std::shared_ptr<Node> nptr) : NodeRef(nptr) {} | ||
/*! \return the expression type of the expression */ | ||
inline DataType dtype() const; | ||
}; | ||
|
||
/*! \brief Variable class */ | ||
class Var : public Expr { | ||
public: | ||
Var(std::string name="", DataType dtype=kInt32); // NOLINT(*) | ||
}; | ||
|
||
/*! \brief */ | ||
Expr IntConstant(int64_t value); | ||
Expr FloatConstant(int64_t value); | ||
Expr operator+(Expr lhs, Expr rhs); | ||
|
||
/*! \brief base of expression node */ | ||
class ExprNode : public Node { | ||
public: | ||
/*! \brief type of data stored in expression */ | ||
DataType dtype_{kUnknown}; | ||
}; | ||
|
||
// inline implementations | ||
inline DataType Expr::dtype() const { | ||
return static_cast<const ExprNode*>(node_.get())->dtype_; | ||
} | ||
template<typename T, | ||
typename = typename std::enable_if<std::is_arithmetic<T>::value>::type > | ||
inline Expr constant(T value) { | ||
if (std::is_integral<T>::value) { | ||
return IntConstant(static_cast<int64_t>(value)); | ||
} else { | ||
return FloatConstant(static_cast<double>(value)); | ||
} | ||
} | ||
|
||
} // namespace tvm | ||
#endif // TVM_EXPR_H_ |
Oops, something went wrong.