@@ -30,18 +30,34 @@ class Graph {
30
30
std::vector<NodeEntry> outputs;
31
31
/* !
32
32
* \brief attributes of a graph
33
- * Each attribute is immutable,
34
- * and can be shared across multiple Instance of graph
33
+ * Note that attribute is shared pointer and can be shared across graphs.
34
+ *
35
+ * It is highly recommended to keep each attribute immutable.
36
+ * It is also safe to implement an copy-on-write semnatics.
37
+ *
38
+ * Copy when shared_ptr.unique is not true, while reuse original space
39
+ * when shared_ptr.unique is true.
35
40
*/
36
- std::unordered_map<std::string, std::shared_ptr<const any> > attrs;
41
+ std::unordered_map<std::string, std::shared_ptr<any> > attrs;
37
42
/* !
38
- * \brief Get the attribute from attrs.
43
+ * \brief Get the immutable attribute from attrs.
39
44
* \param attr_name the name of the attribute
40
45
* \return the reference to corresponding attribute
41
46
* \tparam T the type of the attribute.
42
47
*/
43
48
template <typename T>
44
49
inline const T& GetAttr (const std::string& attr_name);
50
+ /* !
51
+ * \brief Get a move copy of the attribute, implement copy on write semantics.
52
+ * The content is moved if the reference counter of shared_ptr is 1.
53
+ * The attribute is erased from attrs after the call.
54
+ *
55
+ * \param attr_name the name of the attribute
56
+ * \return a new copy of the corresponding attribute.
57
+ * \tparam T the type of the attribute.
58
+ */
59
+ template <typename T>
60
+ inline T MoveCopyAttr (const std::string& attr_name);
45
61
/* !
46
62
* \brief get a indexed graph of current graph, if not exist, create it on demand
47
63
* \return The indexed graph.
@@ -200,6 +216,20 @@ inline const T& Graph::GetAttr(const std::string& attr_name) {
200
216
return nnvm::get<T>(*it->second );
201
217
}
202
218
219
+ template <typename T>
220
+ inline T Graph::MoveCopyAttr (const std::string& attr_name) {
221
+ auto it = attrs.find (attr_name);
222
+ CHECK (it != attrs.end ())
223
+ << " Cannot find attribute " << attr_name << " in the graph" ;
224
+ std::shared_ptr<any> sptr = it->second ;
225
+ attrs.erase (it);
226
+ if (sptr.unique ()) {
227
+ return std::move (nnvm::get<T>(*sptr));
228
+ } else {
229
+ return nnvm::get<T>(*sptr);
230
+ }
231
+ }
232
+
203
233
template <typename GNode, typename HashType,
204
234
typename FVisit, typename HashFunc,
205
235
typename InDegree, typename GetInput>
0 commit comments