From 28ea03c83b53bcc296ae4c7ab380c077fa326571 Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Thu, 20 May 2021 16:59:07 -0700 Subject: [PATCH] [TOPI] Custom schedule for standalone transpose in cuda (#8030) * [TOPI] Custom schedule for standalone transpose in cuda * check if input is not Any * fix vta test * check input shape * fix injective * move transpose out of sparse.py * update comments, use warp size * missspelled transform * formatting * rename test * comment * fix tests --- python/tvm/relay/op/_transform.py | 4 +- python/tvm/relay/op/strategy/cuda.py | 22 ++++++ python/tvm/relay/op/strategy/generic.py | 7 ++ python/tvm/topi/cuda/__init__.py | 1 + python/tvm/topi/cuda/sparse.py | 39 +---------- python/tvm/topi/cuda/transform.py | 67 +++++++++++++++++++ .../python/topi/python/test_topi_transform.py | 26 +++++++ vta/tutorials/autotvm/tune_relay_vta.py | 2 +- 8 files changed, 129 insertions(+), 39 deletions(-) create mode 100644 python/tvm/topi/cuda/transform.py diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 76adee477a1a..412acb4cea17 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -53,7 +53,6 @@ _reg.register_injective_schedule("slice_like") _reg.register_injective_schedule("split") _reg.register_injective_schedule("take") -_reg.register_injective_schedule("transpose") _reg.register_injective_schedule("stack") _reg.register_injective_schedule("contrib_reverse_reshape") _reg.register_injective_schedule("gather") @@ -746,6 +745,9 @@ def transpose_shape_func(attrs, inputs, _): return [_transpose_shape_func(inputs[0], convert(axes))] +_reg.register_schedule("transpose", strategy.schedule_transpose) + + @script def _squeeze_shape_func(data_shape, keep_axes, remove_axes): out = output_tensor((len(keep_axes),), "int64") diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index a6775ae7bd20..6c5b1e0cdead 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -23,6 +23,8 @@ from tvm.te import SpecializedCondition from .. import op as _op +from ....target import Target +from ....tir import IntImm from .generic import * @@ -1068,3 +1070,23 @@ def unique_strategy_cuda(attrs, inputs, out_type, target): name="unique.cuda", ) return strategy + + +@schedule_transpose.register(["cuda", "gpu", "rocm"]) +def schedule_transpose_cuda(attrs, outs, target): + """ + Transpose cuda strategy + Dispatches to and optimized schedule if the transpose is standalone (not fused). + """ + warp_size = int(Target.current(allow_none=False).thread_warp_size) + if ( + isinstance(outs[0].op.input_tensors[0].op, te.PlaceholderOp) + and len(outs[0].shape) == 2 + and (attrs.axes is None or (len(attrs.axes) == 2 and attrs.axes == [1, 0])) + and isinstance(outs[0].shape[0], (int, IntImm)) + and outs[0].shape[0] >= warp_size + and isinstance(outs[0].shape[1], (int, IntImm)) + and outs[0].shape[1] >= warp_size + ): + return topi.cuda.schedule_transpose(outs) + return schedule_injective(attrs, outs, target) diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index db73a874005f..570c1e9983fc 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1570,3 +1570,10 @@ def unique_strategy(attrs, inputs, out_type, target): name="unique.generic", ) return strategy + + +@generic_func +def schedule_transpose(attrs, outs, target): + """schedule transpose""" + with target: + return schedule_injective(attrs, outs, target) diff --git a/python/tvm/topi/cuda/__init__.py b/python/tvm/topi/cuda/__init__.py index 4d838db8bfba..21ddf57ca1d0 100644 --- a/python/tvm/topi/cuda/__init__.py +++ b/python/tvm/topi/cuda/__init__.py @@ -57,4 +57,5 @@ from .argwhere import * from .scan import * from .sparse_reshape import * +from .transform import * from .unique import * diff --git a/python/tvm/topi/cuda/sparse.py b/python/tvm/topi/cuda/sparse.py index 1e846ebf5311..b6baa9cd67a5 100644 --- a/python/tvm/topi/cuda/sparse.py +++ b/python/tvm/topi/cuda/sparse.py @@ -24,6 +24,7 @@ from .. import nn from ..utils import traverse_inline, get_const_tuple, prod, get_const_int, ceil_div +from .transform import schedule_transpose_from_existing def sparse_dense(data, weight_data, weight_indices, weight_indptr, sparse_lhs=False): @@ -105,42 +106,6 @@ def _callback(op): return s -def schedule_cuda_transpose(s, out): - """Schedule for transpose on the gpu. - - Roughly follows this: - https://developer.nvidia.com/blog/efficient-matrix-transpose-cuda-cc/, but - without the padding for shared memory. For better performance, we could - rewrite it in tir to add the padding. - """ - - def _callback(op): - # pylint: disable=invalid-name - m, n = s[op].op.axis - warp_size = int(tvm.target.Target.current(allow_none=False).thread_warp_size) - no, ni = s[op].split(n, factor=warp_size) - mo, mi = s[op].split(m, factor=warp_size) - s[op].reorder(mo, no, mi, ni) - s[op].bind(mo, te.thread_axis("blockIdx.x")) - s[op].bind(no, te.thread_axis("blockIdx.y")) - c = s.cache_read(op.input_tensors[0], "shared", op) - s[c].compute_at(s[op], no) - thread_x = te.thread_axis("threadIdx.x") - thread_y = te.thread_axis("threadIdx.y") - s[op].bind(ni, thread_x) - # This is a hack to make the scheduling language realize that this axis - # can be scheduled. - a, _ = s[c].split(s[c].op.axis[1], factor=1) - s[c].bind(a, thread_x) - # Use 4 warps per block. Slightly faster than 1 warp per block - ao, _ = s[op].split(mi, nparts=4) - s[op].bind(ao, thread_y) - ao, _ = s[c].split(s[c].op.axis[0], nparts=4) - s[c].bind(ao, thread_y) - - traverse_inline(s, out.op, _callback) - - def sparse_dense_tir(data, w_data, w_indices, w_indptr): """Compute data * w^T. @@ -388,7 +353,7 @@ def schedule_sparse_dense_padded(outs): # necessary data_t = outs[0].op.input_tensors[0] s = te.create_schedule([outs[0].op, data_t.op]) - schedule_cuda_transpose(s, outs[0].op.input_tensors[0]) + schedule_transpose_from_existing(s, outs[0].op.input_tensors[0]) return s diff --git a/python/tvm/topi/cuda/transform.py b/python/tvm/topi/cuda/transform.py new file mode 100644 index 000000000000..89caf94bbbc1 --- /dev/null +++ b/python/tvm/topi/cuda/transform.py @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""CUDA implementations of transforms""" + +from ... import te +from ...target import Target +from ..utils import traverse_inline + + +def schedule_transpose(outs): + """Schedule a unfused transpose""" + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = te.create_schedule([x.op for x in outs]) + schedule_transpose_from_existing(s, outs[0]) + return s + + +def schedule_transpose_from_existing(s, out): + """Schedule for transpose on the gpu. + + Roughly follows this: + https://developer.nvidia.com/blog/efficient-matrix-transpose-cuda-cc/, but + without the padding for shared memory. For better performance, we could + rewrite it in tir to add the padding. Also, rewriting in tir would allow + use to use warp shuffles instead of shared memory (see + https://github.com/bryancatanzaro/trove). + """ + + def _callback(op): + # pylint: disable=invalid-name + m, n = s[op].op.axis + warp_size = int(Target.current(allow_none=False).thread_warp_size) + no, ni = s[op].split(n, factor=warp_size) + mo, mi = s[op].split(m, factor=warp_size) + s[op].reorder(mo, no, mi, ni) + s[op].bind(mo, te.thread_axis("blockIdx.x")) + s[op].bind(no, te.thread_axis("blockIdx.y")) + c = s.cache_read(op.input_tensors[0], "shared", op) + s[c].compute_at(s[op], no) + thread_x = te.thread_axis("threadIdx.x") + thread_y = te.thread_axis("threadIdx.y") + s[op].bind(ni, thread_x) + # This is a hack to make the scheduling language realize that this axis + # can be scheduled. + a, _ = s[c].split(s[c].op.axis[1], factor=1) + s[c].bind(a, thread_x) + # Use 4 warps per block. Slightly faster than 1 warp per block + ao, _ = s[op].split(mi, nparts=4) + s[op].bind(ao, thread_y) + ao, _ = s[c].split(s[c].op.axis[0], nparts=4) + s[c].bind(ao, thread_y) + + traverse_inline(s, out.op, _callback) diff --git a/tests/python/topi/python/test_topi_transform.py b/tests/python/topi/python/test_topi_transform.py index 16f9f13f05b0..94cdc613ce9c 100644 --- a/tests/python/topi/python/test_topi_transform.py +++ b/tests/python/topi/python/test_topi_transform.py @@ -20,6 +20,7 @@ import tvm from tvm import te from tvm import topi +from tvm import relay import tvm.topi.testing from tvm.contrib.nvcc import have_fp16 @@ -870,6 +871,31 @@ def test_transpose(): verify_transpose((3, 10), None) +@tvm.testing.parametrize_targets("cuda", "rocm") +def test_transpose_unfused_schedule(target, dev): + shape = (100, tvm.target.Target(target).thread_warp_size + 3) + x = relay.var("x", relay.TensorType(shape, "float32")) + f = relay.transpose(x) + ex = relay.create_executor( + kind="graph", mod=tvm.IRModule.from_expr(relay.Function([x], f)), device=dev, target=target + ) + r = np.random.rand(*shape) + tvm.testing.assert_allclose(ex.evaluate()(r).asnumpy(), np.transpose(r)) + + # We want to make sure schedule does not fire here, but there is no way of + # inspecting which schedules were used. + x = relay.var("x", relay.TensorType(shape, "float32")) + y = relay.var("y", relay.TensorType(shape, "float32")) + f = relay.transpose(x + y) + ex = relay.create_executor( + kind="graph", + mod=tvm.IRModule.from_expr(relay.Function([x, y], f)), + device=dev, + target=target, + ) + tvm.testing.assert_allclose(ex.evaluate()(r, r).asnumpy(), np.transpose(r + r)) + + @tvm.testing.uses_gpu def test_reshape(): verify_reshape((1, 2, 3, 4), (2, 3, 4)) diff --git a/vta/tutorials/autotvm/tune_relay_vta.py b/vta/tutorials/autotvm/tune_relay_vta.py index 38633b01d976..2f505b2a86a6 100644 --- a/vta/tutorials/autotvm/tune_relay_vta.py +++ b/vta/tutorials/autotvm/tune_relay_vta.py @@ -357,7 +357,7 @@ def tune_and_evaluate(tuning_opt): ) # filter out non-packed conv2d task - tasks = list(filter(lambda t: len(t.args[0][1]) > 4, tasks)) + tasks = list(filter(lambda t: len(t.args[0][1]) > 4 and "conv" in t.name, tasks)) # We should have extracted 10 convolution tasks assert len(tasks) == 10