Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LANG] Enable json load/save and pickle #10

Merged
merged 1 commit into from
Jan 12, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 38 additions & 2 deletions include/tvm/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,41 @@ using ::tvm::Node;
using ::tvm::NodeRef;
using ::tvm::AttrVisitor;

/*!
* \brief save the node as well as all the node it depends on as json.
* This can be used to serialize any TVM object
*
* \return the string representation of the node.
*/
std::string SaveJSON(const NodeRef& node);

/*!
* \brief Internal implementation of LoadJSON
* Load tvm Node object from json and return a shared_ptr of Node.
* \param json_str The json string to load from.
*
* \return The shared_ptr of the Node.
*/
std::shared_ptr<Node> LoadJSON_(std::string json_str);

/*!
* \brief Load the node from json string.
* This can be used to deserialize any TVM object.
*
* \param json_str The json string to load from.
*
* \tparam NodeType the nodetype
*
* \code
* Expr e = LoadJSON<Expr>(json_str);
* \endcode
*/
template<typename NodeType,
typename = typename std::enable_if<std::is_base_of<NodeRef, NodeType>::value>::type >
inline NodeType LoadJSON(const std::string& json_str) {
return NodeType(LoadJSON_(json_str));
}

/*! \brief typedef the factory function of data iterator */
using NodeFactory = std::function<std::shared_ptr<Node> ()>;
/*!
Expand All @@ -32,8 +67,9 @@ struct NodeFactoryReg
};

#define TVM_REGISTER_NODE_TYPE(TypeName) \
DMLC_REGISTRY_REGISTER(::tvm::NodeFactoryReg, NodeFactoryReg, TypeName) \
.set_body([]() { return std::make_shared<TypeName>(); })
static DMLC_ATTRIBUTE_UNUSED ::tvm::NodeFactoryReg & __make_Node ## _ ## TypeName ## __ = \
::dmlc::Registry<::tvm::NodeFactoryReg>::Get()->__REGISTER__(TypeName::_type_key) \
.set_body([]() { return std::make_shared<TypeName>(); })

} // namespace tvm
#endif // TVM_BASE_H_
9 changes: 5 additions & 4 deletions include/tvm/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
/*! \brief TVM_DLL prefix for windows */
#ifdef _WIN32
#ifdef TVM_EXPORTS
#define TVM_DLL TVM_EXTERN_C __declspec(dllexport)
#define TVM_DLL __declspec(dllexport)
#else
#define TVM_DLL TVM_EXTERN_C __declspec(dllimport)
#define TVM_DLL __declspec(dllimport)
#endif
#else
#define TVM_DLL TVM_EXTERN_C
#define TVM_DLL
#endif

TVM_EXTERN_C {
/*! \brief handle to functions */
typedef void* FunctionHandle;
/*! \brief handle to node */
Expand Down Expand Up @@ -147,5 +148,5 @@ TVM_DLL int TVMNodeGetAttr(NodeHandle handle,
TVM_DLL int TVMNodeListAttrNames(NodeHandle handle,
int *out_size,
const char*** out_array);

} // TVM_EXTERN_C
#endif // TVM_C_API_H_
24 changes: 23 additions & 1 deletion python/tvm/_ctypes/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ def __getattr__(self, name):
"'%s' object has no attribute '%s'" % (str(type(self)), name))
return value


def __hash__(self):
return _function_internal._raw_ptr(self)

Expand All @@ -111,6 +110,29 @@ def __dir__(self):
names.append(py_str(plist[i]))
return names

def __reduce__(self):
return (type(self), (None,), self.__getstate__())

def __getstate__(self):
handle = self.handle
if handle is not None:
return {'handle': _function_internal._save_json(self)}
else:
return {'handle': None}

def __setstate__(self, state):
# pylint: disable=assigning-non-slot
handle = state['handle']
if handle is not None:
json_str = handle
_push_arg(json_str)
other = _function_internal._load_json(json_str)
self.handle = other.handle
other.handle = None
else:
self.handle = None


def const(value, dtype=None):
"""construct a constant"""
if dtype is None:
Expand Down
32 changes: 32 additions & 0 deletions python/tvm/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,38 @@ def const(value, dtype=None):
return _function_internal._const(value, dtype)


def load_json(json_str):
"""Load tvm object from json_str.

Parameters
----------
json_str : str
The json string

Returns
-------
node : Node
The loaded tvm node.
"""
return _function_internal._load_json(json_str)


def save_json(node):
"""Load tvm object as json string.

Parameters
----------
node : Node
A TVM Node object to be saved.

Returns
-------
json_str : str
Saved json string.
"""
return _function_internal._save_json(node)


def Var(name="tindex", dtype=int32):
"""Create a new variable with specified name and dtype

Expand Down
42 changes: 42 additions & 0 deletions src/base/common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*!
* Copyright (c) 2016 by Contributors
* \file common.h
* \brief Common utilities
*/
#ifndef TVM_BASE_COMMON_H_
#define TVM_BASE_COMMON_H_

#include <tvm/base.h>
#include <string>

namespace tvm {

inline std::string Type2String(const Type& t) {
std::ostringstream os;
os << t;
return os.str();
}

inline Type String2Type(std::string s) {
std::istringstream is(s);
halide_type_code_t code = Type::Int;
if (s.substr(0, 3) == "int") {
code = Type::Int; s = s.substr(3);
} else if (s.substr(0, 4) == "uint") {
code = Type::UInt; s = s.substr(4);
} else if (s.substr(0, 5) == "float") {
code = Type::Float; s = s.substr(5);
} else if (s.substr(0, 5) == "float") {
code = Type::Float; s = s.substr(5);
} else {
LOG(FATAL) << "unknown type " << s;
}
int bits = 32, lanes = 1;
if (sscanf(s.c_str(), "%dx%d", &bits, &lanes) == 0) {
LOG(FATAL) << "unknown type " << s;
}
return Type(code, bits, lanes);
}

} // namespace tvm
#endif // TVM_BASE_COMMON_H_
Loading