Skip to content

Commit 7c3d18c

Browse files
committed
Enable copy on write in graph attrs (apache#31)
* [INFER] Enhance backward op policy * [SYMBOL] add list inputs * relax graph attr to enable copy-on-write
1 parent de07699 commit 7c3d18c

File tree

8 files changed

+93
-33
lines changed

8 files changed

+93
-33
lines changed

nnvm/include/nnvm/graph.h

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,18 +30,34 @@ class Graph {
3030
std::vector<NodeEntry> outputs;
3131
/*!
3232
* \brief attributes of a graph
33-
* Each attribute is immutable,
34-
* and can be shared across multiple Instance of graph
33+
* Note that attribute is shared pointer and can be shared across graphs.
34+
*
35+
* It is highly recommended to keep each attribute immutable.
36+
* It is also safe to implement an copy-on-write semnatics.
37+
*
38+
* Copy when shared_ptr.unique is not true, while reuse original space
39+
* when shared_ptr.unique is true.
3540
*/
36-
std::unordered_map<std::string, std::shared_ptr<const any> > attrs;
41+
std::unordered_map<std::string, std::shared_ptr<any> > attrs;
3742
/*!
38-
* \brief Get the attribute from attrs.
43+
* \brief Get the immutable attribute from attrs.
3944
* \param attr_name the name of the attribute
4045
* \return the reference to corresponding attribute
4146
* \tparam T the type of the attribute.
4247
*/
4348
template<typename T>
4449
inline const T& GetAttr(const std::string& attr_name);
50+
/*!
51+
* \brief Get a move copy of the attribute, implement copy on write semantics.
52+
* The content is moved if the reference counter of shared_ptr is 1.
53+
* The attribute is erased from attrs after the call.
54+
*
55+
* \param attr_name the name of the attribute
56+
* \return a new copy of the corresponding attribute.
57+
* \tparam T the type of the attribute.
58+
*/
59+
template<typename T>
60+
inline T MoveCopyAttr(const std::string& attr_name);
4561
/*!
4662
* \brief get a indexed graph of current graph, if not exist, create it on demand
4763
* \return The indexed graph.
@@ -200,6 +216,20 @@ inline const T& Graph::GetAttr(const std::string& attr_name) {
200216
return nnvm::get<T>(*it->second);
201217
}
202218

219+
template<typename T>
220+
inline T Graph::MoveCopyAttr(const std::string& attr_name) {
221+
auto it = attrs.find(attr_name);
222+
CHECK(it != attrs.end())
223+
<< "Cannot find attribute " << attr_name << " in the graph";
224+
std::shared_ptr<any> sptr = it->second;
225+
attrs.erase(it);
226+
if (sptr.unique()) {
227+
return std::move(nnvm::get<T>(*sptr));
228+
} else {
229+
return nnvm::get<T>(*sptr);
230+
}
231+
}
232+
203233
template <typename GNode, typename HashType,
204234
typename FVisit, typename HashFunc,
205235
typename InDegree, typename GetInput>

nnvm/include/nnvm/op_attr_types.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,17 +82,18 @@ using FInferShape = FInferNodeEntryAttr<TShape>;
8282
using FInferType = FInferNodeEntryAttr<int>;
8383

8484
/*!
85-
* \brief Whether this op is an explicit backward operator
85+
* \brief Whether this op is an explicit backward operator,
86+
* and the correspondence of each output to input.
8687
*
87-
* If TIsBackwardOp is set to be true:
88+
* If FBackwardOutToInIndex exists:
8889
* - The first control_deps of the node points to the corresponding forward operator.
89-
* - The outputs operator corresponds to exactly inputs of forward op one by one.
90-
*
91-
* \note Register under "TIsBackwardOp", default to false.
90+
* - The k-th outputs corresponds to the FBackwardOutputToInputIndex()[k]-th input of forward op.
9291
*
92+
* \note Register under "FBackwardOutToInIndex"
9393
* This enables easier shape/type inference for backward operators for slice and reduction.
9494
*/
95-
using TIsBackwardOp = bool;
95+
using FBackwardOutToInIndex = std::function<
96+
std::vector<uint32_t> (const NodeAttrs& attrs)>;
9697

9798
/*!
9899
* \brief Get possible inplace options.

nnvm/include/nnvm/symbolic.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,15 @@ class Symbol {
6262
* \return the symbol corresponds to the indexed element.
6363
*/
6464
Symbol operator[] (size_t index) const;
65+
/*!
66+
* \brief List the input variable nodes
67+
* \param option The options to list the arguments.
68+
*
69+
* The position of the returned list also corresponds to calling position in operator()
70+
* \return the arguments list of this symbol, they can be either named or unnamed (empty string).
71+
* \sa ListInputOption
72+
*/
73+
std::vector<NodePtr> ListInputs(ListInputOption option) const;
6574
/*!
6675
* \brief List the input names.
6776
* \param option The options to list the arguments.

nnvm/include/nnvm/tuple.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ class Tuple {
233233
return is;
234234
}
235235
}
236-
index_t idx;
236+
ValueType idx;
237237
std::vector<ValueType> tmp;
238238
while (is >> idx) {
239239
tmp.push_back(idx);

nnvm/src/core/symbolic.cc

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -180,37 +180,46 @@ Symbol Symbol::operator[] (size_t index) const {
180180
}
181181
}
182182

183-
std::vector<std::string> Symbol::ListInputNames(ListInputOption option) const {
184-
std::vector<std::string> ret;
183+
std::vector<NodePtr> Symbol::ListInputs(ListInputOption option) const {
184+
std::vector<NodePtr> ret;
185185
if (option == kAll) {
186186
DFSVisit(this->outputs, [&ret](const NodePtr &node) {
187187
if (node->is_variable()) {
188-
ret.push_back(node->attrs.name);
188+
ret.push_back(node);
189189
}
190190
});
191191
} else {
192192
std::unordered_set<Node*> mutable_set;
193-
std::vector<Node*> vlist;
193+
std::vector<NodePtr> vlist;
194194
static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
195195
DFSVisit(this->outputs, [&ret, &mutable_set, &vlist](const NodePtr &node) {
196196
if (node->is_variable()) {
197-
vlist.push_back(node.get());
197+
vlist.push_back(node);
198198
} else if (fmutate_inputs.count(node->op())) {
199199
for (uint32_t i : fmutate_inputs[node->op()](node->attrs)){
200200
mutable_set.insert(node->inputs[i].node.get());
201201
}
202202
}
203203
});
204-
for (Node* node : vlist) {
205-
if ((option == kReadOnlyArgs && mutable_set.count(node) == 0) ||
206-
(option == kAuxiliaryStates && mutable_set.count(node) != 0)) {
207-
ret.push_back(node->attrs.name);
204+
for (const NodePtr& node : vlist) {
205+
if ((option == kReadOnlyArgs && mutable_set.count(node.get()) == 0) ||
206+
(option == kAuxiliaryStates && mutable_set.count(node.get()) != 0)) {
207+
ret.emplace_back(node);
208208
}
209209
}
210210
}
211211
return ret;
212212
}
213213

214+
std::vector<std::string> Symbol::ListInputNames(ListInputOption option) const {
215+
std::vector<NodePtr> inputs = ListInputs(option);
216+
std::vector<std::string> ret(inputs.size());
217+
for (size_t i = 0; i < inputs.size(); ++i) {
218+
ret[i] = inputs[i]->attrs.name;
219+
}
220+
return ret;
221+
}
222+
214223
std::vector<std::string> Symbol::ListOutputNames() const {
215224
static auto& flist_ouputs = Op::GetAttr<FListOutputNames>("FListOutputNames");
216225
std::vector<std::string> ret;

nnvm/src/pass/infer_shape_type.cc

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ Graph InferAttr(Graph &&ret,
2424
const IndexedGraph& idx = ret.indexed_graph();
2525
static auto& finfer_shape =
2626
Op::GetAttr<FInferNodeEntryAttr<AttrType> >(infer_name);
27-
static auto& is_backward =
28-
Op::GetAttr<TIsBackwardOp>("TIsBackwardOp");
27+
static auto& backward_map =
28+
Op::GetAttr<FBackwardOutToInIndex>("FBackwardOutToInIndex");
2929
// reshape shape vector
3030
AttrVector rshape(idx.num_node_entries(), def_value);
3131

@@ -82,16 +82,19 @@ Graph InferAttr(Graph &&ret,
8282
for (uint32_t i = 0; i < num_outputs; ++i) {
8383
rshape[idx.entry_id(nid, i)] = oshape[i];
8484
}
85-
} else if (is_backward.get(inode.source->op(), false)) {
85+
} else if (backward_map.count(inode.source->op())) {
8686
// backward operator inference.
8787
CHECK_GE(inode.control_deps.size(), 1)
8888
<< "BackwardOp need to have control_deps to its forward op";
8989
const auto& fnode = idx[inode.control_deps[0]];
90-
CHECK_EQ(fnode.inputs.size(), num_outputs)
91-
<< "BackwardOp need to correspond to the forward node";
90+
std::vector<uint32_t> out_map =
91+
backward_map[inode.source->op()](inode.source->attrs);
9292
bool known = true;
93-
for (size_t i = 0; i < fnode.inputs.size(); ++i) {
94-
rshape[idx.entry_id(nid, i)] = rshape[idx.entry_id(fnode.inputs[i])];
93+
for (size_t i = 0; i < out_map.size(); ++i) {
94+
uint32_t in_id = out_map[i];
95+
CHECK_LT(in_id, fnode.inputs.size());
96+
rshape[idx.entry_id(nid, i)] =
97+
rshape[idx.entry_id(fnode.inputs[in_id])];
9598
if (fis_none(rshape[idx.entry_id(nid, i)])) known = false;
9699
}
97100
num_unknown += !known;

nnvm/src/pass/place_device.cc

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ namespace nnvm {
1212
namespace pass {
1313
namespace {
1414

15+
1516
// simply logic to place device according to device_group hint
1617
// insert copy node when there is
1718
Graph PlaceDevice(Graph src) {
@@ -21,13 +22,20 @@ Graph PlaceDevice(Graph src) {
2122
<< "Need graph attribute \"device_assign_map\" in PlaceDevice";
2223
CHECK_NE(src.attrs.count("device_copy_op"), 0)
2324
<< "Need graph attribute \"device_copy_op\" in PlaceDevice";
24-
2525
std::string device_group_attr_key = src.GetAttr<std::string>("device_group_attr_key");
2626
const Op* copy_op = Op::Get(src.GetAttr<std::string>("device_copy_op"));
2727
auto& device_assign_map = src.GetAttr<DeviceAssignMap>("device_assign_map");
2828
const IndexedGraph& idx = src.indexed_graph();
2929

30-
DeviceVector device(idx.num_nodes(), -1);
30+
DeviceVector device;
31+
// copy on write semanatics
32+
if (src.attrs.count("device") != 0) {
33+
device = src.MoveCopyAttr<DeviceVector>("device");
34+
CHECK_EQ(device.size(), idx.num_nodes());
35+
} else {
36+
device.resize(idx.num_nodes(), -1);
37+
}
38+
3139
// forward pass
3240
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
3341
const auto& inode = idx[nid];

nnvm/src/pass/saveload_json.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@ namespace dmlc {
1212
namespace json {
1313
// overload handler for shared ptr
1414
template<>
15-
struct Handler<std::shared_ptr<const any> > {
16-
inline static void Write(JSONWriter *writer, const std::shared_ptr<const any> &data) {
15+
struct Handler<std::shared_ptr<any> > {
16+
inline static void Write(JSONWriter *writer, const std::shared_ptr<any> &data) {
1717
writer->Write(*data);
1818
}
19-
inline static void Read(JSONReader *reader, std::shared_ptr<const any> *data) {
19+
inline static void Read(JSONReader *reader, std::shared_ptr<any> *data) {
2020
any v;
2121
reader->Read(&v);
2222
*data = std::make_shared<any>(std::move(v));
@@ -131,7 +131,7 @@ struct JSONGraph {
131131
std::vector<uint32_t> arg_nodes;
132132
std::vector<uint32_t> node_row_ptr;
133133
std::vector<JSONNode::Entry> heads;
134-
std::unordered_map<std::string, std::shared_ptr<const any> > attrs;
134+
std::unordered_map<std::string, std::shared_ptr<any> > attrs;
135135

136136
void Save(dmlc::JSONWriter *writer) const {
137137
writer->BeginObject();

0 commit comments

Comments
 (0)