Skip to content

Commit 40bc10f

Browse files
committed
[PASS] SimplifyBatchNorm->SimplifyInference, remove dropout (apache#24)
1 parent 215693d commit 40bc10f

File tree

4 files changed

+20
-11
lines changed

4 files changed

+20
-11
lines changed

nnvm/python/nnvm/compiler/build_module.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from .. import runtime
99

1010
OPT_PASS_LEVEL = {
11-
"SimplifyBatchNormInference": 2,
11+
"SimplifyInference": 2,
1212
"PrecomputePrune": 2,
1313
"OpFusion": 1
1414
}
@@ -115,12 +115,9 @@ def optimize(graph, shape, dtype="float32"):
115115
"""
116116
# pylint: disable=unused-argument
117117
cfg = BuildConfig.current
118-
graph = graph_attr.set_shape_inputs(graph, shape)
119-
graph = graph.apply("InferShape")
120-
if graph.json_attr("shape_num_unknown_nodes"):
121-
raise ValueError("InferShape fails..")
122-
if cfg.opt_level >= OPT_PASS_LEVEL["SimplifyBatchNormInference"]:
123-
graph = graph.apply("SimplifyBatchNormInference")
118+
if cfg.opt_level >= OPT_PASS_LEVEL["SimplifyInference"]:
119+
graph = graph_attr.set_shape_inputs(graph, shape)
120+
graph = graph.apply(["InferShape", "SimplifyInference"])
124121
return graph
125122

126123

nnvm/python/nnvm/top/tensor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@ def _compute(attrs, x, _):
4444

4545
_fschedule_broadcast = tvm.convert(_schedule_broadcast)
4646

47+
# copy
48+
reg.register_compute("copy", _compute_unary(topi.identity))
49+
reg.register_pattern("copy", OpPattern.ELEM_WISE)
50+
reg.register_schedule("copy", _fschedule_broadcast)
51+
4752
# exp
4853
reg.register_compute("exp", _compute_unary(topi.exp))
4954
reg.register_pattern("exp", OpPattern.ELEM_WISE)

nnvm/src/compiler/simplify_batch_norm.cc renamed to nnvm/src/compiler/simplify_inference.cc

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ BatchNormToInferUnpack(const nnvm::NodeAttrs& attrs,
2222
nnvm::NodeEntry moving_mean,
2323
nnvm::NodeEntry moving_var,
2424
TShape dshape) {
25+
CHECK_NE(dshape.ndim(), 0);
2526
CHECK(attrs.op);
2627
static const Op* bn_op = Op::Get("batch_norm");
2728
CHECK(attrs.op == bn_op);
@@ -76,13 +77,14 @@ BatchNormToInferUnpack(const nnvm::NodeAttrs& attrs,
7677
return {out, undef, undef};
7778
}
7879

79-
Graph SimplifyBatchNormInference(nnvm::Graph src) {
80+
Graph SimplifyInference(nnvm::Graph src) {
8081
// Get attributes from the graph
8182
const IndexedGraph& idx = src.indexed_graph();
8283
const ShapeVector& shape_vec = src.GetAttr<ShapeVector>("shape");
8384
auto transform = [&](uint32_t nid, const Node* n, std::vector<NodeEntry>* ret) {
8485
if (n->is_variable()) return false;
8586
static const Op* bn_op = Op::Get("batch_norm");
87+
static const Op* dropout_op = Op::Get("dropout");
8688
if (n->op() == bn_op) {
8789
*ret = BatchNormToInferUnpack(
8890
n->attrs,
@@ -93,15 +95,19 @@ Graph SimplifyBatchNormInference(nnvm::Graph src) {
9395
n->inputs[4],
9496
shape_vec[idx.entry_id(nid, 0)]);
9597
return true;
98+
} else if (n->op() == dropout_op) {
99+
NodeEntry undef = MakeNode("__undef__", "undef", {});
100+
*ret = {n->inputs[0], undef};
101+
return true;
96102
} else {
97103
return false;
98104
}
99105
};
100106
return GraphTransform(src, transform);
101107
}
102108

103-
NNVM_REGISTER_PASS(SimplifyBatchNormInference)
104-
.set_body(SimplifyBatchNormInference);
109+
NNVM_REGISTER_PASS(SimplifyInference)
110+
.set_body(SimplifyInference);
105111

106112
} // namespace compiler
107113
} // namespace nnvm

nnvm/tests/python/compiler/test_simplify_batchnorm.py renamed to nnvm/tests/python/compiler/test_simplify_inference.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,13 @@ def check(dim, axis, nstep):
3030
for i in range(nstep):
3131
y1 = sym.batch_norm(
3232
y1 + 1, gamma, beta, moving_mean, moving_var, epsilon=eps, axis=axis)
33+
y1 = sym.dropout(y1)
3334
y2 = simple_bn(y2 + 1, gamma, beta, moving_mean, moving_var,
3435
epsilon=eps, axis=axis, shape=ishape["x"])
3536
g = nnvm.graph.create(y1)
3637
g2 = nnvm.graph.create(y2)
3738
graph_attr.set_shape_inputs(g, ishape)
38-
g1 = g.apply("InferShape").apply("SimplifyBatchNormInference")
39+
g1 = g.apply("InferShape").apply("SimplifyInference")
3940
# Some prints for debug
4041
# print(g1.ir())
4142
# assert graph equals as expected

0 commit comments

Comments
 (0)