File tree Expand file tree Collapse file tree 4 files changed +20
-2
lines changed Expand file tree Collapse file tree 4 files changed +20
-2
lines changed Original file line number Diff line number Diff line change @@ -147,6 +147,10 @@ class IndexedGraph {
147147 inline const std::vector<uint32_t >& input_nodes () const {
148148 return input_nodes_;
149149 }
150+ /* ! \return list of mutable nodes */
151+ inline const std::unordered_set<uint32_t >& mutable_input_nodes () const {
152+ return mutable_input_nodes_;
153+ }
150154 /* ! \return list of output entries */
151155 inline const std::vector<NodeEntry>& outputs () const {
152156 return outputs_;
@@ -161,8 +165,10 @@ class IndexedGraph {
161165 explicit IndexedGraph (const Graph& other);
162166 // node pointers in CSR structure.
163167 std::vector<Node> nodes_;
164- // index to input nodes
168+ // index all to input nodes
165169 std::vector<uint32_t > input_nodes_;
170+ // index to mutable input nodes
171+ std::unordered_set<uint32_t > mutable_input_nodes_;
166172 // space to store the outputs entries
167173 std::vector<NodeEntry> outputs_;
168174 // mapping from node to index.
Original file line number Diff line number Diff line change @@ -368,12 +368,14 @@ inline Op& Op::set_attr_parser(std::function<void (NodeAttrs* attrs)> fn) { //
368368// member functions of OpMap
369369template <typename ValueType>
370370inline int OpMap<ValueType>::count(const Op* op) const {
371+ if (op == nullptr ) return 0 ;
371372 const uint32_t idx = op->index_ ;
372373 return idx < data_.size () ? data_[idx].second : 0 ;
373374}
374375
375376template <typename ValueType>
376377inline const ValueType& OpMap<ValueType>::operator [](const Op* op) const {
378+ CHECK (op != nullptr );
377379 const uint32_t idx = op->index_ ;
378380 CHECK (idx < data_.size () && data_[idx].second )
379381 << " Attribute " << attr_name_
@@ -383,6 +385,7 @@ inline const ValueType& OpMap<ValueType>::operator[](const Op* op) const {
383385
384386template <typename ValueType>
385387inline const ValueType& OpMap<ValueType>::get(const Op* op, const ValueType& def_value) const {
388+ if (op == nullptr ) return def_value;
386389 const uint32_t idx = op->index_ ;
387390 if (idx < data_.size () && data_[idx].second ) {
388391 return data_[idx].first ;
Original file line number Diff line number Diff line change 44 * \brief Graph node data structure.
55 */
66#include < nnvm/graph.h>
7+ #include < nnvm/op_attr_types.h>
78#include < limits>
89
910namespace nnvm {
@@ -57,12 +58,20 @@ IndexedGraph::IndexedGraph(const Graph &g) {
5758 node2index_.at (e.node .get ()), e.index , e.version });
5859 }
5960
61+ static auto & fmutate_inputs = Op::GetAttr<FMutateInputs>(" FMutateInputs" );
62+ std::unordered_set<uint32_t > mutable_inputs;
6063 // setup array view
6164 // input_entries_ and control_rptr must not change after this step.
6265 const NodeEntry* iptr = dmlc::BeginPtr (input_entries_);
6366 for (size_t nid = 0 ; nid < nodes_.size (); ++nid) {
6467 nodes_[nid].inputs = array_view<NodeEntry>(
6568 iptr + inputs_rptr[nid], iptr + inputs_rptr[nid + 1 ]);
69+ if (nodes_[nid].source ->op != nullptr &&
70+ fmutate_inputs.count (nodes_[nid].source ->op )) {
71+ for (uint32_t i : fmutate_inputs[nodes_[nid].source ->op ](nodes_[nid].source ->attrs )) {
72+ mutable_input_nodes_.insert (nodes_[nid].inputs [i].node_id );
73+ }
74+ }
6675 }
6776 const uint32_t * cptr = dmlc::BeginPtr (control_deps_);
6877 for (size_t nid = 0 ; nid < nodes_.size (); ++nid) {
Original file line number Diff line number Diff line change @@ -101,7 +101,7 @@ NNVM_REGISTER_PASS(InferShape)
101101.set_body([](Graph ret) {
102102 return InferAttr<TShape>(
103103 std::move (ret), TShape (),
104- " FInferShape" , " shape_args " , " shape_attr_key" ,
104+ " FInferShape" , " shape_inputs " , " shape_attr_key" ,
105105 " shape" , " shape_num_unknown_nodes" ,
106106 [](const TShape& s) { return s.ndim () == 0 ; });
107107 })
You can’t perform that action at this time.
0 commit comments