Skip to content

Commit

Permalink
Temp checkin c++ code.
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Oct 15, 2016
1 parent 1a18f08 commit c41d9d2
Show file tree
Hide file tree
Showing 14 changed files with 641 additions and 1 deletion.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,6 @@ ENV/
# Rope project settings
.ropeproject
*~
*.pyc
*.pyc
*~
build
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "dmlc-core"]
path = dmlc-core
url = https://github.com/dmlc/dmlc-core
30 changes: 30 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
export LDFLAGS = -pthread -lm
export CFLAGS = -std=c++11 -Wall -O2 -Wno-unknown-pragmas -funroll-loops\
-Iinclude -Idmlc-core/include -fPIC

# specify tensor path
.PHONY: clean all

all: lib/libtvm.a
SRC = $(wildcard src/*.cc src/*/*.cc)
ALL_OBJ = $(patsubst src/%.cc, build/%.o, $(SRC))
ALL_DEP = $(ALL_OBJ)

build/%.o: src/%.cc
@mkdir -p $(@D)
$(CXX) $(CFLAGS) -MM -MT build/$*.o $< >build/$*.d
$(CXX) -c $(CFLAGS) -c $< -o $@


lib/libtvm.a: $(ALL_DEP)
@mkdir -p $(@D)
ar crv $@ $(filter %.o, $?)

lint:
python2 dmlc-core/scripts/lint.py tvm cpp include src

clean:
$(RM) -rf build lib bin *~ */*~ */*/*~ */*/*/*~ */*.o */*/*.o */*/*/*.o

-include build/*.d
-include build/*/*.d
1 change: 1 addition & 0 deletions dmlc-core
Submodule dmlc-core added at 39007a
173 changes: 173 additions & 0 deletions include/tvm/base.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
/*!
* Copyright (c) 2016 by Contributors
* \file base.h
* \brief Defines the base data structure
*/
#ifndef TVM_BASE_H_
#define TVM_BASE_H_

#include <dmlc/logging.h>
#include <dmlc/registry.h>
#include <string>
#include <memory>
#include <functional>
#include <typeinfo>


namespace tvm {

// forward declaration
class Node;
class NodeRef;
class UnaryOp;
class BinaryOp;

/*! \brief list of all supported data types */
enum DataType {
kUnknown,
kInt32,
kFloat32
};

/*!
* \brief List of subset node types used for quick runtime switch.
*
* \note The value of NodeType is not used for serialization type_key is used instead.
* \note is_type and type_key can be used to do type checking for all types
* \note kOtherNodes could mean more than one node type.
*/
enum NodeType {
kVarNode,
kIntNode,
kFloatNode,
kUnaryOpNode,
kBinaryOpNode,
kReduceNode,
kTensorReadNode,
kOtherNodes
};

/*!
* \brief Visitor class to each node content.
* The content is going to be called for each field.
*/
class AttrVisitor {
public:
//! \cond Doxygen_Suppress
virtual void Visit(const char* key, double* value) = 0;
virtual void Visit(const char* key, int64_t* value) = 0;
virtual void Visit(const char* key, DataType* value) = 0;
virtual void Visit(const char* key, std::string* value) = 0;
virtual void Visit(const char* key, const UnaryOp** value) = 0;
virtual void Visit(const char* key, const BinaryOp** value) = 0;
//! \endcond
};

/*!
* \brief A function to be applied when visit each NodeRef Field.
* \param ref The child to be visited.
*/
using FNodeRefVisit = std::function<void(const char* key, NodeRef* ref)>;

/*!
* \brief base class of node container in DSL AST.
* All object's internal is stored as std::shared_ptr<Node>
*/
class Node {
public:
/*! \brief virtual destructor */
virtual ~Node();
/*! \return The unique type key of the node */
virtual const char* type_key() const = 0;
/*! \brief verify the correctness of node struct after it get mutated by visitor */
virtual void Verify() const {}
/*!
* \brief Apply visitor to each field of the Node
* Visitor could mutate the content of the node.
* override if Node contains attribute fields.
* \param visitor The visitor
*/
virtual void VisitAttrs(AttrVisitor* visitor) {}
/*!
* \brief Apply visitor to each field of the Node
* Visitor could mutate the content of the node.
* override if Node contains NodeRefFields.
* \param visitor The visitor
*/
virtual void VisitNodeRefFields(FNodeRefVisit visitor) {}
/*!
* \tparam NodeType the type to be checked.
* \return whether the stored type is node type
*/
template<typename TNode>
inline bool is_type() const;
/*! \return the node type */
inline NodeType node_type() const;

protected:
// node ref can see this
friend class NodeRef;
/*! \brief the node type enum */
NodeType node_type_{kOtherNodes};
};

/*! \brief base class of all node reference object */
class NodeRef {
public:
/*!
* \return typed pointer of the node
* \tparam TNode the type of the node.
*/
template<typename TNode>
inline const TNode* Get() const;
/*! \return wheyjer the expression is null */
inline bool is_null() const;

protected:
NodeRef() = default;
explicit NodeRef(std::shared_ptr<Node> node) : node_(node) {}
/*! \brief the internal node */
std::shared_ptr<Node> node_;
};

/*! \brief typedef the factory function of data iterator */
using NodeFactory = std::function<std::shared_ptr<Node> ()>;

/*!
* \brief Registry entry for DataIterator factory functions.
*/
struct NodeFactoryReg
: public dmlc::FunctionRegEntryBase<NodeFactoryReg,
NodeFactory> {
};

#define TVM_REGISTER_NODE_TYPE(TypeName) \
DMLC_REGISTRY_REGISTER(::tvm::NodeFactoryReg, NodeFactoryReg, TypeName) \
.set_body([]() { return std::make_shared<TypeName>(); })

// implementations of inline functions after this
inline NodeType Node::node_type() const {
return node_type_;
}

template<typename TNode>
inline bool Node::is_type() const {
const std::type_info& tinfo = typeid(*this);
if (&typeid(TNode) == &tinfo) return true;
return typeid(TNode) == tinfo;
}

template<typename TNode>
inline const TNode* NodeRef::Get() const {
CHECK(node_->is_type<TNode>())
<< " type inconsistent, expected " << typeid(TNode).name()
<< " given " << typeid(*this).name();
return static_cast<const TNode*>(node_.get());
}

inline bool NodeRef::is_null() const {
return node_.get() == nullptr;
}

} // namespace tvm
#endif // TVM_BASE_H_
49 changes: 49 additions & 0 deletions include/tvm/c_api.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*!
* Copyright (c) 2016 by Contributors
* \file c_api.h
* \brief C API of TVM DSL
*/
#ifndef TVM_C_API_H_
#define TVM_C_API_H_

#ifdef __cplusplus
#define TVM_EXTERN_C extern "C"
#else
#define TVM_EXTERN_C
#endif

/*! \brief TVM_DLL prefix for windows */
#ifdef _WIN32
#ifdef TVM_EXPORTS
#define TVM_DLL TVM_EXTERN_C __declspec(dllexport)
#else
#define TVM_DLL TVM_EXTERN_C __declspec(dllimport)
#endif
#else
#define TVM_DLL TVM_EXTERN_C
#endif

/*! \brief handle to node creator */
typedef void* NodeCreatorHandle;
/*! \brief handle to node */
typedef void* NodeHandle;

TVM_DLL int TVMNodeCreatorGet(const char* node_type,
NodeCreatorHandle *handle);

TVM_DLL int TVMNodeCreate(NodeCreatorHandle handle,
int num_child_ref,
const char* child_ref_keys,
NodeHandle* child_node_refs,
int num_attrs,
const char* attr_keys,
const char* attr_vals,
NodeHandle* handle);

TVM_DLL int TVMNodeGetAttr(const char* key,
const char** value);

TVM_DLL int TVMNodeGetChildNodeRef(const char* key,
NodeHandle* out);

#endif // TVM_C_API_H_
17 changes: 17 additions & 0 deletions include/tvm/domain.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/*!
* Copyright (c) 2016 by Contributors
* \file domain.h
* \brief Defines the AST
*/
#ifndef TVM_DOMAIN_H_
#define TVM_DOMAIN_H_

#include <memory>

namespace tvm {
class RDom {
};

} // namespace tvm

#endif // TVM_DOMAIN_H_
107 changes: 107 additions & 0 deletions include/tvm/expr.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*!
* Copyright (c) 2016 by Contributors
* \file expr.h
* \brief Defines the expressions in AST.
*/
#ifndef TVM_EXPR_H_
#define TVM_EXPR_H_

#include <type_traits>
#include "./base.h"

namespace tvm {
// forward declare Expr
class Expr;

/*!
* \brief create a constant expression
* \tparam T the value type
* \param value The value to the constant.
* \return The created expression
*/
template<typename T,
typename = typename std::enable_if<std::is_arithmetic<T>::value>::type >
inline Expr constant(T value);

/*!
* \brief a expression type, holds a ref to root of an AST
*/
class Expr : public NodeRef {
public:
/*! \brief default constructor */
Expr() = default;
/*!
* \brief copy constructor
* \param other the input
*/
Expr(const Expr& other) = default; // NOLINT(*)
/*!
* \brief move constructor
* \param other the input
*/
Expr(Expr&& other) = default; // NOLINT(*)
/*!
* \brief assign operator.
* \param other the input.
* \return reference to self
*/
Expr& operator=(const Expr& other) = default;
/*!
* \brief assign move operator.
* \param other the input.
* \return reference to self
*/
Expr& operator=(Expr&& other) = default;
/*!
* \brief constructor from constant value
* \param value the constant value
* \tparam T The constant type
*/
template<typename T,
typename = typename std::enable_if<std::is_arithmetic<T>::value>::type >
Expr(T value) { // NOLINT(*)
*this = std::move(constant<T>(value));
}
/*!
* \brief constructor from node pointer
* \param nptr Another node shared pointer
*/
explicit Expr(std::shared_ptr<Node> nptr) : NodeRef(nptr) {}
/*! \return the expression type of the expression */
inline DataType dtype() const;
};

/*! \brief Variable class */
class Var : public Expr {
public:
Var(std::string name="", DataType dtype=kInt32); // NOLINT(*)
};

/*! \brief */
Expr IntConstant(int64_t value);
Expr FloatConstant(int64_t value);
Expr operator+(Expr lhs, Expr rhs);

/*! \brief base of expression node */
class ExprNode : public Node {
public:
/*! \brief type of data stored in expression */
DataType dtype_{kUnknown};
};

// inline implementations
inline DataType Expr::dtype() const {
return static_cast<const ExprNode*>(node_.get())->dtype_;
}
template<typename T,
typename = typename std::enable_if<std::is_arithmetic<T>::value>::type >
inline Expr constant(T value) {
if (std::is_integral<T>::value) {
return IntConstant(static_cast<int64_t>(value));
} else {
return FloatConstant(static_cast<double>(value));
}
}

} // namespace tvm
#endif // TVM_EXPR_H_
Loading

0 comments on commit c41d9d2

Please sign in to comment.