diff --git a/apps/graph_executor/src/graph_executor.cc b/apps/graph_executor/src/graph_executor.cc index fa1325c1f858..e1230f4c7663 100644 --- a/apps/graph_executor/src/graph_executor.cc +++ b/apps/graph_executor/src/graph_executor.cc @@ -399,6 +399,55 @@ FOpExec GraphExecutor::CreateTVMOp(const nnvm::NodeAttrs& attrs, return fexec; } +struct TVMOpParam : public dmlc::Parameter { + std::string func_name; + uint32_t num_inputs; + uint32_t num_outputs; + bool flatten_data; + DMLC_DECLARE_PARAMETER(TVMOpParam) { + DMLC_DECLARE_FIELD(func_name); + DMLC_DECLARE_FIELD(num_inputs) + .set_default(1); + DMLC_DECLARE_FIELD(num_outputs) + .set_default(1); + DMLC_DECLARE_FIELD(flatten_data) + .set_default(false); + } +}; +DMLC_REGISTER_PARAMETER(TVMOpParam); + +/*! \brief Parse keyword arguments as PType arguments and save to parsed */ +template +inline void ParamParser(nnvm::NodeAttrs* attrs) { + PType param; + try { + param.Init(attrs->dict); + } catch (const dmlc::ParamError& e) { + std::ostringstream os; + os << e.what(); + os << ", in operator " << attrs->op->name << "(" + << "name=\"" << attrs->name << "\""; + for (const auto& k : attrs->dict) { + os << ", " << k.first << "=\"" << k.second << "\""; + } + os << ")"; + throw dmlc::ParamError(os.str()); + } + attrs->parsed = std::move(param); +} + +// ewise tvm op +NNVM_REGISTER_OP(tvm_op) +.set_attr_parser(ParamParser) +.set_num_inputs([](const NodeAttrs& attrs) { + const TVMOpParam& param = nnvm::get(attrs.parsed); + return param.num_inputs; + }) +.set_num_outputs([](const NodeAttrs& attrs) { + const TVMOpParam& param = nnvm::get(attrs.parsed); + return param.num_outputs; + }); + // Create executor tvm::runtime::Module CreateExecutor(nnvm::Graph g, TVMContext ctx) { std::shared_ptr exec = @@ -460,3 +509,27 @@ TVM_REGISTER_GLOBAL("tvm_graph._load_executor") }); } // namespace contrib } // namespace tvm + +namespace dmlc { +namespace json { + +template<> +struct Handler { + static void Write(JSONWriter *writer, const DLDataType& data) { + std::vector tmp({data.code, data.bits, data.lanes}); + writer->Write(tmp); + } + + static void Read(JSONReader *reader, DLDataType* data) { + std::vector tmp; + reader->Read(&tmp); + data->code = tmp[0]; + data->bits = tmp[1]; + data->lanes = tmp[2]; + } +}; + +DMLC_JSON_ENABLE_ANY(std::vector, list_dltype); + +} // namespace dmlc +} // namespace json diff --git a/apps/graph_executor/src/graph_pass.cc b/apps/graph_executor/src/graph_pass.cc index 544606df4ea1..38ef185469fd 100644 --- a/apps/graph_executor/src/graph_pass.cc +++ b/apps/graph_executor/src/graph_pass.cc @@ -340,54 +340,6 @@ nnvm::Graph GraphFuse(nnvm::Graph g) { NNVM_REGISTER_PASS(GraphFuse) .set_body(GraphFuse); -struct TVMOpParam : public dmlc::Parameter { - std::string func_name; - uint32_t num_inputs; - uint32_t num_outputs; - bool flatten_data; - DMLC_DECLARE_PARAMETER(TVMOpParam) { - DMLC_DECLARE_FIELD(func_name); - DMLC_DECLARE_FIELD(num_inputs) - .set_default(1); - DMLC_DECLARE_FIELD(num_outputs) - .set_default(1); - DMLC_DECLARE_FIELD(flatten_data) - .set_default(false); - } -}; -DMLC_REGISTER_PARAMETER(TVMOpParam); - -/*! \brief Parse keyword arguments as PType arguments and save to parsed */ -template -inline void ParamParser(nnvm::NodeAttrs* attrs) { - PType param; - try { - param.Init(attrs->dict); - } catch (const dmlc::ParamError& e) { - std::ostringstream os; - os << e.what(); - os << ", in operator " << attrs->op->name << "(" - << "name=\"" << attrs->name << "\""; - for (const auto& k : attrs->dict) { - os << ", " << k.first << "=\"" << k.second << "\""; - } - os << ")"; - throw dmlc::ParamError(os.str()); - } - attrs->parsed = std::move(param); -} - -// ewise tvm op -NNVM_REGISTER_OP(tvm_op) -.set_attr_parser(ParamParser) -.set_num_inputs([](const NodeAttrs& attrs) { - const TVMOpParam& param = nnvm::get(attrs.parsed); - return param.num_inputs; - }) -.set_num_outputs([](const NodeAttrs& attrs) { - const TVMOpParam& param = nnvm::get(attrs.parsed); - return param.num_outputs; - }); inline bool IsIdentityLayout(const LayoutInfo& layout) { if (layout.src == "" && layout.dst == "") return true; @@ -515,27 +467,3 @@ NNVM_REGISTER_OP(layout_transform) .set_num_outputs(1); } // namespace contrib } // namespace tvm - -namespace dmlc { -namespace json { - -template<> -struct Handler { - static void Write(JSONWriter *writer, const DLDataType& data) { - std::vector tmp({data.code, data.bits, data.lanes}); - writer->Write(tmp); - } - - static void Read(JSONReader *reader, DLDataType* data) { - std::vector tmp; - reader->Read(&tmp); - data->code = tmp[0]; - data->bits = tmp[1]; - data->lanes = tmp[2]; - } -}; - -DMLC_JSON_ENABLE_ANY(std::vector, list_dltype); - -} // namespace dmlc -} // namespace json