Skip to content

Commit 2f837ab

Browse files
committed
[PASS] add plan memory (apache#19)
1 parent a33e9ce commit 2f837ab

File tree

9 files changed

+411
-3
lines changed

9 files changed

+411
-3
lines changed

nnvm/include/nnvm/graph.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,10 @@ class IndexedGraph {
147147
inline const std::vector<uint32_t>& arg_nodes() const {
148148
return arg_nodes_;
149149
}
150+
/*! \return list of output entries */
151+
inline const std::vector<NodeEntry>& outputs() const {
152+
return outputs_;
153+
}
150154

151155
private:
152156
friend class Graph;
@@ -159,6 +163,8 @@ class IndexedGraph {
159163
std::vector<Node> nodes_;
160164
// index to argument nodes
161165
std::vector<uint32_t> arg_nodes_;
166+
// space to store the outputs entries
167+
std::vector<NodeEntry> outputs_;
162168
// mapping from node to index.
163169
std::unordered_map<const nnvm::Node*, uint32_t> node2index_;
164170
// CSR pointer of node entries

nnvm/include/nnvm/graph_attr_types.h

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ using DTypeVector = std::vector<int>;
6060
*
6161
* \code
6262
* Graph g = ApplyPass(src_graph, {"PlaceDevice"});
63-
* const &device = g.GetAttr<DeviceVector>("dtype");
63+
* const &device = g.GetAttr<DeviceVector>("device");
6464
* // get device by node_id
6565
* int device_type = device[g.indexed_graph().node_id(my_node)];
6666
* \endcode
@@ -75,6 +75,21 @@ using DeviceVector = std::vector<int>;
7575
*/
7676
using DeviceAssignMap = std::unordered_map<std::string, int>;
7777

78+
/*!
79+
* \brief The result holder of storage id of each NodeEntry in the graph.
80+
*
81+
* \note Stored under graph.attrs["storage"], provided by Pass "PlanMemory"
82+
* Storage id is a continuous integer.
83+
* If the storage id is -1 then the storage is not assigned.
84+
*
85+
* \code
86+
* Graph g = ApplyPass(src_graph, {"PlanMemory"});
87+
* const &storage = g.GetAttr<StorageVector>("storage");
88+
* // get storage id by entry
89+
* int storage_id = storage[g.indexed_graph().entry_id(my_entry)];
90+
* \endcode
91+
*/
92+
using StorageVector = std::vector<int>;
7893

7994
} // namespace nnvm
8095

nnvm/include/nnvm/op_attr_types.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <vector>
1010
#include <string>
11+
#include <utility>
1112
#include <functional>
1213
#include "./base.h"
1314
#include "./tuple.h"
@@ -93,6 +94,20 @@ using FInferType = FInferNodeEntryAttr<int>;
9394
*/
9495
using TIsBackwardOp = bool;
9596

97+
/*!
98+
* \brief Get possible inplace options.
99+
* This function enables optimization to reuse memory of inputs in output.
100+
* \param attrs The attributes of the node
101+
* \param in_data The input data.
102+
* \param out_data The output data.
103+
* \return list of pair of that maps input->output,
104+
* indicating possible in place operations.
105+
*
106+
* \note Register under "FInplaceOption", by default no inplace can happen.
107+
*/
108+
using FInplaceOption = std::function<
109+
std::vector<std::pair<int, int> > (const NodeAttrs& attrs)>;
110+
96111
} // namespace nnvm
97112

98113
#endif // NNVM_OP_ATTR_TYPES_H_

nnvm/src/core/graph.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@ IndexedGraph::IndexedGraph(const Graph &g) {
5252
control_rptr.push_back(control_deps_.size());
5353
});
5454

55+
for (const auto& e : g.outputs) {
56+
outputs_.emplace_back(NodeEntry{
57+
node2index_.at(e.node.get()), e.index, e.version});
58+
}
59+
5560
// setup array view
5661
// input_entries_ and control_rptr must not change after this step.
5762
const NodeEntry* iptr = dmlc::BeginPtr(input_entries_);

nnvm/src/example/operator.cc

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ using nnvm::FListInputNames;
1414
using nnvm::FMutateInput;
1515
using nnvm::FInferShape;
1616
using nnvm::FInferType;
17+
using nnvm::FInplaceOption;
1718
using nnvm::NodeAttrs;
1819
using nnvm::TShape;
1920
using nnvm::array_view;
@@ -32,6 +33,10 @@ inline bool SameShape(const NodeAttrs& attrs,
3233
return true;
3334
}
3435

36+
inline std::vector<std::pair<int, int> > InplaceIn0Out0(const NodeAttrs& attrs) {
37+
return {{0, 0}};
38+
}
39+
3540
// simple demonstration of reshape.
3641
NNVM_REGISTER_OP(reshape)
3742
.describe("reshape source to target shape")
@@ -55,7 +60,8 @@ NNVM_REGISTER_OP(reshape)
5560
CHECK_EQ(ishape[0]->Size(), target.Size())
5661
<< "Reshape op: source target shape mismatch";
5762
return true;
58-
});
63+
})
64+
.attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0);
5965

6066

6167
NNVM_REGISTER_OP(cast)
@@ -82,7 +88,8 @@ NNVM_REGISTER_OP(cast)
8288
NNVM_REGISTER_OP(add)
8389
.describe("add two data together")
8490
.set_num_inputs(2)
85-
.attr<FInferShape>("FInferShape", SameShape);
91+
.attr<FInferShape>("FInferShape", SameShape)
92+
.attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0);
8693

8794
NNVM_REGISTER_OP(__add_symbol__)
8895
.describe("Alias of add")

nnvm/src/pass/graph_algorithm.h

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
/*!
2+
* Copyright (c) 2016 by Contributors
3+
* \file graph_algorithm.h
4+
* \brief This header contains graph algorithms on StaticGraph.
5+
* It is used compute informations such as whether two
6+
* operations can run in parallel, and helps allocation.
7+
*/
8+
#ifndef NNVM_PASS_GRAPH_ALGORITHM_H_
9+
#define NNVM_PASS_GRAPH_ALGORITHM_H_
10+
11+
#include <nnvm/graph.h>
12+
#include <vector>
13+
14+
namespace nnvm {
15+
namespace pass {
16+
17+
/*!
18+
* \brief Find best path in the DAG, with reward defined
19+
* by sum of reward of each node along the path.
20+
* \param graph the original static graph.
21+
* \param topo_order topo order of the nodes in the graph.
22+
* \param node_reward the reward of each node.
23+
* \param path the output path of nodes.
24+
* \return the total reward of best path.
25+
*/
26+
inline uint32_t FindBestPath(
27+
const IndexedGraph& graph,
28+
const std::vector<uint32_t>& node_reward,
29+
std::vector<uint32_t>* path) {
30+
const uint32_t num_nodes = static_cast<uint32_t>(graph.num_nodes());
31+
CHECK_EQ(num_nodes, node_reward.size());
32+
33+
std::vector<uint32_t> best_reward(node_reward.size(), 0);
34+
std::vector<uint32_t> next_node(node_reward.size(), num_nodes);
35+
uint32_t best_solution = 0, best_start_node = 0;
36+
37+
// traverse in reverse topo order
38+
for (uint32_t i = static_cast<uint32_t>(graph.num_nodes()); i != 0; --i) {
39+
const uint32_t nid = i - 1;
40+
best_reward[nid] += node_reward[nid];
41+
if (best_reward[nid] > best_solution) {
42+
best_solution = best_reward[nid];
43+
best_start_node = nid;
44+
}
45+
for (const auto& e : graph[nid].inputs) {
46+
const uint32_t prev = e.node_id;
47+
if (best_reward[nid] > best_reward[prev]) {
48+
best_reward[prev] = best_reward[nid];
49+
next_node[prev] = nid;
50+
}
51+
}
52+
}
53+
path->clear();
54+
uint32_t reward = 0;
55+
for (uint32_t nid = best_start_node; nid < num_nodes; nid = next_node[nid]) {
56+
path->push_back(nid); reward += node_reward[nid];
57+
}
58+
CHECK_EQ(reward, best_solution);
59+
return best_solution;
60+
}
61+
62+
/*!
63+
* \brief Color the nodes in the graph into index.
64+
* The coloring algorithm tries to assign node group
65+
* such that node in the same group cannot run in parallel.
66+
*
67+
* \param graph the original indexed graph.
68+
* \param node_importance The importance of the node
69+
* \param max_ncolor maximum number of colors allowed.
70+
* \param color the color index of each of the node.
71+
* \return the total number of colors.
72+
*/
73+
inline uint32_t ColorNodeGroup(
74+
const IndexedGraph &graph,
75+
std::vector<uint32_t> node_importance,
76+
uint32_t max_ncolor,
77+
std::vector<uint32_t> *color) {
78+
CHECK_NE(max_ncolor, 0);
79+
CHECK_EQ(graph.num_nodes(), node_importance.size());
80+
81+
color->clear();
82+
color->resize(graph.num_nodes(), max_ncolor);
83+
uint32_t cindex;
84+
// greedy algorithm, every time
85+
// find a path with best reward and assign a new color
86+
// All the nodes in the path cannot run in parallel.
87+
for (cindex = 0; cindex < max_ncolor - 1; ++cindex) {
88+
std::vector<uint32_t> path;
89+
uint32_t reward = FindBestPath(graph, node_importance, &path);
90+
if (reward == 0) break;
91+
for (uint32_t nid : path) {
92+
if (node_importance[nid] != 0) {
93+
CHECK_EQ(color->at(nid), max_ncolor);
94+
color->at(nid) = cindex;
95+
// make the importance 0 after color is decided.
96+
node_importance[nid] = 0;
97+
}
98+
}
99+
}
100+
// assign i for rest of the node
101+
for (uint32_t i = 0; i < graph.num_nodes(); ++i) {
102+
if (color->at(i) == max_ncolor) {
103+
color->at(i) = cindex;
104+
}
105+
}
106+
return cindex + 1;
107+
}
108+
109+
} // namespace pass
110+
} // namespace nnvm
111+
112+
#endif // NNVM_PASS_GRAPH_ALGORITHM_H_

nnvm/src/pass/infer_shape_type.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ NNVM_REGISTER_PASS(InferType)
121121

122122
DMLC_JSON_ENABLE_ANY(ShapeVector, list_shape);
123123
DMLC_JSON_ENABLE_ANY(DTypeVector, list_int);
124+
DMLC_JSON_ENABLE_ANY(size_t, size_t);
124125

125126
} // namespace pass
126127
} // namespace nnvm

0 commit comments

Comments
 (0)