Skip to content

Commit de07699

Browse files
committed
[NODE] Move op inside node attribute (apache#30)
1 parent ac070f8 commit de07699

File tree

11 files changed

+58
-54
lines changed

11 files changed

+58
-54
lines changed

nnvm/example/src/operator.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ inline NodeEntry MakeNode(const char* op_name,
4646
std::string node_name,
4747
std::vector<NodeEntry> inputs) {
4848
NodePtr p = Node::Create();
49-
p->op = nnvm::Op::Get(op_name);
49+
p->attrs.op = nnvm::Op::Get(op_name);
5050
p->attrs.name = std::move(node_name);
5151
p->inputs = std::move(inputs);
5252
return NodeEntry{p, 0, 0};

nnvm/include/nnvm/node.h

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,11 @@ struct NodeEntry {
4646
* Usually are additional parameters like axis,
4747
*/
4848
struct NodeAttrs {
49+
/*!
50+
* \brief The operator this node uses.
51+
* For place holder variable, op == nullptr.
52+
*/
53+
const Op *op{nullptr};
4954
/*! \brief name of the node */
5055
std::string name;
5156
/*! \brief Vector representation of positional attributes */
@@ -65,22 +70,19 @@ struct NodeAttrs {
6570
*/
6671
class Node {
6772
public:
68-
/*!
69-
* \brief The operator this node uses.
70-
* For place holder variable, op == nullptr.
71-
*/
72-
const Op *op{nullptr};
73+
/*! \brief The attributes in the node. */
74+
NodeAttrs attrs;
7375
/*! \brief inputs to this node */
7476
std::vector<NodeEntry> inputs;
7577
/*!
7678
* \brief Optional control flow dependencies
7779
* Gives operation must be performed before this operation.
7880
*/
7981
std::vector<NodePtr> control_deps;
80-
/*! \brief The attributes in the node. */
81-
NodeAttrs attrs;
8282
/*! \brief destructor of node */
8383
~Node();
84+
/*! \return operator in this node */
85+
inline const Op* op() const;
8486
/*!
8587
* \brief return whether node is placeholder variable.
8688
* This is equivalent to op == nullptr
@@ -99,25 +101,28 @@ class Node {
99101
};
100102

101103
// implementation of functions.
104+
inline const Op* Node::op() const {
105+
return this->attrs.op;
106+
}
102107
inline bool Node::is_variable() const {
103-
return this->op == nullptr;
108+
return this->op() == nullptr;
104109
}
105110

106111
inline uint32_t Node::num_outputs() const {
107112
if (is_variable()) return 1;
108-
if (this->op->get_num_outputs == nullptr) {
109-
return this->op->num_outputs;
113+
if (this->op()->get_num_outputs == nullptr) {
114+
return this->op()->num_outputs;
110115
} else {
111-
return this->op->get_num_outputs(this->attrs);
116+
return this->op()->get_num_outputs(this->attrs);
112117
}
113118
}
114119

115120
inline uint32_t Node::num_inputs() const {
116121
if (is_variable()) return 1;
117-
if (this->op->get_num_inputs == nullptr) {
118-
return this->op->num_inputs;
122+
if (this->op()->get_num_inputs == nullptr) {
123+
return this->op()->num_inputs;
119124
} else {
120-
return this->op->get_num_inputs(this->attrs);
125+
return this->op()->get_num_inputs(this->attrs);
121126
}
122127
}
123128

nnvm/include/nnvm/pass_functions.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include <string>
1414
#include <memory>
15+
#include <vector>
1516
#include "./base.h"
1617
#include "./pass.h"
1718
#include "./graph_attr_types.h"

nnvm/src/core/graph.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,9 @@ IndexedGraph::IndexedGraph(const Graph &g) {
6666
for (size_t nid = 0; nid < nodes_.size(); ++nid) {
6767
nodes_[nid].inputs = array_view<NodeEntry>(
6868
iptr + inputs_rptr[nid], iptr + inputs_rptr[nid + 1]);
69-
if (nodes_[nid].source->op != nullptr &&
70-
fmutate_inputs.count(nodes_[nid].source->op)) {
71-
for (uint32_t i : fmutate_inputs[nodes_[nid].source->op](nodes_[nid].source->attrs)) {
69+
if (nodes_[nid].source->op() != nullptr &&
70+
fmutate_inputs.count(nodes_[nid].source->op())) {
71+
for (uint32_t i : fmutate_inputs[nodes_[nid].source->op()](nodes_[nid].source->attrs)) {
7272
mutable_input_nodes_.insert(nodes_[nid].inputs[i].node_id);
7373
}
7474
}

nnvm/src/core/symbolic.cc

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ struct VariableParam {
2020

2121
NodePtr CreateVariableNode(const std::string& name) {
2222
NodePtr n = Node::Create();
23-
n->op = nullptr;
23+
n->attrs.op = nullptr;
2424
n->attrs.name = name;
2525
n->attrs.parsed = VariableParam();
2626
return n;
@@ -37,8 +37,8 @@ inline void UpdateNodeVersion(Node *n) {
3737
e.version = nnvm::get<VariableParam>(e.node->attrs.parsed).version;
3838
}
3939
}
40-
if (fmutate_inputs.count(n->op) != 0) {
41-
for (uint32_t i : fmutate_inputs[n->op](n->attrs)) {
40+
if (fmutate_inputs.count(n->op()) != 0) {
41+
for (uint32_t i : fmutate_inputs[n->op()](n->attrs)) {
4242
NodeEntry& e = n->inputs[i];
4343
CHECK(e.node->is_variable())
4444
<< "Mutation target can only be Variable";
@@ -96,7 +96,6 @@ Symbol Symbol::Copy() const {
9696
// use DFSVisit to copy all the nodes
9797
DFSVisit(this->outputs, [&old_new](const NodePtr& node) {
9898
NodePtr np = Node::Create();
99-
np->op = node->op;
10099
np->attrs = node->attrs;
101100
old_new[node.get()] = std::move(np);
102101
});
@@ -123,7 +122,7 @@ void Symbol::Print(std::ostream &os) const {
123122
if (outputs[0].node->is_variable()) {
124123
os << "Variable:" << outputs[0].node->attrs.name << '\n';
125124
} else {
126-
os << "AtomicFunctor "<< " Op:" << outputs[0].node->op->name << '\n';
125+
os << "AtomicFunctor "<< " Op:" << outputs[0].node->op()->name << '\n';
127126
}
128127
} else {
129128
// use DFSVisit to copy all the nodes
@@ -137,7 +136,7 @@ void Symbol::Print(std::ostream &os) const {
137136
os << "Variable:" << node->attrs.name << '\n';
138137
} else {
139138
os << "--------------------\n";
140-
os << "Op:" << node->op->name << ", Name=" << node->attrs.name << '\n'
139+
os << "Op:" << node->op()->name << ", Name=" << node->attrs.name << '\n'
141140
<< "Inputs:\n";
142141
for (size_t i = 0; i < node->inputs.size(); ++i) {
143142
const NodeEntry& e = node->inputs[i];
@@ -196,8 +195,8 @@ std::vector<std::string> Symbol::ListInputNames(ListInputOption option) const {
196195
DFSVisit(this->outputs, [&ret, &mutable_set, &vlist](const NodePtr &node) {
197196
if (node->is_variable()) {
198197
vlist.push_back(node.get());
199-
} else if (fmutate_inputs.count(node->op)) {
200-
for (uint32_t i : fmutate_inputs[node->op](node->attrs)){
198+
} else if (fmutate_inputs.count(node->op())) {
199+
for (uint32_t i : fmutate_inputs[node->op()](node->attrs)){
201200
mutable_set.insert(node->inputs[i].node.get());
202201
}
203202
}
@@ -221,7 +220,7 @@ std::vector<std::string> Symbol::ListOutputNames() const {
221220
} else {
222221
const std::string& hname = head.node->attrs.name;
223222
std::string rname;
224-
FListOutputNames fn = flist_ouputs.get(head.node->op, nullptr);
223+
FListOutputNames fn = flist_ouputs.get(head.node->op(), nullptr);
225224
if (fn != nullptr) {
226225
rname = fn(head.node->attrs)[head.index];
227226
} else {
@@ -278,10 +277,10 @@ void Symbol::Compose(const array_view<const Symbol*>& args,
278277
}
279278
// switch to keyword argument matching
280279
if (args.size() != n_req) {
281-
FListInputNames fn = flist_inputs.get(n->op, nullptr);
280+
FListInputNames fn = flist_inputs.get(n->op(), nullptr);
282281
auto arg_names = (fn == nullptr) ? std::vector<std::string>{"data"} : fn(n->attrs);
283282
if (arg_names.size() != n_req) {
284-
LOG(FATAL) << "Not enough argument to call operator " << outputs[0].node->op->name;
283+
LOG(FATAL) << "Not enough argument to call operator " << outputs[0].node->op()->name;
285284
}
286285
size_t nmatched = 0;
287286
for (size_t i = args.size(); i < n_req; ++i) {
@@ -422,8 +421,8 @@ void Symbol::SetAttrs(const std::vector<std::pair<std::string, std::string> >& a
422421
node->attrs.dict[kv.first] = kv.second;
423422
}
424423
}
425-
if (node->op != nullptr && node->op->attr_parser != nullptr) {
426-
node->op->attr_parser(&(node->attrs));
424+
if (node->op() != nullptr && node->op()->attr_parser != nullptr) {
425+
node->op()->attr_parser(&(node->attrs));
427426
}
428427
}
429428

@@ -461,10 +460,10 @@ Symbol Symbol::CreateFunctor(const Op* op,
461460
std::unordered_map<std::string, std::string> attrs) {
462461
Symbol s;
463462
NodePtr n = Node::Create();
464-
n->op = op;
463+
n->attrs.op = op;
465464
n->attrs.dict = std::move(attrs);
466-
if (n->op->attr_parser != nullptr) {
467-
n->op->attr_parser(&(n->attrs));
465+
if (n->op()->attr_parser != nullptr) {
466+
n->op()->attr_parser(&(n->attrs));
468467
}
469468
s.outputs.emplace_back(NodeEntry{std::move(n), 0, 0});
470469
return s;

nnvm/src/pass/gradient.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@ NodeEntry DefaultAggregateGradient(std::vector<NodeEntry>&& v) {
2020
return std::move(v[0]);
2121
} else if (v.size() == 0) {
2222
NodePtr zero_node = Node::Create();
23-
zero_node->op = Op::Get("__zero__");
23+
zero_node->attrs.op = Op::Get("__zero__");
2424
return NodeEntry{zero_node, 0, 0};
2525
} else {
2626
NodePtr sum_node = Node::Create();
27-
sum_node->op = Op::Get("__ewise_sum__");
27+
sum_node->attrs.op = Op::Get("__ewise_sum__");
2828
sum_node->inputs = std::move(v);
2929
return NodeEntry{sum_node, 0, 0};
3030
}
@@ -109,7 +109,7 @@ Graph Gradient(Graph src) {
109109
e.sum = agg_fun(std::move(e.grads));
110110
out_agg_grads.push_back(e.sum);
111111
}
112-
std::vector<NodeEntry> input_grads = grad_fun_map[ptr->op]
112+
std::vector<NodeEntry> input_grads = grad_fun_map[ptr->op()]
113113
(mirror_map.size() == 0 ? ptr : mirror_map.at(ptr.get()), out_agg_grads);
114114
auto git = input_grads.begin();
115115
for (auto it = (*rit)->inputs.begin(); it != (*rit)->inputs.end(); ++it, ++git) {

nnvm/src/pass/infer_shape_type.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ Graph InferAttr(Graph &&ret,
6565
}
6666
continue;
6767
}
68-
if (finfer_shape.count(inode.source->op)) {
68+
if (finfer_shape.count(inode.source->op())) {
6969
ishape.resize(num_inputs, def_value);
7070
for (uint32_t i = 0; i < ishape.size(); ++i) {
7171
ishape[i] = rshape[idx.entry_id(inode.inputs[i])];
@@ -75,14 +75,14 @@ Graph InferAttr(Graph &&ret,
7575
oshape[i] = rshape[idx.entry_id(nid, i)];
7676
}
7777
num_unknown +=
78-
!(finfer_shape[inode.source->op](inode.source->attrs, &ishape, &oshape));
78+
!(finfer_shape[inode.source->op()](inode.source->attrs, &ishape, &oshape));
7979
for (uint32_t i = 0; i < num_inputs; ++i) {
8080
rshape[idx.entry_id(inode.inputs[i])] = ishape[i];
8181
}
8282
for (uint32_t i = 0; i < num_outputs; ++i) {
8383
rshape[idx.entry_id(nid, i)] = oshape[i];
8484
}
85-
} else if (is_backward.get(inode.source->op, false)) {
85+
} else if (is_backward.get(inode.source->op(), false)) {
8686
// backward operator inference.
8787
CHECK_GE(inode.control_deps.size(), 1)
8888
<< "BackwardOp need to have control_deps to its forward op";

nnvm/src/pass/order_mutation.cc

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ Graph OrderMutation(const Graph& src) {
4343
auto prepare = [&version_hist, &old_new] (const NodePtr& n) {
4444
static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
4545
std::vector<uint32_t> mutate_inputs;
46-
if (!n->is_variable() && fmutate_inputs.count(n->op)) {
47-
mutate_inputs = fmutate_inputs[n->op](n->attrs);
46+
if (!n->is_variable() && fmutate_inputs.count(n->op())) {
47+
mutate_inputs = fmutate_inputs[n->op()](n->attrs);
4848
}
4949
std::sort(mutate_inputs.begin(), mutate_inputs.end());
5050

@@ -67,7 +67,6 @@ Graph OrderMutation(const Graph& src) {
6767
}
6868
if (need_repl) {
6969
NodePtr np = Node::Create();
70-
np->op = n->op;
7170
np->attrs = n->attrs;
7271
old_new[n.get()] = std::move(np);
7372
}
@@ -101,8 +100,8 @@ Graph OrderMutation(const Graph& src) {
101100
// add control deps
102101
static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
103102
std::vector<uint32_t> mutate_inputs;
104-
if (fmutate_inputs.count(kv.first->op)) {
105-
mutate_inputs = fmutate_inputs[kv.first->op](kv.first->attrs);
103+
if (fmutate_inputs.count(kv.first->op())) {
104+
mutate_inputs = fmutate_inputs[kv.first->op()](kv.first->attrs);
106105
}
107106
std::sort(mutate_inputs.begin(), mutate_inputs.end());
108107

nnvm/src/pass/place_device.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,9 @@ Graph PlaceDevice(Graph src) {
109109
NodeEntry{it->second, 0, 0});
110110
} else {
111111
NodePtr copy_node = Node::Create();
112-
copy_node->op = copy_op;
113112
std::ostringstream os;
114113
os << inode.source->inputs[i].node->attrs.name << "_" << e.index <<"_copy";
114+
copy_node->attrs.op = copy_op;
115115
copy_node->attrs.name = os.str();
116116
copy_node->inputs.push_back(inode.source->inputs[i]);
117117
copy_map[copy_key] = copy_node;

nnvm/src/pass/plan_memory.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,8 +168,8 @@ Graph PlanMemory(Graph ret) {
168168
const auto& inode = idx[nid];
169169
if (inode.source->is_variable()) continue;
170170
// check inplace option
171-
if (finplace_option.count(inode.source->op) != 0) {
172-
auto inplace_pairs = finplace_option[inode.source->op](inode.source->attrs);
171+
if (finplace_option.count(inode.source->op()) != 0) {
172+
auto inplace_pairs = finplace_option[inode.source->op()](inode.source->attrs);
173173
for (auto& kv : inplace_pairs) {
174174
uint32_t eid_out = idx.entry_id(nid, kv.second);
175175
uint32_t eid_in = idx.entry_id(inode.inputs[kv.first]);

0 commit comments

Comments
 (0)