Skip to content

Commit

Permalink
[LANG] Enable json load/save and pickle (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored and icemelon committed Jan 12, 2017
1 parent 7250005 commit 5fced92
Show file tree
Hide file tree
Showing 15 changed files with 521 additions and 37 deletions.
40 changes: 38 additions & 2 deletions include/tvm/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,41 @@ using ::tvm::Node;
using ::tvm::NodeRef;
using ::tvm::AttrVisitor;

/*!
* \brief save the node as well as all the node it depends on as json.
* This can be used to serialize any TVM object
*
* \return the string representation of the node.
*/
std::string SaveJSON(const NodeRef& node);

/*!
* \brief Internal implementation of LoadJSON
* Load tvm Node object from json and return a shared_ptr of Node.
* \param json_str The json string to load from.
*
* \return The shared_ptr of the Node.
*/
std::shared_ptr<Node> LoadJSON_(std::string json_str);

/*!
* \brief Load the node from json string.
* This can be used to deserialize any TVM object.
*
* \param json_str The json string to load from.
*
* \tparam NodeType the nodetype
*
* \code
* Expr e = LoadJSON<Expr>(json_str);
* \endcode
*/
template<typename NodeType,
typename = typename std::enable_if<std::is_base_of<NodeRef, NodeType>::value>::type >
inline NodeType LoadJSON(const std::string& json_str) {
return NodeType(LoadJSON_(json_str));
}

/*! \brief typedef the factory function of data iterator */
using NodeFactory = std::function<std::shared_ptr<Node> ()>;
/*!
Expand All @@ -32,8 +67,9 @@ struct NodeFactoryReg
};

#define TVM_REGISTER_NODE_TYPE(TypeName) \
DMLC_REGISTRY_REGISTER(::tvm::NodeFactoryReg, NodeFactoryReg, TypeName) \
.set_body([]() { return std::make_shared<TypeName>(); })
static DMLC_ATTRIBUTE_UNUSED ::tvm::NodeFactoryReg & __make_Node ## _ ## TypeName ## __ = \
::dmlc::Registry<::tvm::NodeFactoryReg>::Get()->__REGISTER__(TypeName::_type_key) \
.set_body([]() { return std::make_shared<TypeName>(); })

} // namespace tvm
#endif // TVM_BASE_H_
9 changes: 5 additions & 4 deletions include/tvm/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
/*! \brief TVM_DLL prefix for windows */
#ifdef _WIN32
#ifdef TVM_EXPORTS
#define TVM_DLL TVM_EXTERN_C __declspec(dllexport)
#define TVM_DLL __declspec(dllexport)
#else
#define TVM_DLL TVM_EXTERN_C __declspec(dllimport)
#define TVM_DLL __declspec(dllimport)
#endif
#else
#define TVM_DLL TVM_EXTERN_C
#define TVM_DLL
#endif

TVM_EXTERN_C {
/*! \brief handle to functions */
typedef void* FunctionHandle;
/*! \brief handle to node */
Expand Down Expand Up @@ -147,5 +148,5 @@ TVM_DLL int TVMNodeGetAttr(NodeHandle handle,
TVM_DLL int TVMNodeListAttrNames(NodeHandle handle,
int *out_size,
const char*** out_array);

} // TVM_EXTERN_C
#endif // TVM_C_API_H_
24 changes: 23 additions & 1 deletion python/tvm/_ctypes/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ def __getattr__(self, name):
"'%s' object has no attribute '%s'" % (str(type(self)), name))
return value


def __hash__(self):
return _function_internal._raw_ptr(self)

Expand All @@ -111,6 +110,29 @@ def __dir__(self):
names.append(py_str(plist[i]))
return names

def __reduce__(self):
return (type(self), (None,), self.__getstate__())

def __getstate__(self):
handle = self.handle
if handle is not None:
return {'handle': _function_internal._save_json(self)}
else:
return {'handle': None}

def __setstate__(self, state):
# pylint: disable=assigning-non-slot
handle = state['handle']
if handle is not None:
json_str = handle
_push_arg(json_str)
other = _function_internal._load_json(json_str)
self.handle = other.handle
other.handle = None
else:
self.handle = None


def const(value, dtype=None):
"""construct a constant"""
if dtype is None:
Expand Down
32 changes: 32 additions & 0 deletions python/tvm/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,38 @@ def const(value, dtype=None):
return _function_internal._const(value, dtype)


def load_json(json_str):
"""Load tvm object from json_str.
Parameters
----------
json_str : str
The json string
Returns
-------
node : Node
The loaded tvm node.
"""
return _function_internal._load_json(json_str)


def save_json(node):
"""Load tvm object as json string.
Parameters
----------
node : Node
A TVM Node object to be saved.
Returns
-------
json_str : str
Saved json string.
"""
return _function_internal._save_json(node)


def Var(name="tindex", dtype=int32):
"""Create a new variable with specified name and dtype
Expand Down
42 changes: 42 additions & 0 deletions src/base/common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*!
* Copyright (c) 2016 by Contributors
* \file common.h
* \brief Common utilities
*/
#ifndef TVM_BASE_COMMON_H_
#define TVM_BASE_COMMON_H_

#include <tvm/base.h>
#include <string>

namespace tvm {

inline std::string Type2String(const Type& t) {
std::ostringstream os;
os << t;
return os.str();
}

inline Type String2Type(std::string s) {
std::istringstream is(s);
halide_type_code_t code = Type::Int;
if (s.substr(0, 3) == "int") {
code = Type::Int; s = s.substr(3);
} else if (s.substr(0, 4) == "uint") {
code = Type::UInt; s = s.substr(4);
} else if (s.substr(0, 5) == "float") {
code = Type::Float; s = s.substr(5);
} else if (s.substr(0, 5) == "float") {
code = Type::Float; s = s.substr(5);
} else {
LOG(FATAL) << "unknown type " << s;
}
int bits = 32, lanes = 1;
if (sscanf(s.c_str(), "%dx%d", &bits, &lanes) == 0) {
LOG(FATAL) << "unknown type " << s;
}
return Type(code, bits, lanes);
}

} // namespace tvm
#endif // TVM_BASE_COMMON_H_
Loading

0 comments on commit 5fced92

Please sign in to comment.