Skip to content

Commit 254d177

Browse files
committed
Rename shared_ptr<Node> to NodePtr (#8)
1 parent 9a956a8 commit 254d177

File tree

5 files changed

+36
-27
lines changed

5 files changed

+36
-27
lines changed

nnvm/include/nnvm/node.h

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,19 @@ namespace nnvm {
1818
// Forward declare node.
1919
class Node;
2020

21+
/*!
22+
* \brief we always used NodePtr for a reference pointer
23+
* to the node, so this alias can be changed in case we need
24+
* even faster graph composition than 3M ops/sec.
25+
*
26+
* By default, NodePtr is a std::shared_ptr of node
27+
*/
28+
using NodePtr = std::shared_ptr<Node>;
29+
2130
/*! \brief an entry that represents output data from a node */
2231
struct NodeEntry {
2332
/*! \brief the source node of this data */
24-
std::shared_ptr<Node> node;
33+
NodePtr node;
2534
/*! \brief index of output from the source. */
2635
uint32_t index;
2736
/*!
@@ -66,7 +75,7 @@ class Node {
6675
* \brief Optional control flow dependencies
6776
* Gives operation must be performed before this operation.
6877
*/
69-
std::vector<std::shared_ptr<Node> > control_deps;
78+
std::vector<NodePtr> control_deps;
7079
/*! \brief The attributes in the node. */
7180
NodeAttrs attrs;
7281
/*! \brief destructor of node */
@@ -85,7 +94,7 @@ class Node {
8594
* \brief create a new empty shared_ptr of Node.
8695
* \return a created empty node.
8796
*/
88-
static std::shared_ptr<Node> Create();
97+
static NodePtr Create();
8998
};
9099

91100
// implementation of functions.

nnvm/src/core/node.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ Node::~Node() {
2424
e.node.reset();
2525
}
2626
}
27-
for (std::shared_ptr<Node>& sp : n->control_deps) {
27+
for (NodePtr& sp : n->control_deps) {
2828
if (sp.unique()) {
2929
stack.push_back(sp.get());
3030
} else {
@@ -36,7 +36,7 @@ Node::~Node() {
3636
}
3737
}
3838

39-
std::shared_ptr<Node> Node::Create() {
39+
NodePtr Node::Create() {
4040
// NOTE: possible change to thread local memory pool
4141
// via std::allocate_shared instead for faster allocation.
4242
return std::make_shared<Node>();

nnvm/src/core/symbolic.cc

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ struct VariableParam {
1818
uint32_t version{0};
1919
};
2020

21-
std::shared_ptr<Node> CreateVariableNode(const std::string& name) {
22-
std::shared_ptr<Node> n = Node::Create();
21+
NodePtr CreateVariableNode(const std::string& name) {
22+
NodePtr n = Node::Create();
2323
n->op = nullptr;
2424
n->attrs.name = name;
2525
n->attrs.parsed = VariableParam();
@@ -95,10 +95,10 @@ inline bool IsAtomic(const std::vector<NodeEntry>& outputs) {
9595

9696
// public functions
9797
Symbol Symbol::Copy() const {
98-
std::unordered_map<Node*, std::shared_ptr<Node> > old_new;
98+
std::unordered_map<Node*, NodePtr> old_new;
9999
// use DFSVisit to copy all the nodes
100-
DFSVisit(this->outputs, [&old_new](const std::shared_ptr<Node>& node) {
101-
std::shared_ptr<Node> np = Node::Create();
100+
DFSVisit(this->outputs, [&old_new](const NodePtr& node) {
101+
NodePtr np = Node::Create();
102102
np->op = node->op;
103103
np->attrs = node->attrs;
104104
old_new[node.get()] = std::move(np);
@@ -109,7 +109,7 @@ Symbol Symbol::Copy() const {
109109
Node *ptr = e.node.get();
110110
kv.second->inputs.emplace_back(NodeEntry{old_new[ptr], e.index, e.version});
111111
}
112-
for (const std::shared_ptr<Node>& p : kv.first->control_deps) {
112+
for (const NodePtr& p : kv.first->control_deps) {
113113
kv.second->control_deps.emplace_back(old_new[p.get()]);
114114
}
115115
}
@@ -131,7 +131,7 @@ void Symbol::Print(std::ostream &os) const {
131131
os << "\toutput[" << i << "]=" << outputs[i].node->attrs.name
132132
<< '(' << outputs[i].index << ")\n";
133133
}
134-
DFSVisit(this->outputs, [&os](const std::shared_ptr<Node>& node) {
134+
DFSVisit(this->outputs, [&os](const NodePtr& node) {
135135
if (node->is_variable()) {
136136
os << "Variable:" << node->attrs.name << '\n';
137137
} else {
@@ -179,7 +179,7 @@ Symbol Symbol::operator[] (size_t index) const {
179179

180180
std::vector<std::string> Symbol::ListArguments() const {
181181
std::vector<std::string> ret;
182-
DFSVisit(this->outputs, [&ret](const std::shared_ptr<Node> &node) {
182+
DFSVisit(this->outputs, [&ret](const NodePtr &node) {
183183
if (node->is_variable()) {
184184
ret.push_back(node->attrs.name);
185185
}
@@ -295,7 +295,7 @@ void Symbol::Compose(const array_view<const Symbol*>& args,
295295
std::unordered_map<Node *, const NodeEntry*> replace_map;
296296
// replace map stores the existing replacement plan for arguments node
297297
auto find_replace_map = [&nmatched, &arg_counter, &args, &kwargs, &replace_map]
298-
(const std::shared_ptr<Node> &node) {
298+
(const NodePtr &node) {
299299
if (node->is_variable()) {
300300
if (arg_counter < args.size()) {
301301
replace_map[node.get()] = &(args[arg_counter]->outputs[0]);
@@ -316,7 +316,7 @@ void Symbol::Compose(const array_view<const Symbol*>& args,
316316
std::vector<Node*> update_nodes;
317317
std::vector<std::pair<NodeEntry*, const NodeEntry*> > replace_plan;
318318
auto find_replace_plan = [&replace_map, &replace_plan, &update_nodes]
319-
(const std::shared_ptr<Node> &node) {
319+
(const NodePtr &node) {
320320
// visit all the childs, find possible replacement
321321
bool repl = false;
322322
for (size_t i = 0; i < node->inputs.size(); ++i) {
@@ -368,7 +368,7 @@ void Symbol::AddControlDeps(const Symbol& src) {
368368

369369
Symbol Symbol::GetInternals() const {
370370
Symbol ret;
371-
DFSVisit(this->outputs, [&ret](const std::shared_ptr<Node>& node) {
371+
DFSVisit(this->outputs, [&ret](const NodePtr& node) {
372372
Node* n = node.get();
373373
if (n->is_variable()) {
374374
// grab version from variable.
@@ -421,7 +421,7 @@ bool Symbol::GetAttr(const std::string& key, std::string* out) const {
421421
std::unordered_map<std::string, std::string> Symbol::ListAttrs(ListAttrOption option) const {
422422
if (option == kRecursive) {
423423
std::unordered_map<std::string, std::string> ret;
424-
DFSVisit(this->outputs, [&ret](const std::shared_ptr<Node>& n) {
424+
DFSVisit(this->outputs, [&ret](const NodePtr& n) {
425425
for (const auto& it : n->attrs.dict) {
426426
ret[n->attrs.name + symbol_constants::kNamespaceSeparator + it.first] = it.second;
427427
}
@@ -435,7 +435,7 @@ std::unordered_map<std::string, std::string> Symbol::ListAttrs(ListAttrOption op
435435
Symbol Symbol::CreateFunctor(const Op* op,
436436
std::unordered_map<std::string, std::string>&& attrs) {
437437
Symbol s;
438-
std::shared_ptr<Node> n = Node::Create();
438+
NodePtr n = Node::Create();
439439
n->op = op;
440440
n->attrs.dict = std::move(attrs);
441441
if (n->op->attr_parser != nullptr) {

nnvm/src/pass/order_mutation.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ inline T get_with_default(const std::unordered_map<Node*, T> &map,
2121

2222
Graph OrderMutation(const Graph& src) {
2323
std::unordered_map<Node*, std::vector<NodeEntry> > version_hist;
24-
DFSVisit(src.outputs, [&version_hist](const std::shared_ptr<Node>& n) {
24+
DFSVisit(src.outputs, [&version_hist](const NodePtr& n) {
2525
for (const NodeEntry& e : n->inputs) {
2626
if (e.node->is_variable()) {
2727
if (e.version != 0 && version_hist.count(e.node.get()) == 0) {
@@ -33,8 +33,8 @@ Graph OrderMutation(const Graph& src) {
3333
// no mutation happens, everything if fine.
3434
if (version_hist.size() == 0) return src;
3535
// start preparing for remapping the nodes.
36-
std::unordered_map<Node*, std::shared_ptr<Node> > old_new;
37-
auto prepare = [&version_hist, &old_new] (const std::shared_ptr<Node>& n) {
36+
std::unordered_map<Node*, NodePtr> old_new;
37+
auto prepare = [&version_hist, &old_new] (const NodePtr& n) {
3838
static auto& fmutate_inputs = Op::GetAttr<FMutateInput>("FMutateInput");
3939
bool need_repl = false;
4040
for (size_t i = 0; i < n->inputs.size(); ++i) {
@@ -52,11 +52,11 @@ Graph OrderMutation(const Graph& src) {
5252
if (old_new.count(e.node.get()) != 0) need_repl = true;
5353
}
5454
}
55-
for (const std::shared_ptr<Node>& p : n->control_deps) {
55+
for (const NodePtr& p : n->control_deps) {
5656
if (old_new.count(p.get()) != 0) need_repl = true;
5757
}
5858
if (need_repl) {
59-
std::shared_ptr<Node> np = Node::Create();
59+
NodePtr np = Node::Create();
6060
np->op = n->op;
6161
np->attrs = n->attrs;
6262
old_new[n.get()] = std::move(np);
@@ -84,7 +84,7 @@ Graph OrderMutation(const Graph& src) {
8484
kv.second->inputs.push_back(e);
8585
}
8686
}
87-
for (const std::shared_ptr<Node>& p : kv.first->control_deps) {
87+
for (const NodePtr& p : kv.first->control_deps) {
8888
kv.second->control_deps.emplace_back(
8989
get_with_default(old_new, p.get(), p));
9090
}

nnvm/src/pass/saveload_json.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ struct JSONNode {
3232
// the node entry structure in serialized format
3333
typedef std::pair<uint32_t, uint32_t> Entry;
3434
// pointer to the graph node
35-
std::shared_ptr<Node> node;
35+
NodePtr node;
3636
// inputs
3737
std::vector<Entry> inputs;
3838
// control flow dependencies
@@ -159,7 +159,7 @@ Graph LoadJSON(const Graph& src) {
159159
Graph SaveJSON(const Graph& src) {
160160
JSONGraph jgraph;
161161
std::unordered_map<Node*, uint32_t> node2index;
162-
DFSVisit(src.outputs, [&node2index, &jgraph](const std::shared_ptr<Node>& n) {
162+
DFSVisit(src.outputs, [&node2index, &jgraph](const NodePtr& n) {
163163
uint32_t nid = static_cast<uint32_t>(jgraph.nodes.size());
164164
node2index[n.get()] = nid;
165165
if (n->is_variable()) {
@@ -172,7 +172,7 @@ Graph SaveJSON(const Graph& src) {
172172
jnode.inputs.emplace_back(
173173
std::make_pair(node2index.at(e.node.get()), e.index));
174174
}
175-
for (const std::shared_ptr<Node>& c : n->control_deps) {
175+
for (const NodePtr& c : n->control_deps) {
176176
jnode.control_deps.push_back(node2index.at(c.get()));
177177
}
178178
jgraph.nodes.emplace_back(std::move(jnode));

0 commit comments

Comments
 (0)