Skip to content

Commit

Permalink
Change op function pointer to std::function, enable mutation (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored and sergei-mironov committed Aug 8, 2018
1 parent a52bbad commit c965e24
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 30 deletions.
7 changes: 7 additions & 0 deletions nnvm/include/nnvm/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ struct NodeEntry {
std::shared_ptr<Node> node;
/*! \brief index of output from the source. */
uint32_t index;
/*!
* \brief version of input Variable.
* This field can only be nonzero when this->node is a Variable node.
* version is increased by one each time a Variable get composed to a mutation Op.
* This information can be helpful to decide order of operations when sequence of mutation happens.
*/
uint32_t version;
};

/*!
Expand Down
19 changes: 9 additions & 10 deletions nnvm/include/nnvm/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,13 @@ class Op {
* \param attrs The attribute of the node
* \return number of outputs.
*/
uint32_t (*get_num_outputs)(const NodeAttrs& attrs) = nullptr;
std::function<uint32_t(const NodeAttrs& attrs)> get_num_outputs = nullptr;
/*!
* \brief get number of inputs given information about the node.
* \param attrs The attribute of the node
* \return number of inputs
*/
uint32_t (*get_num_inputs)(const NodeAttrs& attrs) = nullptr;
std::function<uint32_t(const NodeAttrs& attrs)> get_num_inputs = nullptr;
/*!
* \brief Attribute parser to parse the NodeAttrs information.
*
Expand Down Expand Up @@ -140,8 +140,7 @@ class Op {
* }
* \endcode
*/
void (*attr_parser)(NodeAttrs* attrs) = nullptr;

std::function<void(NodeAttrs* attrs)> attr_parser = nullptr;
// function fields.
/*!
* \brief setter function during registration
Expand All @@ -161,7 +160,7 @@ class Op {
* \param fn The function to be set.
* \return reference to self.
*/
inline Op& set_num_inputs(uint32_t (*fn)(const NodeAttrs& attr)); // NOLINT(*)
inline Op& set_num_inputs(std::function<uint32_t (const NodeAttrs& attr)> fn); // NOLINT(*)
/*!
* \brief Set the num_outputs
* \param n The number of outputs to be set.
Expand All @@ -173,13 +172,13 @@ class Op {
* \param fn The function to be set.
* \return reference to self.
*/
inline Op& set_num_outputs(uint32_t (*fn)(const NodeAttrs& attr)); // NOLINT(*)
inline Op& set_num_outputs(std::function<uint32_t (const NodeAttrs& attr)> fn); // NOLINT(*)
/*!
* \brief Set the attr_parser function.
* \param fn The number of outputs to be set.
* \return reference to self.
*/
inline Op& set_attr_parser(void (*fn)(NodeAttrs* attrs)); // NOLINT(*)
inline Op& set_attr_parser(std::function<void (NodeAttrs* attrs)> fn); // NOLINT(*)
/*!
* \brief Register additional attributes to operator.
* \param attr_name The name of the attribute.
Expand Down Expand Up @@ -342,7 +341,7 @@ inline Op& Op::set_num_inputs(uint32_t n) { // NOLINT(*)
return *this;
}

inline Op& Op::set_num_inputs(uint32_t (*fn)(const NodeAttrs& attr)) { // NOLINT(*)
inline Op& Op::set_num_inputs(std::function<uint32_t (const NodeAttrs& attr)> fn) { // NOLINT(*)
this->get_num_inputs = fn;
return *this;
}
Expand All @@ -352,12 +351,12 @@ inline Op& Op::set_num_outputs(uint32_t n) { // NOLINT(*)
return *this;
}

inline Op& Op::set_num_outputs(uint32_t (*fn)(const NodeAttrs& attr)) { // NOLINT(*)
inline Op& Op::set_num_outputs(std::function<uint32_t (const NodeAttrs& attr)> fn) { // NOLINT(*)
this->get_num_outputs = fn;
return *this;
}

inline Op& Op::set_attr_parser(void (*fn)(NodeAttrs* attrs)) { // NOLINT(*)
inline Op& Op::set_attr_parser(std::function<void (NodeAttrs* attrs)> fn) { // NOLINT(*)
this->attr_parser = fn;
return *this;
}
Expand Down
14 changes: 12 additions & 2 deletions nnvm/include/nnvm/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@

namespace nnvm {

// These types are optional attributes in each op
// Some of them are needed for certain pass.
// These types are optional attributes in each operator.
// Each attribute can be required by some passes.

/*!
* \brief Return list of input arguments names of each operator.
Expand All @@ -37,6 +37,16 @@ using FListInputNames = std::function<std::vector<std::string> (const NodeAttrs&
*/
using FListOutputNames = std::function<std::vector<std::string> (const NodeAttrs& attrs)>;

/*!
* \brief Check whether operator will mutate k-th input.
* \param index The input index
* \return Whether this operator will mutate index-th input.
*
* \note Register under "FMutateInput", default return false
* FMutateInputs enables mutation order handling correctly.
*/
using FMutateInput = std::function<bool (const NodeAttrs& attrs, uint32_t index)>;

} // namespace nnvm

#endif // NNVM_OP_ATTR_TYPES_H_
91 changes: 73 additions & 18 deletions nnvm/src/core/symbolic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,43 @@ namespace symbol_constants {
const char *kNamespaceSeparator = "_";
} // namespace symbol_constants

// auxililary version attribute in variable.
struct VariableParam {
uint32_t version{0};
};

std::shared_ptr<Node> CreateVariableNode(const std::string& name) {
std::shared_ptr<Node> n = Node::Create();
n->op = nullptr;
n->attrs.name = name;
n->attrs.parsed = VariableParam();
return n;
}

// scan over a node's input, update the version to latest
// If the node's op mutates a certain input variable,
// The version of that varaible will increase
// version is used to implicitly order the mutation sequences
inline void UpdateNodeVersion(Node *n) {
static auto& fmutate_inputs = Op::GetAttr<FMutateInput>("FMutateInput");
for (NodeEntry& e : n->inputs) {
if (e.node->is_variable()) {
e.version = nnvm::get<VariableParam>(e.node->attrs.parsed).version;
}
}
if (fmutate_inputs.count(n->op) != 0) {
FMutateInput fmutate = fmutate_inputs[n->op];
for (uint32_t i = 0; i < n->inputs.size(); ++i) {
if (fmutate(n->attrs, i)) {
NodeEntry& e = n->inputs[i];
CHECK(e.node->is_variable())
<< "Mutation target can only be Variable";
// increase the version of the variable.
++nnvm::get<VariableParam>(e.node->attrs.parsed).version;
}
}
}
}

inline std::string DefaultVarName(const std::string &op_name,
const std::string &arg_name) {
Expand Down Expand Up @@ -67,13 +104,13 @@ Symbol Symbol::Copy() const {
for (const auto &kv : old_new) {
for (const NodeEntry& e : kv.first->inputs) {
Node *ptr = e.node.get();
kv.second->inputs.emplace_back(NodeEntry{old_new[ptr], e.index});
kv.second->inputs.emplace_back(NodeEntry{old_new[ptr], e.index, e.version});
}
}
// set the head
Symbol ret;
for (const NodeEntry &e : outputs) {
ret.outputs.emplace_back(NodeEntry{old_new[e.node.get()], e.index});
ret.outputs.emplace_back(NodeEntry{old_new[e.node.get()], e.index, e.version});
}
return ret;
}
Expand All @@ -95,8 +132,14 @@ void Symbol::Print(std::ostream &os) const {
os << "Name: " << node->attrs.name << " Op:" << node->op->name << '\n'
<< "Inputs:\n";
for (size_t i = 0; i < node->inputs.size(); ++i) {
os << "\targ[" << i << "]=" << node->inputs[i].node->attrs.name
<< '(' << node->inputs[i].index << ")\n";
const NodeEntry& e = node->inputs[i];
os << "\targ[" << i << "]=" << e.node->attrs.name
<< '(' << e.index << ")";
if (e.node->is_variable()) {
os << " version=" << e.version << '\n';
} else {
os << '\n';
}
}
os << "Attrs:\n";
for (auto &kv : node->attrs.dict) {
Expand Down Expand Up @@ -163,6 +206,8 @@ std::vector<std::string> Symbol::ListOutputs() const {
void Symbol::Compose(const std::vector<Symbol>& args,
const std::unordered_map<std::string, Symbol>& kwargs,
const std::string& name) {
static auto& flist_inputs = Op::GetAttr<FListInputNames>("FListInputNames");

CHECK_EQ(outputs.size(), 1)
<< "Only composition of value function is supported currently";
CHECK(!outputs[0].node->is_variable()) << "Variable cannot be composed";
Expand Down Expand Up @@ -193,7 +238,6 @@ void Symbol::Compose(const std::vector<Symbol>& args,
}
// switch to keyword argument matching
if (args.size() != n_req) {
static auto& flist_inputs = Op::GetAttr<FListInputNames>("FListInputNames");
FListInputNames fn = flist_inputs.get(n->op, nullptr);
auto arg_names = (fn == nullptr) ? std::vector<std::string>{"data"} : fn(n->attrs);
if (arg_names.size() != n_req) {
Expand All @@ -206,8 +250,8 @@ void Symbol::Compose(const std::vector<Symbol>& args,
n->inputs[i] = it->second.outputs[0];
++nmatched;
} else {
n->inputs[i] = NodeEntry{Node::Create(), 0};
n->inputs[i].node->attrs.name = DefaultVarName(name, arg_names[i]);
n->inputs[i] = NodeEntry{
CreateVariableNode(DefaultVarName(name, arg_names[i])), 0, 0};
}
}

Expand All @@ -226,6 +270,7 @@ void Symbol::Compose(const std::vector<Symbol>& args,
n->inputs.push_back(s.outputs[0]);
}
}
UpdateNodeVersion(n);
} else {
// general composition
CHECK_EQ(args.size(), 0)
Expand Down Expand Up @@ -253,25 +298,32 @@ void Symbol::Compose(const std::vector<Symbol>& args,
DFSVisit(this->outputs, find_replace_map);

if (nmatched == kwargs.size() && arg_counter < args.size()) {
std::vector<Node*> update_nodes;
std::vector<std::pair<NodeEntry*, const NodeEntry*> > replace_plan;
auto find_replace_plan = [&replace_map, &replace_plan]
auto find_replace_plan = [&replace_map, &replace_plan, &update_nodes]
(const std::shared_ptr<Node> &node) {
// visit all the childs, find possible replacement
bool repl = false;
for (size_t i = 0; i < node->inputs.size(); ++i) {
NodeEntry *e = &(node->inputs[i]);
if (e->node->is_variable()) {
auto iter = replace_map.find(e->node.get());
if (iter != replace_map.end()) {
replace_plan.push_back(std::make_pair(e, iter->second));
repl = true;
}
}
}
if (repl) update_nodes.push_back(node.get());
};
DFSVisit(this->outputs, find_replace_plan);

for (const auto& kv : replace_plan) {
*(kv.first) = *(kv.second);
}
for (Node* n : update_nodes) {
UpdateNodeVersion(n);
}
} else {
std::vector<std::string> keys = GetKeys(kwargs);
std::vector<std::string> arg_names = ListArguments();
Expand Down Expand Up @@ -303,9 +355,15 @@ Symbol Symbol::GetInternals() const {
Symbol ret;
DFSVisit(this->outputs, [&ret](const std::shared_ptr<Node>& node) {
Node* n = node.get();
uint32_t nout = n->num_outputs();
for (uint32_t i = 0; i < nout; ++i) {
ret.outputs.emplace_back(NodeEntry{node, i});
if (n->is_variable()) {
// grab version from variable.
VariableParam& param = nnvm::get<VariableParam>(n->attrs.parsed);
ret.outputs.emplace_back(NodeEntry{node, 0, param.version});
} else {
uint32_t nout = n->num_outputs();
for (uint32_t i = 0; i < nout; ++i) {
ret.outputs.emplace_back(NodeEntry{node, i, 0});
}
}
});
return ret;
Expand All @@ -325,7 +383,7 @@ void Symbol::SetAttrs(const std::vector<std::pair<std::string, std::string> >& a
}
}
if (node->op != nullptr && node->op->attr_parser != nullptr) {
(*node->op->attr_parser)(&(node->attrs));
node->op->attr_parser(&(node->attrs));
}
}

Expand Down Expand Up @@ -366,9 +424,9 @@ Symbol Symbol::CreateFunctor(const Op* op,
n->op = op;
n->attrs.dict = std::move(attrs);
if (n->op->attr_parser != nullptr) {
(*n->op->attr_parser)(&(n->attrs));
n->op->attr_parser(&(n->attrs));
}
s.outputs.emplace_back(NodeEntry{std::move(n), 0});
s.outputs.emplace_back(NodeEntry{std::move(n), 0, 0});
return s;
}

Expand All @@ -382,10 +440,7 @@ Symbol Symbol::CreateGroup(const std::vector<Symbol> &symbols) {

Symbol Symbol::CreateVariable(const std::string& name) {
Symbol s;
std::shared_ptr<Node> n = Node::Create();
n->op = nullptr;
n->attrs.name = name;
s.outputs.emplace_back(NodeEntry{std::move(n), 0});
s.outputs.emplace_back(NodeEntry{CreateVariableNode(name), 0, 0});
return s;
}

Expand Down
8 changes: 8 additions & 0 deletions nnvm/src/example/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <utility>

using nnvm::FListInputNames;
using nnvm::FMutateInput;
using nnvm::NodeAttrs;

NNVM_REGISTER_OP(add)
Expand All @@ -29,3 +30,10 @@ NNVM_REGISTER_OP(conv2d)

NNVM_REGISTER_OP(add)
.attr<std::string>("nick_name", "plus");

NNVM_REGISTER_OP(assign)
.set_num_inputs(2)
.set_num_outputs(1)
.attr<FMutateInput>("FMutateInput", [](const NodeAttrs& attrs, uint32_t index) {
return index == 0;
});
14 changes: 14 additions & 0 deletions nnvm/tests/python/test_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,20 @@ def test_default_input():
except NNVMError:
pass

def test_mutate_input():
x = sym.Variable('x')
y = sym.conv2d(data=x, name='conv')
z = sym.assign(x, y)
t = sym.add(z, x)

try:
z = sym.assign(z, z)
assert False
except NNVMError:
pass


if __name__ == "__main__":
test_default_input()
test_compose()
test_mutate_input()

0 comments on commit c965e24

Please sign in to comment.