Skip to content

Commit db9a9a7

Browse files
committed
[PASS] Add order mutation (apache#7)
* [PASS] Add order mutation * A few benchmarks on compose speed
1 parent badcdff commit db9a9a7

File tree

9 files changed

+296
-63
lines changed

9 files changed

+296
-63
lines changed

nnvm/include/nnvm/symbolic.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,8 @@ class Symbol {
7373
* \param kwargs keyword arguments for the symbol
7474
* \param name name of returned symbol.
7575
*/
76-
void Compose(const std::vector<Symbol>& args,
77-
const std::unordered_map<std::string, Symbol>& kwargs,
76+
void Compose(const array_view<const Symbol*>& args,
77+
const std::unordered_map<std::string, const Symbol*>& kwargs,
7878
const std::string& name);
7979
/*!
8080
* \brief Apply the symbol as a function, compose with arguments
@@ -84,8 +84,8 @@ class Symbol {
8484
* \param name name of returned symbol.
8585
* \return a new Symbol which is the composition of current symbol with its arguments
8686
*/
87-
Symbol operator () (const std::vector<Symbol>& args,
88-
const std::unordered_map<std::string, Symbol>& kwargs,
87+
Symbol operator () (const array_view<const Symbol*>& args,
88+
const std::unordered_map<std::string, const Symbol*>& kwargs,
8989
const std::string& name) const;
9090
/*!
9191
* \brief Add control flow depenencies to operators involved in symbols.

nnvm/src/c_api/c_api_common.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <dmlc/logging.h>
1111
#include <dmlc/thread_local.h>
1212
#include <nnvm/c_api.h>
13+
#include <nnvm/symbolic.h>
1314
#include <vector>
1415
#include <string>
1516

@@ -36,6 +37,8 @@ struct NNAPIThreadLocalEntry {
3637
std::vector<const char *> ret_vec_charp;
3738
/*! \brief result holder for returning handles */
3839
std::vector<void *> ret_handles;
40+
/*! \brief argument holder to hold symbol */
41+
std::unordered_map<std::string, const nnvm::Symbol*> kwarg_symbol;
3942
};
4043

4144
/*! \brief Thread local store that can be used to hold return values. */

nnvm/src/c_api/c_api_symbolic.cc

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -217,22 +217,26 @@ int NNSymbolCompose(SymbolHandle sym,
217217
const char** keys,
218218
SymbolHandle* args) {
219219
API_BEGIN();
220-
std::string s_name;
221-
if (name != nullptr) s_name = name;
222-
220+
NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
221+
std::string& s_name = ret->ret_str;
222+
std::unordered_map<std::string, const Symbol*>& kwargs
223+
= ret->kwarg_symbol;
224+
if (name != nullptr) {
225+
s_name = name;
226+
} else {
227+
s_name.clear();
228+
}
223229
Symbol* s = static_cast<Symbol*>(sym);
224230
if (keys == nullptr && num_args != 0) {
225-
std::vector<Symbol> pos_args;
226-
for (nn_uint i = 0; i < num_args; ++i) {
227-
pos_args.push_back(*((Symbol*)args[i])); // NOLINT(*)
228-
}
229-
s->Compose(pos_args, {}, s_name);
231+
kwargs.clear();
232+
array_view<const Symbol*> parg(
233+
(Symbol**)args, (Symbol**)args + num_args); // NOLINT(*)
234+
s->Compose(parg, kwargs, s_name);
230235
} else {
231-
std::unordered_map<std::string, Symbol> kwargs;
232236
for (nn_uint i = 0; i < num_args; ++i) {
233-
kwargs[keys[i]] = *((Symbol*)args[i]); // NOLINT(*)
237+
kwargs[keys[i]] = (Symbol*)args[i]; // NOLINT(*)
234238
}
235-
s->Compose({}, kwargs, s_name);
239+
s->Compose(array_view<const Symbol*>(), kwargs, s_name);
236240
}
237241
API_END();
238242
}

nnvm/src/core/symbolic.cc

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ inline void UpdateNodeVersion(Node *n) {
4545
CHECK(e.node->is_variable())
4646
<< "Mutation target can only be Variable";
4747
// increase the version of the variable.
48-
++nnvm::get<VariableParam>(e.node->attrs.parsed).version;
48+
e.version = ++nnvm::get<VariableParam>(e.node->attrs.parsed).version;
4949
}
5050
}
5151
}
@@ -98,14 +98,20 @@ Symbol Symbol::Copy() const {
9898
std::unordered_map<Node*, std::shared_ptr<Node> > old_new;
9999
// use DFSVisit to copy all the nodes
100100
DFSVisit(this->outputs, [&old_new](const std::shared_ptr<Node>& node) {
101-
old_new[node.get()] = std::make_shared<Node>(*node);
101+
std::shared_ptr<Node> np = Node::Create();
102+
np->op = node->op;
103+
np->attrs = node->attrs;
104+
old_new[node.get()] = std::move(np);
102105
});
103106
// connect nodes of new graph
104107
for (const auto &kv : old_new) {
105108
for (const NodeEntry& e : kv.first->inputs) {
106109
Node *ptr = e.node.get();
107110
kv.second->inputs.emplace_back(NodeEntry{old_new[ptr], e.index, e.version});
108111
}
112+
for (const std::shared_ptr<Node>& p : kv.first->control_deps) {
113+
kv.second->control_deps.emplace_back(old_new[p.get()]);
114+
}
109115
}
110116
// set the head
111117
Symbol ret;
@@ -120,7 +126,7 @@ void Symbol::Print(std::ostream &os) const {
120126
os << "AtomicFunctor "<< " Op:" << outputs[0].node->op->name << '\n';
121127
} else {
122128
// use DFSVisit to copy all the nodes
123-
os << "Outputs:\n";
129+
os << "Symbol Outputs:\n";
124130
for (size_t i = 0; i < outputs.size(); ++i) {
125131
os << "\toutput[" << i << "]=" << outputs[i].node->attrs.name
126132
<< '(' << outputs[i].index << ")\n";
@@ -129,7 +135,8 @@ void Symbol::Print(std::ostream &os) const {
129135
if (node->is_variable()) {
130136
os << "Variable:" << node->attrs.name << '\n';
131137
} else {
132-
os << "Name: " << node->attrs.name << " Op:" << node->op->name << '\n'
138+
os << "--------------------\n";
139+
os << "Op:" << node->op->name << ", Name=" << node->attrs.name << '\n'
133140
<< "Inputs:\n";
134141
for (size_t i = 0; i < node->inputs.size(); ++i) {
135142
const NodeEntry& e = node->inputs[i];
@@ -141,9 +148,17 @@ void Symbol::Print(std::ostream &os) const {
141148
os << '\n';
142149
}
143150
}
144-
os << "Attrs:\n";
145-
for (auto &kv : node->attrs.dict) {
146-
os << '\t' << kv.first << '=' << kv.second << '\n';
151+
if (!node->attrs.dict.empty()) {
152+
os << "Attrs:\n";
153+
for (auto &kv : node->attrs.dict) {
154+
os << '\t' << kv.first << '=' << kv.second << '\n';
155+
}
156+
}
157+
if (node->control_deps.size() != 0) {
158+
os << "Control deps:\n";
159+
for (size_t i = 0; i < node->control_deps.size(); ++i) {
160+
os << "\tcdep[" << i << "]=" << node->control_deps[i]->attrs.name << '\n';
161+
}
147162
}
148163
}
149164
});
@@ -203,8 +218,8 @@ std::vector<std::string> Symbol::ListOutputs() const {
203218
}
204219

205220
// compositional logic
206-
void Symbol::Compose(const std::vector<Symbol>& args,
207-
const std::unordered_map<std::string, Symbol>& kwargs,
221+
void Symbol::Compose(const array_view<const Symbol*>& args,
222+
const std::unordered_map<std::string, const Symbol*>& kwargs,
208223
const std::string& name) {
209224
static auto& flist_inputs = Op::GetAttr<FListInputNames>("FListInputNames");
210225

@@ -213,11 +228,11 @@ void Symbol::Compose(const std::vector<Symbol>& args,
213228
CHECK(!outputs[0].node->is_variable()) << "Variable cannot be composed";
214229
// parameter check.
215230
for (size_t i = 0; i < args.size(); ++i) {
216-
CHECK_EQ(args[i].outputs.size(), 1)
231+
CHECK_EQ(args[i]->outputs.size(), 1)
217232
<< "Argument " << i << " is a tuple, single value is required";
218233
}
219234
for (const auto& kv : kwargs) {
220-
CHECK_EQ(kv.second.outputs.size(), 1)
235+
CHECK_EQ(kv.second->outputs.size(), 1)
221236
<< "Keyword Argument " << kv.first << " is a tuple, single value is required";
222237
}
223238
// assign new name
@@ -234,7 +249,7 @@ void Symbol::Compose(const std::vector<Symbol>& args,
234249
<< "Incorrect number of arguments, requires " << n_req
235250
<< ", provided " << args.size();
236251
for (size_t i = 0; i < args.size(); ++i) {
237-
n->inputs[i] = args[i].outputs[0];
252+
n->inputs[i] = args[i]->outputs[0];
238253
}
239254
// switch to keyword argument matching
240255
if (args.size() != n_req) {
@@ -247,7 +262,7 @@ void Symbol::Compose(const std::vector<Symbol>& args,
247262
for (size_t i = args.size(); i < n_req; ++i) {
248263
auto it = kwargs.find(arg_names[i]);
249264
if (it != kwargs.end() && it->first == arg_names[i]) {
250-
n->inputs[i] = it->second.outputs[0];
265+
n->inputs[i] = it->second->outputs[0];
251266
++nmatched;
252267
} else {
253268
n->inputs[i] = NodeEntry{
@@ -266,8 +281,8 @@ void Symbol::Compose(const std::vector<Symbol>& args,
266281
} else {
267282
CHECK_EQ(kwargs.size(), 0) << "Variable length function do not accept kwargs";
268283
n->inputs.reserve(args.size());
269-
for (const Symbol& s : args) {
270-
n->inputs.push_back(s.outputs[0]);
284+
for (const Symbol* s : args) {
285+
n->inputs.push_back(s->outputs[0]);
271286
}
272287
}
273288
UpdateNodeVersion(n);
@@ -283,13 +298,13 @@ void Symbol::Compose(const std::vector<Symbol>& args,
283298
(const std::shared_ptr<Node> &node) {
284299
if (node->is_variable()) {
285300
if (arg_counter < args.size()) {
286-
replace_map[node.get()] = &(args[arg_counter].outputs[0]);
301+
replace_map[node.get()] = &(args[arg_counter]->outputs[0]);
287302
++arg_counter;
288303
} else {
289304
// match kwargs
290305
auto kit = kwargs.find(node->attrs.name);
291306
if (kit != kwargs.end()) {
292-
replace_map[node.get()] = &(kit->second.outputs[0]);
307+
replace_map[node.get()] = &(kit->second->outputs[0]);
293308
++nmatched;
294309
}
295310
}
@@ -334,8 +349,8 @@ void Symbol::Compose(const std::vector<Symbol>& args,
334349
}
335350
}
336351

337-
Symbol Symbol::operator () (const std::vector<Symbol>& args,
338-
const std::unordered_map<std::string, Symbol>& kwargs,
352+
Symbol Symbol::operator () (const array_view<const Symbol*>& args,
353+
const std::unordered_map<std::string, const Symbol*>& kwargs,
339354
const std::string& name) const {
340355
Symbol s = this->Copy();
341356
s.Compose(args, kwargs, name);

nnvm/src/pass/order_mutation.cc

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
/*!
2+
* Copyright (c) 2016 by Contributors
3+
* \file saveload_json.cc
4+
* \brief Add control flow dependencies between nodes
5+
* To correctly order mutation and read to resolve
6+
* write after read problem and read after write problems.
7+
*/
8+
#include <nnvm/pass.h>
9+
#include <nnvm/op_attr_types.h>
10+
11+
namespace nnvm {
12+
13+
template<typename T>
14+
inline T get_with_default(const std::unordered_map<Node*, T> &map,
15+
Node* key,
16+
const T& def) {
17+
auto it = map.find(key);
18+
if (it != map.end()) return it->second;
19+
return def;
20+
}
21+
22+
Graph OrderMutation(const Graph& src) {
23+
std::unordered_map<Node*, std::vector<NodeEntry> > version_hist;
24+
DFSVisit(src.outputs, [&version_hist](const std::shared_ptr<Node>& n) {
25+
for (const NodeEntry& e : n->inputs) {
26+
if (e.node->is_variable()) {
27+
if (e.version != 0 && version_hist.count(e.node.get()) == 0) {
28+
version_hist[e.node.get()] = std::vector<NodeEntry>{};
29+
}
30+
}
31+
}
32+
});
33+
// no mutation happens, everything if fine.
34+
if (version_hist.size() == 0) return src;
35+
// start preparing for remapping the nodes.
36+
std::unordered_map<Node*, std::shared_ptr<Node> > old_new;
37+
auto prepare = [&version_hist, &old_new] (const std::shared_ptr<Node>& n) {
38+
static auto& fmutate_inputs = Op::GetAttr<FMutateInput>("FMutateInput");
39+
bool need_repl = false;
40+
for (size_t i = 0; i < n->inputs.size(); ++i) {
41+
const NodeEntry& e = n->inputs[i];
42+
if (e.node->is_variable()) {
43+
if (e.version != 0) need_repl = true;
44+
auto it = version_hist.find(e.node.get());
45+
if (it != version_hist.end()) {
46+
std::vector<NodeEntry>& vec = it->second;
47+
uint32_t is_mutate =
48+
fmutate_inputs.count(n->op) ? fmutate_inputs[n->op](n->attrs, i) : 0;
49+
vec.emplace_back(NodeEntry{n, is_mutate, e.version});
50+
}
51+
} else {
52+
if (old_new.count(e.node.get()) != 0) need_repl = true;
53+
}
54+
}
55+
for (const std::shared_ptr<Node>& p : n->control_deps) {
56+
if (old_new.count(p.get()) != 0) need_repl = true;
57+
}
58+
if (need_repl) {
59+
std::shared_ptr<Node> np = Node::Create();
60+
np->op = n->op;
61+
np->attrs = n->attrs;
62+
old_new[n.get()] = std::move(np);
63+
}
64+
};
65+
DFSVisit(src.outputs, prepare);
66+
// comparator of history entry
67+
auto comparator = [](const NodeEntry& a, const NodeEntry &b) {
68+
if (a.version < b.version) return true;
69+
if (a.version > b.version) return false;
70+
return a.index > b.index;
71+
};
72+
73+
for (auto &kv : version_hist) {
74+
std::sort(kv.second.begin(), kv.second.end(), comparator);
75+
}
76+
// copy the nodes, as well as add control deps
77+
for (auto &kv : old_new) {
78+
// copy the nodes
79+
for (const NodeEntry& e : kv.first->inputs) {
80+
auto it = old_new.find(e.node.get());
81+
if (it != old_new.end()) {
82+
kv.second->inputs.emplace_back(NodeEntry{it->second, e.index, e.version});
83+
} else {
84+
kv.second->inputs.push_back(e);
85+
}
86+
}
87+
for (const std::shared_ptr<Node>& p : kv.first->control_deps) {
88+
kv.second->control_deps.emplace_back(
89+
get_with_default(old_new, p.get(), p));
90+
}
91+
// add control deps
92+
static auto& fmutate_inputs = Op::GetAttr<FMutateInput>("FMutateInput");
93+
for (size_t i = 0; i < kv.first->inputs.size(); ++i) {
94+
const NodeEntry& e = kv.first->inputs[i];
95+
if (e.node->is_variable() && version_hist.count(e.node.get()) != 0) {
96+
FMutateInput fmutate = fmutate_inputs.get(kv.first->op, nullptr);
97+
uint32_t is_mutate = (fmutate == nullptr) ? 0 : fmutate(kv.first->attrs, i);
98+
std::vector<NodeEntry>& vec = version_hist.at(e.node.get());
99+
100+
auto it = std::lower_bound(vec.begin(), vec.end(),
101+
NodeEntry{nullptr, 1, e.version},
102+
comparator);
103+
if (is_mutate != 0) {
104+
int read_dep = 0;
105+
while (it != vec.begin()) {
106+
--it;
107+
if (it->index != 0) break;
108+
++read_dep;
109+
// depend on previous read
110+
kv.second->control_deps.push_back(
111+
get_with_default(old_new, it->node.get(), it->node));
112+
}
113+
if (read_dep == 0 && it->index != 0) {
114+
// depend on last write
115+
kv.second->control_deps.push_back(
116+
get_with_default(old_new, it->node.get(), it->node));
117+
}
118+
} else {
119+
// depend on last write
120+
if (it->index != 0) {
121+
kv.second->control_deps.push_back(
122+
get_with_default(old_new, it->node.get(), it->node));
123+
}
124+
}
125+
}
126+
}
127+
}
128+
Graph ret;
129+
for (const NodeEntry &e : src.outputs) {
130+
ret.outputs.emplace_back(NodeEntry{
131+
get_with_default(old_new, e.node.get(), e.node), e.index, e.version});
132+
}
133+
return ret;
134+
}
135+
136+
NNVM_REGISTER_PASS(OrderMutation)
137+
.describe("Return a new graph that adds control dependencies, "\
138+
"to order the mutation and reads if mutation exists.")
139+
.set_body(OrderMutation)
140+
.set_change_graph(true);
141+
142+
} // namespace nnvm

nnvm/src/pass/saveload_json.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
/*!
22
* Copyright (c) 2016 by Contributors
33
* \file saveload_json.cc
4-
* \brief Passes that defines save and load graph to/from JSON file.
4+
* \brief Save and load graph to/from JSON file.
55
*/
66
#include <nnvm/pass.h>
77
#include <dmlc/json.h>

0 commit comments

Comments
 (0)