Skip to content

Commit 486249e

Browse files
committed
Update mutate function (apache#23)
1 parent 16a6db3 commit 486249e

File tree

9 files changed

+65
-58
lines changed

9 files changed

+65
-58
lines changed

nnvm/docs/Doxyfile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ DOXYFILE_ENCODING = UTF-8
3232
# title of most generated pages and in a few other places.
3333
# The default value is: My Project.
3434

35-
PROJECT_NAME = "mxnngraph"
35+
PROJECT_NAME = "nnvm"
3636

3737
# The PROJECT_NUMBER tag can be used to enter a project or revision number. This
3838
# could be handy for archiving the generated documentation or if some version
@@ -753,7 +753,7 @@ WARN_LOGFILE =
753753
# spaces.
754754
# Note: If this tag is empty the current directory is searched.
755755

756-
INPUT = include
756+
INPUT = include/nnvm
757757

758758
# This tag can be used to specify the character encoding of the source files
759759
# that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses

nnvm/example/src/operator.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
namespace myproject {
1212

1313
using nnvm::FListInputNames;
14-
using nnvm::FMutateInput;
14+
using nnvm::FMutateInputs;
1515
using nnvm::FInferShape;
1616
using nnvm::FInferType;
1717
using nnvm::FInplaceOption;
@@ -119,8 +119,8 @@ NNVM_REGISTER_OP(add)
119119
NNVM_REGISTER_OP(assign)
120120
.set_num_inputs(2)
121121
.set_num_outputs(1)
122-
.attr<FMutateInput>("FMutateInput", [](const NodeAttrs& attrs, uint32_t index) {
123-
return index == 0;
122+
.attr<FMutateInputs>("FMutateInputs", [](const NodeAttrs& attrs) {
123+
return std::vector<uint32_t>{0};
124124
});
125125

126126
} // namespace myproject

nnvm/include/nnvm/graph.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,8 @@ class IndexedGraph {
144144
return nodes_[node_id(node)];
145145
}
146146
/*! \return list of argument nodes */
147-
inline const std::vector<uint32_t>& arg_nodes() const {
148-
return arg_nodes_;
147+
inline const std::vector<uint32_t>& input_nodes() const {
148+
return input_nodes_;
149149
}
150150
/*! \return list of output entries */
151151
inline const std::vector<NodeEntry>& outputs() const {
@@ -161,8 +161,8 @@ class IndexedGraph {
161161
explicit IndexedGraph(const Graph& other);
162162
// node pointers in CSR structure.
163163
std::vector<Node> nodes_;
164-
// index to argument nodes
165-
std::vector<uint32_t> arg_nodes_;
164+
// index to input nodes
165+
std::vector<uint32_t> input_nodes_;
166166
// space to store the outputs entries
167167
std::vector<NodeEntry> outputs_;
168168
// mapping from node to index.

nnvm/include/nnvm/op_attr_types.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,12 @@ using FListOutputNames = std::function<std::vector<std::string> (const NodeAttrs
4343
/*!
4444
* \brief Check whether operator will mutate k-th input.
4545
* \param attrs The attributes of the node.
46-
* \param index The input index
47-
* \return Whether this operator will mutate index-th input.
46+
* \return list of input indices it mutates.
4847
*
49-
* \note Register under "FMutateInput", default return false
48+
* \note Register under "FMutateInputs", default return false
5049
* FMutateInputs enables mutation order handling correctly.
5150
*/
52-
using FMutateInput = std::function<bool (const NodeAttrs& attrs, uint32_t index)>;
51+
using FMutateInputs = std::function<std::vector<uint32_t> (const NodeAttrs& attrs)>;
5352

5453
/*!
5554
* \brief Inference function of certain type.

nnvm/include/nnvm/pass_functions.h

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -54,16 +54,16 @@ inline Graph OrderMutation(Graph src) {
5454
/*!
5555
* \brief Infer shapes in the graph given the information.
5656
* \param graph source graph
57-
* \param shape_args The shapes of aruguments to the graph.
57+
* \param shape_inputs The shapes of aruguments to the graph.
5858
* \param shape_attr_key The key to the node attribute that can indicate shape.
5959
* \return A graph with new attribute "shape" containing inferred shape of each NodeEntry.
6060
* The index of ShapeVector is given by graph.indexed_graph().entry_id
6161
*/
6262
inline Graph InferShape(Graph graph,
63-
ShapeVector shape_args = {},
63+
ShapeVector shape_inputs = {},
6464
std::string shape_attr_key = "") {
65-
if (shape_args.size() != 0) {
66-
graph.attrs["shape_args"] = std::make_shared<any>(std::move(shape_args));
65+
if (shape_inputs.size() != 0) {
66+
graph.attrs["shape_inputs"] = std::make_shared<any>(std::move(shape_inputs));
6767
}
6868
if (shape_attr_key.length() != 0) {
6969
graph.attrs["shape_attr_key"] = std::make_shared<any>(std::move(shape_attr_key));
@@ -74,19 +74,19 @@ inline Graph InferShape(Graph graph,
7474
/*!
7575
* \brief Infer types in the graph given the information.
7676
* \param graph source graph
77-
* \param shape_args The shapes of aruguments to the graph.
78-
* \param shape_attr_key The key to the node attribute that can indicate shape.
77+
* \param dtype_inputs The shapes of inputs to the graph.
78+
* \param dtype_attr_key The key to the node attribute that can indicate shape.
7979
* \return A graph with new attribute "shape" containing inferred shape of each NodeEntry.
8080
* The index of ShapeVector is given by graph.indexed_graph().entry_id
8181
*/
8282
inline Graph InferType(Graph graph,
83-
DTypeVector type_args = {},
84-
std::string type_attr_key = "") {
85-
if (type_args.size() != 0) {
86-
graph.attrs["dtype_args"] = std::make_shared<any>(std::move(type_args));
83+
DTypeVector dtype_inputs = {},
84+
std::string dtype_attr_key = "") {
85+
if (dtype_inputs.size() != 0) {
86+
graph.attrs["dtype_inputs"] = std::make_shared<any>(std::move(dtype_inputs));
8787
}
88-
if (type_attr_key.length() != 0) {
89-
graph.attrs["dtype_attr_key"] = std::make_shared<any>(std::move(type_attr_key));
88+
if (dtype_attr_key.length() != 0) {
89+
graph.attrs["dtype_attr_key"] = std::make_shared<any>(std::move(dtype_attr_key));
9090
}
9191
return ApplyPass(std::move(graph), {"InferType"});
9292
}

nnvm/src/core/graph.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ IndexedGraph::IndexedGraph(const Graph &g) {
3030
nodes_.emplace_back(std::move(new_node));
3131
// arg_nodes_
3232
if (n->is_variable()) {
33-
arg_nodes_.push_back(nid);
33+
input_nodes_.push_back(nid);
3434
}
3535
// node2index_
3636
node2index_[n.get()] = nid;

nnvm/src/core/symbolic.cc

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,22 +31,19 @@ NodePtr CreateVariableNode(const std::string& name) {
3131
// The version of that varaible will increase
3232
// version is used to implicitly order the mutation sequences
3333
inline void UpdateNodeVersion(Node *n) {
34-
static auto& fmutate_inputs = Op::GetAttr<FMutateInput>("FMutateInput");
34+
static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
3535
for (NodeEntry& e : n->inputs) {
3636
if (e.node->is_variable()) {
3737
e.version = nnvm::get<VariableParam>(e.node->attrs.parsed).version;
3838
}
3939
}
4040
if (fmutate_inputs.count(n->op) != 0) {
41-
FMutateInput fmutate = fmutate_inputs[n->op];
42-
for (uint32_t i = 0; i < n->inputs.size(); ++i) {
43-
if (fmutate(n->attrs, i)) {
44-
NodeEntry& e = n->inputs[i];
45-
CHECK(e.node->is_variable())
46-
<< "Mutation target can only be Variable";
47-
// increase the version of the variable.
48-
e.version = ++nnvm::get<VariableParam>(e.node->attrs.parsed).version;
49-
}
41+
for (uint32_t i : fmutate_inputs[n->op](n->attrs)) {
42+
NodeEntry& e = n->inputs[i];
43+
CHECK(e.node->is_variable())
44+
<< "Mutation target can only be Variable";
45+
// increase the version of the variable.
46+
e.version = ++nnvm::get<VariableParam>(e.node->attrs.parsed).version;
5047
}
5148
}
5249
}
@@ -192,16 +189,13 @@ std::vector<std::string> Symbol::ListInputNames(ListInputOption option) const {
192189
} else {
193190
std::unordered_set<Node*> mutable_set;
194191
std::vector<Node*> vlist;
195-
static auto& fmutate_inputs = Op::GetAttr<FMutateInput>("FMutateInput");
192+
static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
196193
DFSVisit(this->outputs, [&ret, &mutable_set, &vlist](const NodePtr &node) {
197194
if (node->is_variable()) {
198195
vlist.push_back(node.get());
199196
} else if (fmutate_inputs.count(node->op)) {
200-
FMutateInput fmutate = fmutate_inputs[node->op];
201-
for (uint32_t i = 0; i < node->inputs.size(); ++i) {
202-
if (fmutate(node->attrs, i)) {
203-
mutable_set.insert(node->inputs[i].node.get());
204-
}
197+
for (uint32_t i : fmutate_inputs[node->op](node->attrs)){
198+
mutable_set.insert(node->inputs[i].node.get());
205199
}
206200
}
207201
});

nnvm/src/pass/infer_shape_type.cc

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ template<typename AttrType, typename IsNone>
1515
Graph InferAttr(Graph &&ret,
1616
const AttrType def_value,
1717
const char* infer_name,
18-
const char* arg_name,
18+
const char* input_name,
1919
const char* attr_key_name,
2020
const char* attr_name,
2121
const char* unknown_name,
@@ -29,15 +29,15 @@ Graph InferAttr(Graph &&ret,
2929
// reshape shape vector
3030
AttrVector rshape(idx.num_node_entries(), def_value);
3131

32-
if (ret.attrs.count(arg_name) != 0) {
33-
const AttrVector& shape_args = ret.GetAttr<AttrVector>(arg_name);
34-
CHECK_LE(shape_args.size(), idx.arg_nodes().size())
32+
if (ret.attrs.count(input_name) != 0) {
33+
const AttrVector& shape_args = ret.GetAttr<AttrVector>(input_name);
34+
CHECK_LE(shape_args.size(), idx.input_nodes().size())
3535
<< "shape args is more than number of arguments";
3636
for (size_t i = 0; i < shape_args.size(); ++i) {
37-
rshape[idx.entry_id(idx.arg_nodes()[i], 0)] = shape_args[i];
37+
rshape[idx.entry_id(idx.input_nodes()[i], 0)] = shape_args[i];
3838
}
3939
// erase the provided arguments
40-
ret.attrs.erase(arg_name);
40+
ret.attrs.erase(input_name);
4141
}
4242
std::string shape_attr_key;
4343
if (ret.attrs.count(attr_key_name) != 0) {
@@ -113,7 +113,7 @@ NNVM_REGISTER_PASS(InferType)
113113
.set_body([](Graph ret) {
114114
return InferAttr<int>(
115115
std::move(ret), 0,
116-
"FInferType", "dtype_args", "dtype_attr_key",
116+
"FInferType", "dtype_inputs", "dtype_attr_key",
117117
"dtype", "dtype_num_unknown_nodes",
118118
[](const int t) { return t == -1; });
119119
})

nnvm/src/pass/order_mutation.cc

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,13 @@ inline T get_with_default(const std::unordered_map<Node*, T> &map,
2121
return def;
2222
}
2323

24+
inline bool IsMutate(const std::vector<uint32_t>& mutate_inputs, uint32_t i) {
25+
if (mutate_inputs.size() == 0) return false;
26+
auto it = std::lower_bound(
27+
mutate_inputs.begin(), mutate_inputs.end(), i);
28+
return (it != mutate_inputs.end()) && (*it == i);
29+
}
30+
2431
Graph OrderMutation(const Graph& src) {
2532
std::unordered_map<Node*, std::vector<NodeEntry> > version_hist;
2633
DFSVisit(src.outputs, [&version_hist](const NodePtr& n) {
@@ -37,7 +44,13 @@ Graph OrderMutation(const Graph& src) {
3744
// start preparing for remapping the nodes.
3845
std::unordered_map<Node*, NodePtr> old_new;
3946
auto prepare = [&version_hist, &old_new] (const NodePtr& n) {
40-
static auto& fmutate_inputs = Op::GetAttr<FMutateInput>("FMutateInput");
47+
static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
48+
std::vector<uint32_t> mutate_inputs;
49+
if (!n->is_variable() && fmutate_inputs.count(n->op)) {
50+
mutate_inputs = fmutate_inputs[n->op](n->attrs);
51+
}
52+
std::sort(mutate_inputs.begin(), mutate_inputs.end());
53+
4154
bool need_repl = false;
4255
for (size_t i = 0; i < n->inputs.size(); ++i) {
4356
const NodeEntry& e = n->inputs[i];
@@ -46,9 +59,7 @@ Graph OrderMutation(const Graph& src) {
4659
auto it = version_hist.find(e.node.get());
4760
if (it != version_hist.end()) {
4861
std::vector<NodeEntry>& vec = it->second;
49-
uint32_t is_mutate =
50-
fmutate_inputs.count(n->op) ? fmutate_inputs[n->op](n->attrs, i) : 0;
51-
vec.emplace_back(NodeEntry{n, is_mutate, e.version});
62+
vec.emplace_back(NodeEntry{n, IsMutate(mutate_inputs, i), e.version});
5263
}
5364
} else {
5465
if (old_new.count(e.node.get()) != 0) need_repl = true;
@@ -91,18 +102,21 @@ Graph OrderMutation(const Graph& src) {
91102
get_with_default(old_new, p.get(), p));
92103
}
93104
// add control deps
94-
static auto& fmutate_inputs = Op::GetAttr<FMutateInput>("FMutateInput");
105+
static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
106+
std::vector<uint32_t> mutate_inputs;
107+
if (fmutate_inputs.count(kv.first->op)) {
108+
mutate_inputs = fmutate_inputs[kv.first->op](kv.first->attrs);
109+
}
110+
std::sort(mutate_inputs.begin(), mutate_inputs.end());
111+
95112
for (size_t i = 0; i < kv.first->inputs.size(); ++i) {
96113
const NodeEntry& e = kv.first->inputs[i];
97114
if (e.node->is_variable() && version_hist.count(e.node.get()) != 0) {
98-
FMutateInput fmutate = fmutate_inputs.get(kv.first->op, nullptr);
99-
uint32_t is_mutate = (fmutate == nullptr) ? 0 : fmutate(kv.first->attrs, i);
100115
std::vector<NodeEntry>& vec = version_hist.at(e.node.get());
101-
102116
auto it = std::lower_bound(vec.begin(), vec.end(),
103117
NodeEntry{nullptr, 1, e.version},
104118
comparator);
105-
if (is_mutate != 0) {
119+
if (IsMutate(mutate_inputs, i)) {
106120
int read_dep = 0;
107121
while (it != vec.begin()) {
108122
--it;

0 commit comments

Comments
 (0)