From c965e241f17e56ccbf44eedbcdc14d9796a8db5d Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Tue, 12 Jul 2016 22:03:43 -0700 Subject: [PATCH] Change op function pointer to std::function, enable mutation (#6) --- nnvm/include/nnvm/node.h | 7 +++ nnvm/include/nnvm/op.h | 19 +++---- nnvm/include/nnvm/op_attr_types.h | 14 ++++- nnvm/src/core/symbolic.cc | 91 +++++++++++++++++++++++++------ nnvm/src/example/operator.cc | 8 +++ nnvm/tests/python/test_symbol.py | 14 +++++ 6 files changed, 123 insertions(+), 30 deletions(-) diff --git a/nnvm/include/nnvm/node.h b/nnvm/include/nnvm/node.h index dd24c3bd6fa38..7ca18760d173c 100644 --- a/nnvm/include/nnvm/node.h +++ b/nnvm/include/nnvm/node.h @@ -24,6 +24,13 @@ struct NodeEntry { std::shared_ptr node; /*! \brief index of output from the source. */ uint32_t index; + /*! + * \brief version of input Variable. + * This field can only be nonzero when this->node is a Variable node. + * version is increased by one each time a Variable get composed to a mutation Op. + * This information can be helpful to decide order of operations when sequence of mutation happens. + */ + uint32_t version; }; /*! diff --git a/nnvm/include/nnvm/op.h b/nnvm/include/nnvm/op.h index a7d2fd5e1e4d0..5f499b7377e1c 100644 --- a/nnvm/include/nnvm/op.h +++ b/nnvm/include/nnvm/op.h @@ -101,13 +101,13 @@ class Op { * \param attrs The attribute of the node * \return number of outputs. */ - uint32_t (*get_num_outputs)(const NodeAttrs& attrs) = nullptr; + std::function get_num_outputs = nullptr; /*! * \brief get number of inputs given information about the node. * \param attrs The attribute of the node * \return number of inputs */ - uint32_t (*get_num_inputs)(const NodeAttrs& attrs) = nullptr; + std::function get_num_inputs = nullptr; /*! * \brief Attribute parser to parse the NodeAttrs information. * @@ -140,8 +140,7 @@ class Op { * } * \endcode */ - void (*attr_parser)(NodeAttrs* attrs) = nullptr; - + std::function attr_parser = nullptr; // function fields. /*! * \brief setter function during registration @@ -161,7 +160,7 @@ class Op { * \param fn The function to be set. * \return reference to self. */ - inline Op& set_num_inputs(uint32_t (*fn)(const NodeAttrs& attr)); // NOLINT(*) + inline Op& set_num_inputs(std::function fn); // NOLINT(*) /*! * \brief Set the num_outputs * \param n The number of outputs to be set. @@ -173,13 +172,13 @@ class Op { * \param fn The function to be set. * \return reference to self. */ - inline Op& set_num_outputs(uint32_t (*fn)(const NodeAttrs& attr)); // NOLINT(*) + inline Op& set_num_outputs(std::function fn); // NOLINT(*) /*! * \brief Set the attr_parser function. * \param fn The number of outputs to be set. * \return reference to self. */ - inline Op& set_attr_parser(void (*fn)(NodeAttrs* attrs)); // NOLINT(*) + inline Op& set_attr_parser(std::function fn); // NOLINT(*) /*! * \brief Register additional attributes to operator. * \param attr_name The name of the attribute. @@ -342,7 +341,7 @@ inline Op& Op::set_num_inputs(uint32_t n) { // NOLINT(*) return *this; } -inline Op& Op::set_num_inputs(uint32_t (*fn)(const NodeAttrs& attr)) { // NOLINT(*) +inline Op& Op::set_num_inputs(std::function fn) { // NOLINT(*) this->get_num_inputs = fn; return *this; } @@ -352,12 +351,12 @@ inline Op& Op::set_num_outputs(uint32_t n) { // NOLINT(*) return *this; } -inline Op& Op::set_num_outputs(uint32_t (*fn)(const NodeAttrs& attr)) { // NOLINT(*) +inline Op& Op::set_num_outputs(std::function fn) { // NOLINT(*) this->get_num_outputs = fn; return *this; } -inline Op& Op::set_attr_parser(void (*fn)(NodeAttrs* attrs)) { // NOLINT(*) +inline Op& Op::set_attr_parser(std::function fn) { // NOLINT(*) this->attr_parser = fn; return *this; } diff --git a/nnvm/include/nnvm/op_attr_types.h b/nnvm/include/nnvm/op_attr_types.h index bfcf5f6b8eaa1..615fca24e07cb 100644 --- a/nnvm/include/nnvm/op_attr_types.h +++ b/nnvm/include/nnvm/op_attr_types.h @@ -12,8 +12,8 @@ namespace nnvm { -// These types are optional attributes in each op -// Some of them are needed for certain pass. +// These types are optional attributes in each operator. +// Each attribute can be required by some passes. /*! * \brief Return list of input arguments names of each operator. @@ -37,6 +37,16 @@ using FListInputNames = std::function (const NodeAttrs& */ using FListOutputNames = std::function (const NodeAttrs& attrs)>; +/*! + * \brief Check whether operator will mutate k-th input. + * \param index The input index + * \return Whether this operator will mutate index-th input. + * + * \note Register under "FMutateInput", default return false + * FMutateInputs enables mutation order handling correctly. + */ +using FMutateInput = std::function; + } // namespace nnvm #endif // NNVM_OP_ATTR_TYPES_H_ diff --git a/nnvm/src/core/symbolic.cc b/nnvm/src/core/symbolic.cc index a1a1fdcbbe1b0..0d6f2a6c786e8 100644 --- a/nnvm/src/core/symbolic.cc +++ b/nnvm/src/core/symbolic.cc @@ -13,6 +13,43 @@ namespace symbol_constants { const char *kNamespaceSeparator = "_"; } // namespace symbol_constants +// auxililary version attribute in variable. +struct VariableParam { + uint32_t version{0}; +}; + +std::shared_ptr CreateVariableNode(const std::string& name) { + std::shared_ptr n = Node::Create(); + n->op = nullptr; + n->attrs.name = name; + n->attrs.parsed = VariableParam(); + return n; +} + +// scan over a node's input, update the version to latest +// If the node's op mutates a certain input variable, +// The version of that varaible will increase +// version is used to implicitly order the mutation sequences +inline void UpdateNodeVersion(Node *n) { + static auto& fmutate_inputs = Op::GetAttr("FMutateInput"); + for (NodeEntry& e : n->inputs) { + if (e.node->is_variable()) { + e.version = nnvm::get(e.node->attrs.parsed).version; + } + } + if (fmutate_inputs.count(n->op) != 0) { + FMutateInput fmutate = fmutate_inputs[n->op]; + for (uint32_t i = 0; i < n->inputs.size(); ++i) { + if (fmutate(n->attrs, i)) { + NodeEntry& e = n->inputs[i]; + CHECK(e.node->is_variable()) + << "Mutation target can only be Variable"; + // increase the version of the variable. + ++nnvm::get(e.node->attrs.parsed).version; + } + } + } +} inline std::string DefaultVarName(const std::string &op_name, const std::string &arg_name) { @@ -67,13 +104,13 @@ Symbol Symbol::Copy() const { for (const auto &kv : old_new) { for (const NodeEntry& e : kv.first->inputs) { Node *ptr = e.node.get(); - kv.second->inputs.emplace_back(NodeEntry{old_new[ptr], e.index}); + kv.second->inputs.emplace_back(NodeEntry{old_new[ptr], e.index, e.version}); } } // set the head Symbol ret; for (const NodeEntry &e : outputs) { - ret.outputs.emplace_back(NodeEntry{old_new[e.node.get()], e.index}); + ret.outputs.emplace_back(NodeEntry{old_new[e.node.get()], e.index, e.version}); } return ret; } @@ -95,8 +132,14 @@ void Symbol::Print(std::ostream &os) const { os << "Name: " << node->attrs.name << " Op:" << node->op->name << '\n' << "Inputs:\n"; for (size_t i = 0; i < node->inputs.size(); ++i) { - os << "\targ[" << i << "]=" << node->inputs[i].node->attrs.name - << '(' << node->inputs[i].index << ")\n"; + const NodeEntry& e = node->inputs[i]; + os << "\targ[" << i << "]=" << e.node->attrs.name + << '(' << e.index << ")"; + if (e.node->is_variable()) { + os << " version=" << e.version << '\n'; + } else { + os << '\n'; + } } os << "Attrs:\n"; for (auto &kv : node->attrs.dict) { @@ -163,6 +206,8 @@ std::vector Symbol::ListOutputs() const { void Symbol::Compose(const std::vector& args, const std::unordered_map& kwargs, const std::string& name) { + static auto& flist_inputs = Op::GetAttr("FListInputNames"); + CHECK_EQ(outputs.size(), 1) << "Only composition of value function is supported currently"; CHECK(!outputs[0].node->is_variable()) << "Variable cannot be composed"; @@ -193,7 +238,6 @@ void Symbol::Compose(const std::vector& args, } // switch to keyword argument matching if (args.size() != n_req) { - static auto& flist_inputs = Op::GetAttr("FListInputNames"); FListInputNames fn = flist_inputs.get(n->op, nullptr); auto arg_names = (fn == nullptr) ? std::vector{"data"} : fn(n->attrs); if (arg_names.size() != n_req) { @@ -206,8 +250,8 @@ void Symbol::Compose(const std::vector& args, n->inputs[i] = it->second.outputs[0]; ++nmatched; } else { - n->inputs[i] = NodeEntry{Node::Create(), 0}; - n->inputs[i].node->attrs.name = DefaultVarName(name, arg_names[i]); + n->inputs[i] = NodeEntry{ + CreateVariableNode(DefaultVarName(name, arg_names[i])), 0, 0}; } } @@ -226,6 +270,7 @@ void Symbol::Compose(const std::vector& args, n->inputs.push_back(s.outputs[0]); } } + UpdateNodeVersion(n); } else { // general composition CHECK_EQ(args.size(), 0) @@ -253,25 +298,32 @@ void Symbol::Compose(const std::vector& args, DFSVisit(this->outputs, find_replace_map); if (nmatched == kwargs.size() && arg_counter < args.size()) { + std::vector update_nodes; std::vector > replace_plan; - auto find_replace_plan = [&replace_map, &replace_plan] + auto find_replace_plan = [&replace_map, &replace_plan, &update_nodes] (const std::shared_ptr &node) { // visit all the childs, find possible replacement + bool repl = false; for (size_t i = 0; i < node->inputs.size(); ++i) { NodeEntry *e = &(node->inputs[i]); if (e->node->is_variable()) { auto iter = replace_map.find(e->node.get()); if (iter != replace_map.end()) { replace_plan.push_back(std::make_pair(e, iter->second)); + repl = true; } } } + if (repl) update_nodes.push_back(node.get()); }; DFSVisit(this->outputs, find_replace_plan); for (const auto& kv : replace_plan) { *(kv.first) = *(kv.second); } + for (Node* n : update_nodes) { + UpdateNodeVersion(n); + } } else { std::vector keys = GetKeys(kwargs); std::vector arg_names = ListArguments(); @@ -303,9 +355,15 @@ Symbol Symbol::GetInternals() const { Symbol ret; DFSVisit(this->outputs, [&ret](const std::shared_ptr& node) { Node* n = node.get(); - uint32_t nout = n->num_outputs(); - for (uint32_t i = 0; i < nout; ++i) { - ret.outputs.emplace_back(NodeEntry{node, i}); + if (n->is_variable()) { + // grab version from variable. + VariableParam& param = nnvm::get(n->attrs.parsed); + ret.outputs.emplace_back(NodeEntry{node, 0, param.version}); + } else { + uint32_t nout = n->num_outputs(); + for (uint32_t i = 0; i < nout; ++i) { + ret.outputs.emplace_back(NodeEntry{node, i, 0}); + } } }); return ret; @@ -325,7 +383,7 @@ void Symbol::SetAttrs(const std::vector >& a } } if (node->op != nullptr && node->op->attr_parser != nullptr) { - (*node->op->attr_parser)(&(node->attrs)); + node->op->attr_parser(&(node->attrs)); } } @@ -366,9 +424,9 @@ Symbol Symbol::CreateFunctor(const Op* op, n->op = op; n->attrs.dict = std::move(attrs); if (n->op->attr_parser != nullptr) { - (*n->op->attr_parser)(&(n->attrs)); + n->op->attr_parser(&(n->attrs)); } - s.outputs.emplace_back(NodeEntry{std::move(n), 0}); + s.outputs.emplace_back(NodeEntry{std::move(n), 0, 0}); return s; } @@ -382,10 +440,7 @@ Symbol Symbol::CreateGroup(const std::vector &symbols) { Symbol Symbol::CreateVariable(const std::string& name) { Symbol s; - std::shared_ptr n = Node::Create(); - n->op = nullptr; - n->attrs.name = name; - s.outputs.emplace_back(NodeEntry{std::move(n), 0}); + s.outputs.emplace_back(NodeEntry{CreateVariableNode(name), 0, 0}); return s; } diff --git a/nnvm/src/example/operator.cc b/nnvm/src/example/operator.cc index f1332d1af6221..2bd4a22ed2dd9 100644 --- a/nnvm/src/example/operator.cc +++ b/nnvm/src/example/operator.cc @@ -6,6 +6,7 @@ #include using nnvm::FListInputNames; +using nnvm::FMutateInput; using nnvm::NodeAttrs; NNVM_REGISTER_OP(add) @@ -29,3 +30,10 @@ NNVM_REGISTER_OP(conv2d) NNVM_REGISTER_OP(add) .attr("nick_name", "plus"); + +NNVM_REGISTER_OP(assign) +.set_num_inputs(2) +.set_num_outputs(1) +.attr("FMutateInput", [](const NodeAttrs& attrs, uint32_t index) { + return index == 0; + }); diff --git a/nnvm/tests/python/test_symbol.py b/nnvm/tests/python/test_symbol.py index 861b5372d5da9..08d24536084dd 100644 --- a/nnvm/tests/python/test_symbol.py +++ b/nnvm/tests/python/test_symbol.py @@ -24,6 +24,20 @@ def test_default_input(): except NNVMError: pass +def test_mutate_input(): + x = sym.Variable('x') + y = sym.conv2d(data=x, name='conv') + z = sym.assign(x, y) + t = sym.add(z, x) + + try: + z = sym.assign(z, z) + assert False + except NNVMError: + pass + + if __name__ == "__main__": test_default_input() test_compose() + test_mutate_input()