Skip to content

Commit 8b3d7da

Browse files
jermainewangtqchen
authored andcommitted
ApplyPass -> ApplyPasses; Refactored infer pass; (apache#43)
* ApplyPass -> ApplyPasses; Refactored infer pass; * lint fix
1 parent fb5b7b5 commit 8b3d7da

File tree

9 files changed

+67
-49
lines changed

9 files changed

+67
-49
lines changed

nnvm/include/nnvm/c_api.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -329,16 +329,16 @@ NNVM_DLL int NNGraphSetNodeEntryListAttr_(GraphHandle handle,
329329
const char* key,
330330
SymbolHandle list);
331331
/*!
332-
* \brief Apply pass on the src graph.
332+
* \brief Apply passes on the src graph.
333333
* \param src The source graph handle.
334334
* \param num_pass The number of pass to be applied.
335335
* \param pass_names The names of the pass.
336336
* \param dst The result graph.
337337
* \return 0 when success, -1 when failure happens
338338
*/
339-
NNVM_DLL int NNGraphApplyPass(GraphHandle src,
340-
nn_uint num_pass,
341-
const char** pass_names,
342-
GraphHandle *dst);
339+
NNVM_DLL int NNGraphApplyPasses(GraphHandle src,
340+
nn_uint num_pass,
341+
const char** pass_names,
342+
GraphHandle *dst);
343343

344344
#endif // NNVM_C_API_H_

nnvm/include/nnvm/graph.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,11 +179,11 @@ class IndexedGraph {
179179
* \param other The source graph.
180180
*/
181181
explicit IndexedGraph(const Graph& other);
182-
// node pointers in CSR structure.
182+
// Node pointers in CSR structure.
183183
std::vector<Node> nodes_;
184-
// index all to input nodes
184+
// Index to all input nodes.
185185
std::vector<uint32_t> input_nodes_;
186-
// index to mutable input nodes
186+
// Index to all mutable input nodes.
187187
std::unordered_set<uint32_t> mutable_input_nodes_;
188188
// space to store the outputs entries
189189
std::vector<NodeEntry> outputs_;

nnvm/include/nnvm/graph_attr_types.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ namespace nnvm {
1818
* \note Stored under ret.attrs["json"], provided by Pass "SaveJSON"
1919
2020
* \code
21-
* Graph ret = ApplyPass(src_graph, {"SaveJSON"});
21+
* Graph ret = ApplyPass(src_graph, "SaveJSON");
2222
* const JSONString& json = ret.GetAttr<JSONString>("shape");
2323
* \endcode
2424
*/
@@ -29,7 +29,7 @@ using JSONString = std::string;
2929
* \note Stored under graph.attrs["shape"], provided by Pass "InferShape"
3030
*
3131
* \code
32-
* Graph g = ApplyPass(src_graph, {"InferShape"});
32+
* Graph g = ApplyPass(src_graph, "InferShape");
3333
* const ShapeVector& shapes = g.GetAttr<ShapeVector>("shape");
3434
* // get shape by entry id
3535
* TShape entry_shape = shapes[g.indexed_graph().entry_id(my_entry)];
@@ -44,7 +44,7 @@ using ShapeVector = std::vector<TShape>;
4444
* \note Stored under graph.attrs["dtype"], provided by Pass "InferType"
4545
*
4646
* \code
47-
* Graph g = ApplyPass(src_graph, {"InferType"});
47+
* Graph g = ApplyPass(src_graph, "InferType");
4848
* const DTypeVector& types = g.GetAttr<DTypeVector>("dtype");
4949
* // get shape by entry id
5050
* int entry_type = dtypes[g.indexed_graph().entry_id(my_entry)];
@@ -59,7 +59,7 @@ using DTypeVector = std::vector<int>;
5959
* \note Stored under graph.attrs["device"], provided by Pass "PlaceDevice"
6060
*
6161
* \code
62-
* Graph g = ApplyPass(src_graph, {"PlaceDevice"});
62+
* Graph g = ApplyPass(src_graph, "PlaceDevice");
6363
* const &device = g.GetAttr<DeviceVector>("device");
6464
* // get device by node_id
6565
* int device_type = device[g.indexed_graph().node_id(my_node)];
@@ -83,7 +83,7 @@ using DeviceAssignMap = std::unordered_map<std::string, int>;
8383
* If the storage id is -1 then the storage is not assigned.
8484
*
8585
* \code
86-
* Graph g = ApplyPass(src_graph, {"PlanMemory"});
86+
* Graph g = ApplyPass(src_graph, "PlanMemory");
8787
* const &storage = g.GetAttr<StorageVector>("storage");
8888
* // get storage id by entry
8989
* int storage_id = storage[g.indexed_graph().entry_id(my_entry)];

nnvm/include/nnvm/pass.h

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,22 @@ typedef std::function<Graph (Graph src)> PassFunction;
2929
/*!
3030
* \brief Apply a series of pass transformations on the input graph.
3131
* \param src The graph to be transformed.
32+
* \param passes A list of pass names to be applied.
33+
* \return The transformed graph
34+
*/
35+
Graph ApplyPasses(Graph src,
36+
const std::vector<std::string>& passes);
37+
38+
/*!
39+
* \brief Apply one pass to the graph.
40+
* \param src The graph to be transformed.
3241
* \param pass The name of pass to be applied.
3342
* \return The transformed graph.
3443
*/
35-
Graph ApplyPass(Graph src,
36-
const std::vector<std::string>& pass);
44+
inline Graph ApplyPass(Graph src, const std::string& pass) {
45+
return ApplyPasses(src, {pass});
46+
}
47+
3748

3849
/*!
3950
* \brief Registry entry for DataIterator factory functions.

nnvm/include/nnvm/pass_functions.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ namespace pass {
2828
inline Graph LoadJSON(const std::string& json_str) {
2929
Graph ret;
3030
ret.attrs["json"] = std::make_shared<any>(json_str);
31-
return ApplyPass(ret, {"LoadJSON"});
31+
return ApplyPass(ret, "LoadJSON");
3232
}
3333

3434
/*!
@@ -37,7 +37,7 @@ inline Graph LoadJSON(const std::string& json_str) {
3737
* \return The json string.
3838
*/
3939
inline std::string SaveJSON(Graph graph) {
40-
Graph ret = ApplyPass(std::move(graph), {"SaveJSON"});
40+
Graph ret = ApplyPass(std::move(graph), "SaveJSON");
4141
return ret.GetAttr<std::string>("json");
4242
}
4343

@@ -52,7 +52,7 @@ inline std::string SaveJSON(Graph graph) {
5252
* \return A graph with proper control flow dependencies added.
5353
*/
5454
inline Graph OrderMutation(Graph src) {
55-
return ApplyPass(std::move(src), {"OrderMutation"});
55+
return ApplyPass(std::move(src), "OrderMutation");
5656
}
5757

5858
/*!
@@ -73,7 +73,7 @@ inline Graph InferShape(Graph graph,
7373
if (shape_attr_key.length() != 0) {
7474
graph.attrs["shape_attr_key"] = std::make_shared<any>(std::move(shape_attr_key));
7575
}
76-
return ApplyPass(std::move(graph), {"InferShape"});
76+
return ApplyPass(std::move(graph), "InferShape");
7777
}
7878

7979
/*!
@@ -94,7 +94,7 @@ inline Graph InferType(Graph graph,
9494
if (dtype_attr_key.length() != 0) {
9595
graph.attrs["dtype_attr_key"] = std::make_shared<any>(std::move(dtype_attr_key));
9696
}
97-
return ApplyPass(std::move(graph), {"InferType"});
97+
return ApplyPass(std::move(graph), "InferType");
9898
}
9999

100100
/*!
@@ -118,7 +118,7 @@ inline Graph PlaceDevice(Graph graph,
118118
graph.attrs["device_group_attr_key"] = std::make_shared<any>(std::move(device_group_attr_key));
119119
graph.attrs["device_assign_map"] = std::make_shared<any>(std::move(device_assign_map));
120120
graph.attrs["device_copy_op"] = std::make_shared<any>(std::move(device_copy_op));
121-
return ApplyPass(std::move(graph), {"PlaceDevice"});
121+
return ApplyPass(std::move(graph), "PlaceDevice");
122122
}
123123

124124
/*!
@@ -149,7 +149,7 @@ inline Graph Gradient(
149149
graph.attrs["grad_mirror_fun"] = std::make_shared<any>(mirror_fun);
150150
}
151151

152-
return ApplyPass(std::move(graph), {"Gradient"});
152+
return ApplyPass(std::move(graph), "Gradient");
153153
}
154154

155155
} // namespace pass

nnvm/python/nnvm/graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def apply(self, passes):
113113
cpass = c_array(ctypes.c_char_p, [c_str(key) for key in passes])
114114
ghandle = GraphHandle()
115115
npass = nn_uint(len(passes))
116-
check_call(_LIB.NNGraphApplyPass(self.handle, npass, cpass, ctypes.byref(ghandle)))
116+
check_call(_LIB.NNGraphApplyPasses(self.handle, npass, cpass, ctypes.byref(ghandle)))
117117
return Graph(ghandle)
118118

119119

nnvm/src/c_api/c_api_graph.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,17 +82,17 @@ int NNGraphGetJSONAttr(GraphHandle handle,
8282
API_END();
8383
}
8484

85-
int NNGraphApplyPass(GraphHandle src,
86-
nn_uint num_pass,
87-
const char** pass_names,
88-
GraphHandle *dst) {
85+
int NNGraphApplyPasses(GraphHandle src,
86+
nn_uint num_pass,
87+
const char** pass_names,
88+
GraphHandle *dst) {
8989
Graph* g = new Graph();
9090
API_BEGIN();
9191
std::vector<std::string> vpass;
9292
for (nn_uint i = 0; i < num_pass; ++i) {
9393
vpass.emplace_back(std::string(pass_names[i]));
9494
}
95-
*g = ApplyPass(*static_cast<Graph*>(src), vpass);
95+
*g = ApplyPasses(*static_cast<Graph*>(src), vpass);
9696
*dst = g;
9797
API_END_HANDLE_ERROR(delete g);
9898
}

nnvm/src/core/pass.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ const PassFunctionReg* FindPassDep(const std::string&attr_name) {
2222
return nullptr;
2323
}
2424

25-
Graph ApplyPass(Graph g,
26-
const std::vector<std::string>& pass) {
25+
Graph ApplyPasses(Graph g,
26+
const std::vector<std::string>& pass) {
2727
std::vector<const PassFunctionReg*> fpass;
2828
for (auto& name : pass) {
2929
auto* reg = dmlc::Registry<PassFunctionReg>::Find(name);

nnvm/src/pass/infer_shape_type.cc

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ namespace {
1313

1414
template<typename AttrType, typename IsNone>
1515
Graph InferAttr(Graph &&ret,
16-
const AttrType def_value,
16+
const AttrType default_val,
1717
const char* infer_name,
1818
const char* input_name,
1919
const char* attr_key_name,
@@ -23,16 +23,16 @@ Graph InferAttr(Graph &&ret,
2323
using AttrVector = std::vector<AttrType>;
2424
const IndexedGraph& idx = ret.indexed_graph();
2525
static auto& finfer_shape =
26-
Op::GetAttr<FInferNodeEntryAttr<AttrType> >(infer_name);
26+
Op::GetAttr<FInferNodeEntryAttr<AttrType>>(infer_name);
2727
static auto& backward_map =
2828
Op::GetAttr<FBackwardOutToInIndex>("FBackwardOutToInIndex");
2929
// reshape shape vector
30-
AttrVector rshape(idx.num_node_entries(), def_value);
30+
AttrVector rshape(idx.num_node_entries(), default_val);
3131

3232
if (ret.attrs.count(input_name) != 0) {
3333
const AttrVector& shape_args = ret.GetAttr<AttrVector>(input_name);
3434
CHECK_LE(shape_args.size(), idx.input_nodes().size())
35-
<< "shape args is more than number of arguments";
35+
<< "More provided shapes than number of arguments.";
3636
for (size_t i = 0; i < shape_args.size(); ++i) {
3737
rshape[idx.entry_id(idx.input_nodes()[i], 0)] = shape_args[i];
3838
}
@@ -46,47 +46,54 @@ Graph InferAttr(Graph &&ret,
4646
ret.attrs.erase(attr_key_name);
4747
}
4848

49-
// temp space for shape inference.
49+
// Temp space for shape inference.
5050
std::vector<AttrType> ishape, oshape;
5151
// number of completed nodes
5252
size_t num_unknown = 0;
5353
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
5454
const auto& inode = idx[nid];
55-
uint32_t num_inputs = inode.inputs.size();
56-
uint32_t num_outputs = inode.source->num_outputs();
55+
const uint32_t num_inputs = inode.inputs.size();
56+
const uint32_t num_outputs = inode.source->num_outputs();
5757
if (inode.source->is_variable()) {
58-
if (shape_attr_key.length() != 0 && fis_none(rshape[idx.entry_id(nid, 0)])) {
58+
// Variable node. No operator. Only one output entry.
59+
CHECK(inode.source->op() == nullptr);
60+
CHECK_EQ(num_outputs, 1);
61+
const uint32_t out_ent_id = idx.entry_id(nid, 0);
62+
if (shape_attr_key.length() != 0 && fis_none(rshape[out_ent_id])) {
5963
auto it = inode.source->attrs.dict.find(shape_attr_key);
6064
if (it != inode.source->attrs.dict.end()) {
61-
CHECK_EQ(num_outputs, 1);
6265
std::istringstream is(it->second);
63-
CHECK(is >> rshape[idx.entry_id(nid, 0)]) << "Invalid attribute";
66+
CHECK(is >> rshape[out_ent_id]) << "Invalid attribute";
6467
}
6568
}
66-
continue;
67-
}
68-
if (finfer_shape.count(inode.source->op())) {
69-
ishape.resize(num_inputs, def_value);
69+
} else if (finfer_shape.count(inode.source->op())) {
70+
// Forward operator inference.
71+
ishape.resize(num_inputs, default_val);
7072
for (uint32_t i = 0; i < ishape.size(); ++i) {
7173
ishape[i] = rshape[idx.entry_id(inode.inputs[i])];
7274
}
73-
oshape.resize(num_outputs, def_value);
75+
oshape.resize(num_outputs, default_val);
7476
for (uint32_t i = 0; i < oshape.size(); ++i) {
7577
oshape[i] = rshape[idx.entry_id(nid, i)];
7678
}
77-
num_unknown +=
78-
!(finfer_shape[inode.source->op()](inode.source->attrs, &ishape, &oshape));
79+
// Call inference function of the operator.
80+
bool forward_known = finfer_shape[inode.source->op()](
81+
inode.source->attrs, &ishape, &oshape);
82+
num_unknown += !forward_known;
83+
// Save to the result map.
7984
for (uint32_t i = 0; i < num_inputs; ++i) {
8085
rshape[idx.entry_id(inode.inputs[i])] = ishape[i];
8186
}
8287
for (uint32_t i = 0; i < num_outputs; ++i) {
8388
rshape[idx.entry_id(nid, i)] = oshape[i];
8489
}
8590
} else if (backward_map.count(inode.source->op())) {
86-
// backward operator inference.
91+
// Backward operator inference.
8792
CHECK_GE(inode.control_deps.size(), 1)
8893
<< "BackwardOp need to have control_deps to its forward op";
89-
const auto& fnode = idx[inode.control_deps[0]];
94+
const IndexedGraph::Node& fnode = idx[inode.control_deps[0]];
95+
// Inference the outputs of backward operator (equal to the inputs
96+
// of its corresponding forward operator).
9097
std::vector<uint32_t> out_map =
9198
backward_map[inode.source->op()](inode.source->attrs);
9299
bool known = true;

0 commit comments

Comments
 (0)