Skip to content

Commit efc87f1

Browse files
committed
Enable aux data (apache#24)
1 parent 486249e commit efc87f1

File tree

4 files changed

+20
-2
lines changed

4 files changed

+20
-2
lines changed

nnvm/include/nnvm/graph.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,10 @@ class IndexedGraph {
147147
inline const std::vector<uint32_t>& input_nodes() const {
148148
return input_nodes_;
149149
}
150+
/*! \return list of mutable nodes */
151+
inline const std::unordered_set<uint32_t>& mutable_input_nodes() const {
152+
return mutable_input_nodes_;
153+
}
150154
/*! \return list of output entries */
151155
inline const std::vector<NodeEntry>& outputs() const {
152156
return outputs_;
@@ -161,8 +165,10 @@ class IndexedGraph {
161165
explicit IndexedGraph(const Graph& other);
162166
// node pointers in CSR structure.
163167
std::vector<Node> nodes_;
164-
// index to input nodes
168+
// index all to input nodes
165169
std::vector<uint32_t> input_nodes_;
170+
// index to mutable input nodes
171+
std::unordered_set<uint32_t> mutable_input_nodes_;
166172
// space to store the outputs entries
167173
std::vector<NodeEntry> outputs_;
168174
// mapping from node to index.

nnvm/include/nnvm/op.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,12 +368,14 @@ inline Op& Op::set_attr_parser(std::function<void (NodeAttrs* attrs)> fn) { //
368368
// member functions of OpMap
369369
template<typename ValueType>
370370
inline int OpMap<ValueType>::count(const Op* op) const {
371+
if (op == nullptr) return 0;
371372
const uint32_t idx = op->index_;
372373
return idx < data_.size() ? data_[idx].second : 0;
373374
}
374375

375376
template<typename ValueType>
376377
inline const ValueType& OpMap<ValueType>::operator[](const Op* op) const {
378+
CHECK(op != nullptr);
377379
const uint32_t idx = op->index_;
378380
CHECK(idx < data_.size() && data_[idx].second)
379381
<< "Attribute " << attr_name_
@@ -383,6 +385,7 @@ inline const ValueType& OpMap<ValueType>::operator[](const Op* op) const {
383385

384386
template<typename ValueType>
385387
inline const ValueType& OpMap<ValueType>::get(const Op* op, const ValueType& def_value) const {
388+
if (op == nullptr) return def_value;
386389
const uint32_t idx = op->index_;
387390
if (idx < data_.size() && data_[idx].second) {
388391
return data_[idx].first;

nnvm/src/core/graph.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
* \brief Graph node data structure.
55
*/
66
#include <nnvm/graph.h>
7+
#include <nnvm/op_attr_types.h>
78
#include <limits>
89

910
namespace nnvm {
@@ -57,12 +58,20 @@ IndexedGraph::IndexedGraph(const Graph &g) {
5758
node2index_.at(e.node.get()), e.index, e.version});
5859
}
5960

61+
static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
62+
std::unordered_set<uint32_t> mutable_inputs;
6063
// setup array view
6164
// input_entries_ and control_rptr must not change after this step.
6265
const NodeEntry* iptr = dmlc::BeginPtr(input_entries_);
6366
for (size_t nid = 0; nid < nodes_.size(); ++nid) {
6467
nodes_[nid].inputs = array_view<NodeEntry>(
6568
iptr + inputs_rptr[nid], iptr + inputs_rptr[nid + 1]);
69+
if (nodes_[nid].source->op != nullptr &&
70+
fmutate_inputs.count(nodes_[nid].source->op)) {
71+
for (uint32_t i : fmutate_inputs[nodes_[nid].source->op](nodes_[nid].source->attrs)) {
72+
mutable_input_nodes_.insert(nodes_[nid].inputs[i].node_id);
73+
}
74+
}
6675
}
6776
const uint32_t* cptr = dmlc::BeginPtr(control_deps_);
6877
for (size_t nid = 0; nid < nodes_.size(); ++nid) {

nnvm/src/pass/infer_shape_type.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ NNVM_REGISTER_PASS(InferShape)
101101
.set_body([](Graph ret) {
102102
return InferAttr<TShape>(
103103
std::move(ret), TShape(),
104-
"FInferShape", "shape_args", "shape_attr_key",
104+
"FInferShape", "shape_inputs", "shape_attr_key",
105105
"shape", "shape_num_unknown_nodes",
106106
[](const TShape& s) { return s.ndim() == 0; });
107107
})

0 commit comments

Comments
 (0)