From 6e0b0d4f1f310b241d87e6f742da2395e834655f Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Fri, 22 Sep 2017 20:48:28 -0700 Subject: [PATCH] [PASS] Improve GraphFuse to include five patterns (#26) --- nnvm/docs/top.rst | 13 ++++- nnvm/include/nnvm/compiler/op_attr_types.h | 19 ++++-- nnvm/python/nnvm/compiler/registry.py | 22 +++++-- nnvm/python/nnvm/top/nn.py | 18 ++++-- nnvm/python/nnvm/top/tensor.py | 40 +++++++------ nnvm/python/nnvm/top/transform.py | 2 +- nnvm/src/compiler/graph_fuse.cc | 31 +++++++--- nnvm/src/compiler/layout_transform.cc | 2 +- nnvm/tests/python/compiler/test_op_fusion.py | 61 ++++++++++++++++++++ 9 files changed, 162 insertions(+), 46 deletions(-) create mode 100644 nnvm/tests/python/compiler/test_op_fusion.py diff --git a/nnvm/docs/top.rst b/nnvm/docs/top.rst index 28b4a07109339..945cef00ad444 100644 --- a/nnvm/docs/top.rst +++ b/nnvm/docs/top.rst @@ -1,7 +1,8 @@ -NNVM Core Primitives -==================== +NNVM Core Tensor Operators +========================== -**Level 1: Basic Ops** +**Level 1: Basic Operators** +This level enables fully connected multi-layer perceptron. .. autosummary:: :nosignatures: @@ -12,12 +13,14 @@ NNVM Core Primitives nnvm.symbol.sigmoid nnvm.symbol.exp nnvm.symbol.log + nnvm.symbol.sqrt nnvm.symbol.elemwise_add nnvm.symbol.elemwise_sub nnvm.symbol.elemwise_mul nnvm.symbol.elemwise_div nnvm.symbol.flatten nnvm.symbol.concatenate + nnvm.symbol.expand_dims nnvm.symbol.split nnvm.symbol.dropout nnvm.symbol.batch_norm @@ -27,6 +30,8 @@ NNVM Core Primitives **Level 2: Convolutions** +This level enables typical convnet models. + .. autosummary:: :nosignatures: @@ -78,12 +83,14 @@ NNVM Core Primitives .. autofunction:: nnvm.symbol.sigmoid .. autofunction:: nnvm.symbol.exp .. autofunction:: nnvm.symbol.log +.. autofunction:: nnvm.symbol.sqrt .. autofunction:: nnvm.symbol.elemwise_add .. autofunction:: nnvm.symbol.elemwise_sub .. autofunction:: nnvm.symbol.elemwise_mul .. autofunction:: nnvm.symbol.elemwise_div .. autofunction:: nnvm.symbol.flatten .. autofunction:: nnvm.symbol.concatenate +.. autofunction:: nnvm.symbol.expand_dims .. autofunction:: nnvm.symbol.split .. autofunction:: nnvm.symbol.dropout .. autofunction:: nnvm.symbol.batch_norm diff --git a/nnvm/include/nnvm/compiler/op_attr_types.h b/nnvm/include/nnvm/compiler/op_attr_types.h index 8381733c33a1a..9b9fbed4455f1 100644 --- a/nnvm/include/nnvm/compiler/op_attr_types.h +++ b/nnvm/include/nnvm/compiler/op_attr_types.h @@ -25,16 +25,23 @@ using ::tvm::Tensor; using ::tvm::Schedule; /*! \brief operator pattern used in graph fusion */ -enum OpPatternKind : int { +enum OpPatternKind { // Elementwise operation kElemWise = 0, - // Broadcast operation + // Broadcasting operator, can always map output axis to the input in order. + // for example :code:`out[i, ax1, j, ax2] = input[i, j]`. + // Note that the axis need to be in order so transpose is not a bcast operator. kBroadcast = 1, - // Complex operation, can fuse bcast in input/outputs + // Injective operator, can always injectively map output axis to a single input axis. + // All injective operator can still be safely fused to injective and reduction. + kInjective = 2, + // Communicative reduction operator. + kCommReduce = 3, + // Complex operation, can still fuse elemwise operations into its output. // but cannot chain another complex op - kComplex = 2, - // Extern operation, cannot fuse anything. - kExtern = 3 + kOutEWiseFusable = 4, + // Opaque operation, cannot fuse anything. + kOpaque = 8 }; /*! \brief the operator pattern */ diff --git a/nnvm/python/nnvm/compiler/registry.py b/nnvm/python/nnvm/compiler/registry.py index c8094b1d345fb..2861877e8db3e 100644 --- a/nnvm/python/nnvm/compiler/registry.py +++ b/nnvm/python/nnvm/compiler/registry.py @@ -3,12 +3,24 @@ import tvm class OpPattern(object): - ELEM_WISE = 0 + """Operator generic patterns + + See Also + -------- + top.tag : Contains explaination of the tag type. + """ + # Elementwise operator + ELEMWISE = 0 + # Broadcast operator BROADCAST = 1 - # Complex means we can fuse elemwise to it - COMPLEX = 2 - # Extern means the op is not fusable - EXTERN = 3 + # Injective mapping + INJECTIVE = 2 + # Comunication + COMM_REDUCE = 3 + # Complex op, can still fuse ewise into it + OUT_ELEMWISE_FUSABLE = 4 + # Not fusable opaque op + OPAQUE = 8 _register_compute = tvm.get_global_func("nnvm._register_compute") _register_schedule = tvm.get_global_func("nnvm._register_schedule") diff --git a/nnvm/python/nnvm/top/nn.py b/nnvm/python/nnvm/top/nn.py index 71246bc0823e7..3c3f737f4daf9 100644 --- a/nnvm/python/nnvm/top/nn.py +++ b/nnvm/python/nnvm/top/nn.py @@ -16,8 +16,16 @@ def compute_relu(attrs, inputs, _): return topi.nn.relu(inputs[0]) reg.register_schedule("relu", _fschedule_broadcast) -reg.register_pattern("relu", OpPattern.ELEM_WISE) +reg.register_pattern("relu", OpPattern.ELEMWISE) +# leaky_relu +@reg.register_compute("leaky_relu") +def compute_relu(attrs, inputs, _): + """Compute definition of relu""" + return topi.nn.leaky_relu(inputs[0]) + +reg.register_schedule("leaky_relu", _fschedule_broadcast) +reg.register_pattern("leaky_relu", OpPattern.ELEMWISE) # flatten @reg.register_compute("flatten") @@ -26,7 +34,7 @@ def compute_flatten(attrs, inputs, _): return topi.nn.flatten(inputs[0]) reg.register_schedule("flatten", _fschedule_broadcast) -reg.register_pattern("flatten", OpPattern.COMPLEX) +reg.register_pattern("flatten", OpPattern.INJECTIVE) # softmax @@ -46,7 +54,7 @@ def schedule_softmax(_, outs, target): return tvm.create_schedule([x.op for x in outs]) # Mark softmax as extern as we do not fuse it in call cases -reg.register_pattern("softmax", OpPattern.EXTERN) +reg.register_pattern("softmax", OpPattern.OPAQUE) # dense @@ -67,7 +75,7 @@ def schedule_dense(_, outs, target): return tvm.create_schedule([x.op for x in outs]) # register extern for now, change me when fusion is enabled. -reg.register_pattern("dense", OpPattern.EXTERN) +reg.register_pattern("dense", OpPattern.OPAQUE) # conv @@ -105,4 +113,4 @@ def schedule_conv2d(attrs, outs, target): # naive schedule return tvm.create_schedule([x.op for x in outs]) -reg.register_pattern("conv2d", OpPattern.COMPLEX) +reg.register_pattern("conv2d", OpPattern.OUT_ELEMWISE_FUSABLE) diff --git a/nnvm/python/nnvm/top/tensor.py b/nnvm/python/nnvm/top/tensor.py index c427f49ab8d28..cc35e6e8ba4a1 100644 --- a/nnvm/python/nnvm/top/tensor.py +++ b/nnvm/python/nnvm/top/tensor.py @@ -8,13 +8,15 @@ from ..compiler import registry as reg from ..compiler import OpPattern -def _schedule_broadcast(_, outs, target): +def _schedule_injective(_, outs, target): """Generic schedule for binary bcast""" if target == "cuda": - return topi.cuda.schedule_elemwise(outs) + return topi.cuda.schedule_injective(outs) assert target.startswith("llvm") s = tvm.create_schedule([x.op for x in outs]) + x = outs[0] tvm.schedule.AutoInlineInjective(s) + s[x].fuse(s[x].op.axis) return s def _compute_binary_scalar(f): @@ -42,89 +44,91 @@ def _compute(attrs, x, _): return _compute -_fschedule_broadcast = tvm.convert(_schedule_broadcast) +_fschedule_injective = tvm.convert(_schedule_injective) +_fschedule_broadcast = _fschedule_injective +_fschedule_elemwise = _fschedule_injective # copy reg.register_compute("copy", _compute_unary(topi.identity)) -reg.register_pattern("copy", OpPattern.ELEM_WISE) +reg.register_pattern("copy", OpPattern.ELEMWISE) reg.register_schedule("copy", _fschedule_broadcast) # exp reg.register_compute("exp", _compute_unary(topi.exp)) -reg.register_pattern("exp", OpPattern.ELEM_WISE) +reg.register_pattern("exp", OpPattern.ELEMWISE) reg.register_schedule("exp", _fschedule_broadcast) # sqrt reg.register_compute("sqrt", _compute_unary(topi.sqrt)) -reg.register_pattern("sqrt", OpPattern.ELEM_WISE) +reg.register_pattern("sqrt", OpPattern.ELEMWISE) reg.register_schedule("sqrt", _fschedule_broadcast) # log reg.register_compute("log", _compute_unary(topi.log)) -reg.register_pattern("log", OpPattern.ELEM_WISE) +reg.register_pattern("log", OpPattern.ELEMWISE) reg.register_schedule("log", _fschedule_broadcast) # tanh reg.register_compute("tanh", _compute_unary(topi.tanh)) -reg.register_pattern("tanh", OpPattern.ELEM_WISE) +reg.register_pattern("tanh", OpPattern.ELEMWISE) reg.register_schedule("tanh", _fschedule_broadcast) # negative reg.register_compute("negative", _compute_unary(topi.negative)) -reg.register_pattern("negative", OpPattern.ELEM_WISE) +reg.register_pattern("negative", OpPattern.ELEMWISE) reg.register_schedule("negative", _fschedule_broadcast) # sigmoid reg.register_compute("sigmoid", _compute_unary(topi.sigmoid)) -reg.register_pattern("sigmoid", OpPattern.ELEM_WISE) +reg.register_pattern("sigmoid", OpPattern.ELEMWISE) reg.register_schedule("sigmoid", _fschedule_broadcast) # add_scalar reg.register_compute("__add_scalar__", _compute_binary_scalar(lambda x, y: x + y)) -reg.register_pattern("__add_scalar__", OpPattern.ELEM_WISE) +reg.register_pattern("__add_scalar__", OpPattern.ELEMWISE) reg.register_schedule("__add_scalar__", _fschedule_broadcast) # sub_calar reg.register_compute("__sub_scalar__", _compute_binary_scalar(lambda x, y: x - y)) -reg.register_pattern("__sub_scalar__", OpPattern.ELEM_WISE) +reg.register_pattern("__sub_scalar__", OpPattern.ELEMWISE) reg.register_schedule("__sub_scalar__", _fschedule_broadcast) # rsub_scalar reg.register_compute("__rsub_scalar__", _compute_binary_scalar(lambda x, y: y - x)) -reg.register_pattern("__rsub_scalar__", OpPattern.ELEM_WISE) +reg.register_pattern("__rsub_scalar__", OpPattern.ELEMWISE) reg.register_schedule("__rsub_scalar__", _fschedule_broadcast) # mul_scalar reg.register_compute("__mul_scalar__", _compute_binary_scalar(lambda x, y: x * y)) -reg.register_pattern("__mul_scalar__", OpPattern.ELEM_WISE) +reg.register_pattern("__mul_scalar__", OpPattern.ELEMWISE) reg.register_schedule("__mul_scalar__", _fschedule_broadcast) # div_scalar reg.register_compute("__div_scalar__", _compute_binary_scalar(lambda x, y: x / y)) -reg.register_pattern("__div_scalar__", OpPattern.ELEM_WISE) +reg.register_pattern("__div_scalar__", OpPattern.ELEMWISE) reg.register_schedule("__div_scalar__", _fschedule_broadcast) # rdiv_scalar reg.register_compute("__rdiv_scalar__", _compute_binary_scalar(lambda x, y: y / x)) -reg.register_pattern("__rdiv_scalar__", OpPattern.ELEM_WISE) +reg.register_pattern("__rdiv_scalar__", OpPattern.ELEMWISE) reg.register_schedule("__rdiv_scalar__", _fschedule_broadcast) # pow_scalar reg.register_compute("__pow_scalar__", _compute_binary_scalar(tvm.power)) -reg.register_pattern("__pow_scalar__", OpPattern.ELEM_WISE) +reg.register_pattern("__pow_scalar__", OpPattern.ELEMWISE) reg.register_schedule("__pow_scalar__", _fschedule_broadcast) # rpow_scalar reg.register_compute("__rpow_scalar__", _compute_binary_scalar(lambda x, y: tvm.power(y, x))) -reg.register_pattern("__rpow_scalar__", OpPattern.ELEM_WISE) +reg.register_pattern("__rpow_scalar__", OpPattern.ELEMWISE) reg.register_schedule("__rpow_scalar__", _fschedule_broadcast) # elemwise_add diff --git a/nnvm/python/nnvm/top/transform.py b/nnvm/python/nnvm/top/transform.py index e7419c030df46..a69caf39bcd06 100644 --- a/nnvm/python/nnvm/top/transform.py +++ b/nnvm/python/nnvm/top/transform.py @@ -37,5 +37,5 @@ def compute_reshape(attrs, inputs, out_info): oshape = out_info[0].shape x = inputs[0] return tvm.compute(oshape, lambda *i: x(_flatten_index(i, oshape))) -reg.register_pattern("reshape", OpPattern.COMPLEX) +reg.register_pattern("reshape", OpPattern.INJECTIVE) reg.register_schedule("reshape", _fschedule_broadcast) diff --git a/nnvm/src/compiler/graph_fuse.cc b/nnvm/src/compiler/graph_fuse.cc index 1daa5fd113942..e3935ed95ff19 100644 --- a/nnvm/src/compiler/graph_fuse.cc +++ b/nnvm/src/compiler/graph_fuse.cc @@ -71,7 +71,7 @@ nnvm::Graph GraphFusePartition(nnvm::Graph g) { ref_count[e.node_id] += 2; } // Pattern for the subgraph - std::vector pattern_vec(idx.num_nodes(), kExtern); + std::vector pattern_vec(idx.num_nodes(), kOpaque); // Whether node can be fused to parent. std::vector fuse_vec(idx.num_nodes(), FuseRule::kUknown); // Master node id of fusion segment. @@ -84,19 +84,21 @@ nnvm::Graph GraphFusePartition(nnvm::Graph g) { if (inode.source->is_variable()) { fuse_vec[nid] = FuseRule::kRealize; continue; } - TOpPattern pt = op_pattern.get(inode.source->op(), kExtern); + TOpPattern pt = op_pattern.get(inode.source->op(), kOpaque); if (pt <= kBroadcast) { + // Try to check if we can fuse to the master. int chosen_master = -1; bool ewise = inode.source->num_outputs() == 1; for (const auto& e : inode.inputs) { if (fuse_vec[e.node_id] == FuseRule::kUknown) { TOpPattern ipt = pattern_vec[e.node_id]; if (ipt != kElemWise) ewise = false; - if (ipt <= kBroadcast) { + if (ipt <= kInjective) { fuse_vec[e.node_id] = FuseRule::kFuseToMaster; - } else if (ipt == kComplex && chosen_master == -1 && - shape_vec[idx.entry_id(nid, 0)] == shape_vec[idx.entry_id(e)]) { + } else if (ipt == kOutEWiseFusable && + chosen_master == -1 && + shape_vec[idx.entry_id(nid, 0)] == shape_vec[idx.entry_id(e)]) { chosen_master = master_vec[e.node_id]; fuse_vec[e.node_id] = FuseRule::kFuseToMaster; } else { @@ -111,11 +113,27 @@ nnvm::Graph GraphFusePartition(nnvm::Graph g) { } master_vec[nid] = chosen_master; if (chosen_master != -1) { - pt = kComplex; + pt = kOutEWiseFusable; } else { pt = ewise ? kElemWise : kBroadcast; } + } else if (pt == kInjective || pt == kCommReduce) { + // fuse to the comm reduce or injective + for (const auto& e : inode.inputs) { + if (fuse_vec[e.node_id] == FuseRule::kUknown) { + TOpPattern ipt = pattern_vec[e.node_id]; + if (ipt <= kInjective) { + fuse_vec[e.node_id] = FuseRule::kFuseToMaster; + } else { + fuse_vec[e.node_id] = FuseRule::kRealize; + } + } + } + if (pt == kCommReduce) { + master_vec[nid] = nid; + } } else { + // realize master_vec[nid] = nid; for (const auto& e : inode.inputs) { if (fuse_vec[e.node_id] == FuseRule::kUknown) { @@ -136,7 +154,6 @@ nnvm::Graph GraphFusePartition(nnvm::Graph g) { } } - // point to the group root id of each node std::vector group_vec(idx.num_nodes(), -1); for (uint32_t i = idx.num_nodes(); i != 0; --i) { diff --git a/nnvm/src/compiler/layout_transform.cc b/nnvm/src/compiler/layout_transform.cc index 2bce0fad54c9e..5651838ffb85a 100644 --- a/nnvm/src/compiler/layout_transform.cc +++ b/nnvm/src/compiler/layout_transform.cc @@ -52,7 +52,7 @@ nnvm::Graph LayoutTransform(nnvm::Graph src) { // use op pattern to decide whether an op is map auto is_map_op = [&](size_t nid) { - TOpPattern pt = op_pattern.get(idx[nid].source->op(), kExtern); + TOpPattern pt = op_pattern.get(idx[nid].source->op(), kOpaque); bool is_map = (pt <= kBroadcast); if (pt == kBroadcast) { for (const auto& e : idx[nid].inputs) { diff --git a/nnvm/tests/python/compiler/test_op_fusion.py b/nnvm/tests/python/compiler/test_op_fusion.py new file mode 100644 index 0000000000000..45031bf1802be --- /dev/null +++ b/nnvm/tests/python/compiler/test_op_fusion.py @@ -0,0 +1,61 @@ +import nnvm +import numpy as np +import tvm +import topi +from nnvm import symbol as sym +from nnvm.compiler import graph_util, graph_attr +from nnvm.testing.config import test_ctx_list + +def test_ewise_injective(): + x = sym.Variable("x") + y = x * 2 + y = sym.flatten(y) + 1 + dshape = (10, 2, 3) + shape_dict = {"x": dshape} + dtype = "float32" + target = "llvm" + for target, ctx in test_ctx_list(): + graph, lib, _ = nnvm.compiler.build(y, target, shape_dict) + assert graph.index.num_nodes == 2 + m = nnvm.runtime.create(graph, lib, ctx) + x_np = np.random.uniform(size=dshape).astype(dtype) + m.run(x=x_np) + out = m.get_output(0, tvm.nd.empty((10, 6))) + np.testing.assert_allclose( + out.asnumpy(), x_np.reshape(out.shape) * 2 + 1, + atol=1e-5, rtol=1e-5) + + +def test_conv_ewise_injective(): + x = sym.Variable("x") + y = sym.conv2d(x, channels=32, kernel_size=(3, 3), groups=32, + name="y", padding=(1,1)) + y = sym.flatten(y + 1) + 1 + dtype = "float32" + dshape = (1, 32, 18, 18) + kshape = (32, 1, 3, 3) + oshape = (1, 32* 18 * 18) + shape_dict = {"x": dshape} + + for target, ctx in test_ctx_list(): + graph, lib, _ = nnvm.compiler.build(y, target, shape_dict) + m = nnvm.runtime.create(graph, lib, ctx) + # print(graph.ir(join_entry_attrs=["shape"])) + assert graph.index.num_nodes == 5 + # set input + data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype)) + kernel = tvm.nd.array(np.random.uniform(size=kshape).astype(dtype)) + bias = tvm.nd.array(np.random.uniform(size=kshape[0]).astype(dtype)) + m.run(x=data, y_weight=kernel, y_bias=bias) + # get output + out = m.get_output(0, tvm.nd.empty(oshape, dtype)) + c_np = topi.testing.depthwise_conv2d_python_nchw( + data.asnumpy(), kernel.asnumpy(), (1,1), 'SAME') + c_np = c_np + bias.asnumpy().reshape(kshape[0], 1, 1) + 1 + c_np = c_np.reshape(c_np.shape[0], np.prod(c_np.shape[1:])) + 1 + np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5) + + +if __name__ == "__main__": + test_ewise_injective() + test_conv_ewise_injective()