Skip to content

Commit f995214

Browse files
tqchensergei-mironov
authored andcommitted
[Pass] Finish infershape testcase (apache#16)
1 parent 95e2400 commit f995214

File tree

9 files changed

+161
-14
lines changed

9 files changed

+161
-14
lines changed

nnvm/include/nnvm/pass.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,15 @@ namespace nnvm {
2323
* \param src The graph to be transformed.
2424
* \return The generated graph.
2525
*/
26-
typedef std::function<Graph (const Graph& src)> PassFunction;
26+
typedef std::function<Graph (Graph src)> PassFunction;
2727

2828
/*!
2929
* \brief Apply a series of pass transformations on g.
3030
* \param src The graph to be transformed.
3131
* \param pass The name of pass to be applied.
3232
* \return The transformed graph
3333
*/
34-
Graph ApplyPass(const Graph& src,
34+
Graph ApplyPass(Graph src,
3535
const std::vector<std::string>& pass);
3636

3737
/*!

nnvm/include/nnvm/pass_functions.h

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/*!
2+
* Copyright (c) 2016 by Contributors
3+
* \file pass_functions.h
4+
* \brief Pass functions that simply redirect the calls to ApplyPass
5+
*
6+
* This file serves as documentation on how to use functions implemented in "src/pass".
7+
* It is totally optional to add these functions when you add a new pass, since
8+
* ApplyPass can be directly called.
9+
*/
10+
#ifndef NNVM_PASS_FUNCTIONS_H_
11+
#define NNVM_PASS_FUNCTIONS_H_
12+
13+
#include <string>
14+
#include <memory>
15+
#include "./base.h"
16+
#include "./pass.h"
17+
#include "./graph_attr_types.h"
18+
19+
namespace nnvm {
20+
namespace pass {
21+
22+
/*!
23+
* \brief Load a graph from JSON string, redirects to "LoadJSON" pass.
24+
* \param json_str The json string.
25+
* \return Loaded graph.
26+
*/
27+
inline Graph LoadJSON(const std::string& json_str) {
28+
Graph ret;
29+
ret.attrs["json"] = std::make_shared<any>(json_str);
30+
return ApplyPass(ret, {"LoadJSON"});
31+
}
32+
33+
/*!
34+
* \brief Save a graph to json, redirects to "SaveJSON" pass.
35+
* \param graph The to be saved.
36+
* \return The json string.
37+
*/
38+
inline std::string SaveJSON(Graph graph) {
39+
Graph ret = ApplyPass(std::move(graph), {"SaveJSON"});
40+
return ret.GetAttr<std::string>("json");
41+
}
42+
43+
/*!
44+
* \brief Add control flow dependencies between nodes
45+
* To correctly order mutation and read to resolve
46+
* write after read problem and read after write problems.
47+
* \param src source graph
48+
* \return A graph that added control flow dependencies.
49+
*/
50+
inline Graph OrderMutation(Graph src) {
51+
return ApplyPass(std::move(src), {"OrderMutation"});
52+
}
53+
54+
/*!
55+
* \brief Infer shapes in the graph given the information.
56+
* \param graph source graph
57+
* \param shape_args The shapes of aruguments to the graph.
58+
* \param shape_attr_key The key to the node attribute that can indicate shape.
59+
* \return A graph with new attribute "shape" containing inferred shape of each NodeEntry.
60+
* The index of ShapeVector is given by graph.indexed_graph().entry_id
61+
*/
62+
inline Graph InferShape(Graph graph,
63+
ShapeVector shape_args = {},
64+
std::string shape_attr_key = "") {
65+
if (shape_args.size() != 0) {
66+
graph.attrs["shape_args"] = std::make_shared<any>(std::move(shape_args));
67+
}
68+
if (shape_attr_key.length() != 0) {
69+
graph.attrs["shape_attr_key"] = std::make_shared<any>(std::move(shape_attr_key));
70+
}
71+
return ApplyPass(std::move(graph), {"InferShape"});
72+
}
73+
74+
} // namespace pass
75+
} // namespace nnvm
76+
#endif // NNVM_PASS_FUNCTIONS_H_

nnvm/src/core/pass.cc

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ const PassFunctionReg* FindPassDep(const std::string&attr_name) {
2222
return nullptr;
2323
}
2424

25-
Graph ApplyPass(const Graph& src,
25+
Graph ApplyPass(Graph g,
2626
const std::vector<std::string>& pass) {
2727
std::vector<const PassFunctionReg*> fpass;
2828
for (auto& name : pass) {
@@ -32,11 +32,9 @@ Graph ApplyPass(const Graph& src,
3232
fpass.push_back(reg);
3333
}
3434

35-
Graph g;
36-
const Graph* s = &src;
3735
for (auto r : fpass) {
3836
for (auto& dep : r->graph_attr_dependency) {
39-
if (s->attrs.count(dep) == 0) {
37+
if (g.attrs.count(dep) == 0) {
4038
auto* pass_dep = FindPassDep(dep);
4139
std::string msg;
4240
if (pass_dep != nullptr) {
@@ -48,8 +46,7 @@ Graph ApplyPass(const Graph& src,
4846
<< msg;
4947
}
5048
}
51-
g = r->body(*s);
52-
s = &g;
49+
g = r->body(std::move(g));
5350
}
5451
return g;
5552
}

nnvm/src/example/operator.cc

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <nnvm/base.h>
55
#include <nnvm/op.h>
66
#include <nnvm/op_attr_types.h>
7+
#include <nnvm/node.h>
78
#include <nnvm/graph_attr_types.h>
89
#include <utility>
910

@@ -30,6 +31,31 @@ inline bool SameShape(const NodeAttrs& attrs,
3031
return true;
3132
}
3233

34+
// simple demonstration of reshape.
35+
NNVM_REGISTER_OP(reshape)
36+
.describe("reshape source to target shape")
37+
.set_num_inputs(1)
38+
.set_attr_parser(
39+
[](NodeAttrs* attrs) {
40+
// parse attr parser to get target attribute
41+
TShape target;
42+
std::istringstream is(attrs->dict.at("target"));
43+
CHECK(is >> target);
44+
attrs->parsed = std::move(target);
45+
})
46+
.attr<FInferShape>(
47+
"FInferShape", [] (const NodeAttrs& attrs,
48+
array_view<TShape*> ishape,
49+
array_view<TShape*> oshape) {
50+
// get parsed attribute
51+
const TShape& target = nnvm::get<TShape>(attrs.parsed);
52+
*oshape[0] = target;
53+
if (ishape[0]->ndim() == 0) return false;
54+
CHECK_EQ(ishape[0]->Size(), target.Size())
55+
<< "Reshape op: source target shape mismatch";
56+
return true;
57+
});
58+
3359
NNVM_REGISTER_OP(add)
3460
.describe("add two data together")
3561
.set_num_inputs(2)

nnvm/src/pass/infer_shape.cc

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,42 @@
1010
namespace nnvm {
1111
namespace pass {
1212

13-
Graph InferShape(const Graph& src) {
14-
Graph ret = src;
13+
Graph InferShape(Graph ret) {
1514
const IndexedGraph& idx = ret.indexed_graph();
1615
static auto& finfer_shape = Op::GetAttr<FInferShape>("FInferShape");
1716
// reshape shape vector
1817
ShapeVector rshape(idx.num_node_entries());
18+
19+
if (ret.attrs.count("shape_args") != 0) {
20+
const ShapeVector& shape_args = ret.GetAttr<ShapeVector>("shape_args");
21+
CHECK_LE(shape_args.size(), idx.arg_nodes().size())
22+
<< "shape args is more than number of arguments";
23+
for (size_t i = 0; i < shape_args.size(); ++i) {
24+
rshape[idx.entry_id(idx.arg_nodes()[i], 0)] = shape_args[i];
25+
}
26+
}
27+
std::string shape_attr_key;
28+
if (ret.attrs.count("shape_attr_key") != 0) {
29+
shape_attr_key = ret.GetAttr<std::string>("shape_attr_key");
30+
}
31+
1932
// temp space for shape inference.
2033
std::vector<TShape*> ishape, oshape;
2134
// number of completed nodes
2235
size_t num_known = 0;
2336
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
2437
const auto& inode = idx[nid];
25-
if (inode.source->is_variable()) continue;
38+
if (inode.source->is_variable()) {
39+
if (shape_attr_key.length() != 0) {
40+
auto it = inode.source->attrs.dict.find(shape_attr_key);
41+
if (it != inode.source->attrs.dict.end()) {
42+
CHECK_EQ(inode.source->num_outputs(), 1);
43+
std::istringstream is(it->second);
44+
CHECK(is >> rshape[idx.entry_id(nid, 0)]) << "Invalid shape attribute";
45+
}
46+
}
47+
continue;
48+
}
2649
ishape.resize(inode.inputs.size());
2750
for (uint32_t i = 0; i < ishape.size(); ++i) {
2851
ishape[i] = &rshape[idx.entry_id(inode.inputs[i])];
@@ -43,5 +66,13 @@ Graph InferShape(const Graph& src) {
4366
return ret;
4467
}
4568

69+
NNVM_REGISTER_PASS(InferShape)
70+
.describe("Infer the shape of each node entries.")
71+
.set_body(InferShape)
72+
.set_change_graph(false)
73+
.provide_graph_attr("shape");
74+
75+
DMLC_JSON_ENABLE_ANY(ShapeVector, list_shape);
76+
4677
} // namespace pass
4778
} // namespace nnvm

nnvm/src/pass/order_mutation.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/*!
22
* Copyright (c) 2016 by Contributors
3-
* \file saveload_json.cc
3+
* \file order_mutation.cc
44
* \brief Add control flow dependencies between nodes
55
* To correctly order mutation and read to resolve
66
* write after read problem and read after write problems.

nnvm/src/pass/saveload_json.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ struct JSONGraph {
149149
};
150150

151151
// Load a graph from JSON file.
152-
Graph LoadJSON(const Graph& src) {
152+
Graph LoadJSON(Graph src) {
153153
CHECK_NE(src.attrs.count("json"), 0)
154154
<< "Load JSON require json to be presented.";
155155
const std::string &json_str =
@@ -188,7 +188,7 @@ Graph LoadJSON(const Graph& src) {
188188
}
189189

190190
// save a graph to json
191-
Graph SaveJSON(const Graph& src) {
191+
Graph SaveJSON(Graph src) {
192192
JSONGraph jgraph;
193193
std::unordered_map<Node*, uint32_t> node2index;
194194
jgraph.node_row_ptr.push_back(0);

nnvm/src/test_main.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <nnvm/tuple.h>
55
#include <nnvm/c_api.h>
66
#include <nnvm/graph_attr_types.h>
7+
#include <nnvm/pass_functions.h>
78
#include <dmlc/timer.h>
89
#include <string>
910

nnvm/tests/python/test_graph.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,23 @@ def test_order_mutation_pass():
3535
assert nindex['add1'] in jnodes[nindex['assign']]['control_deps']
3636
assert jnodes[nindex['assign']]['inputs'][0][2] == 1
3737

38+
def test_infer_shape():
39+
x = sym.Variable('x', shape=(4, 2))
40+
y = sym.add(x, x, name='add1')
41+
y = sym.reshape(y, target=(2, 4), name="reshape1")
42+
g = graph.create(y)
43+
g._set_json_attr("shape_attr_key", "shape")
44+
g = g.apply('InferShape')
45+
jgraph = json.loads(g.apply('SaveJSON').json_attr('json'))
46+
jnodes = jgraph['nodes']
47+
jnode_row_ptr = jgraph['node_row_ptr']
48+
nindex = {n['name']: i for i, n in enumerate(jnodes)}
49+
assert g.json_attr('shape')[jnode_row_ptr[nindex["reshape1"]]] == [2, 4]
50+
assert g.json_attr('shape')[jnode_row_ptr[nindex["add1"]]] == [4, 2]
51+
52+
3853
if __name__ == "__main__":
3954
test_order_mutation_pass()
4055
test_graph_json_attr()
4156
test_json_pass()
57+
test_infer_shape()

0 commit comments

Comments
 (0)