Skip to content

Commit

Permalink
[TOPI] Custom schedule for standalone transpose in cuda (#8030)
Browse files Browse the repository at this point in the history
* [TOPI] Custom schedule for standalone transpose in cuda

* check if input is not Any

* fix vta test

* check input shape

* fix injective

* move transpose out of sparse.py

* update comments, use warp size

* missspelled transform

* formatting

* rename test

* comment

* fix tests
  • Loading branch information
tkonolige authored May 20, 2021
1 parent 71ff875 commit 28ea03c
Show file tree
Hide file tree
Showing 8 changed files with 129 additions and 39 deletions.
4 changes: 3 additions & 1 deletion python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
_reg.register_injective_schedule("slice_like")
_reg.register_injective_schedule("split")
_reg.register_injective_schedule("take")
_reg.register_injective_schedule("transpose")
_reg.register_injective_schedule("stack")
_reg.register_injective_schedule("contrib_reverse_reshape")
_reg.register_injective_schedule("gather")
Expand Down Expand Up @@ -746,6 +745,9 @@ def transpose_shape_func(attrs, inputs, _):
return [_transpose_shape_func(inputs[0], convert(axes))]


_reg.register_schedule("transpose", strategy.schedule_transpose)


@script
def _squeeze_shape_func(data_shape, keep_axes, remove_axes):
out = output_tensor((len(keep_axes),), "int64")
Expand Down
22 changes: 22 additions & 0 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from tvm.te import SpecializedCondition

from .. import op as _op
from ....target import Target
from ....tir import IntImm
from .generic import *


Expand Down Expand Up @@ -1068,3 +1070,23 @@ def unique_strategy_cuda(attrs, inputs, out_type, target):
name="unique.cuda",
)
return strategy


@schedule_transpose.register(["cuda", "gpu", "rocm"])
def schedule_transpose_cuda(attrs, outs, target):
"""
Transpose cuda strategy
Dispatches to and optimized schedule if the transpose is standalone (not fused).
"""
warp_size = int(Target.current(allow_none=False).thread_warp_size)
if (
isinstance(outs[0].op.input_tensors[0].op, te.PlaceholderOp)
and len(outs[0].shape) == 2
and (attrs.axes is None or (len(attrs.axes) == 2 and attrs.axes == [1, 0]))
and isinstance(outs[0].shape[0], (int, IntImm))
and outs[0].shape[0] >= warp_size
and isinstance(outs[0].shape[1], (int, IntImm))
and outs[0].shape[1] >= warp_size
):
return topi.cuda.schedule_transpose(outs)
return schedule_injective(attrs, outs, target)
7 changes: 7 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1570,3 +1570,10 @@ def unique_strategy(attrs, inputs, out_type, target):
name="unique.generic",
)
return strategy


@generic_func
def schedule_transpose(attrs, outs, target):
"""schedule transpose"""
with target:
return schedule_injective(attrs, outs, target)
1 change: 1 addition & 0 deletions python/tvm/topi/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,5 @@
from .argwhere import *
from .scan import *
from .sparse_reshape import *
from .transform import *
from .unique import *
39 changes: 2 additions & 37 deletions python/tvm/topi/cuda/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from .. import nn
from ..utils import traverse_inline, get_const_tuple, prod, get_const_int, ceil_div
from .transform import schedule_transpose_from_existing


def sparse_dense(data, weight_data, weight_indices, weight_indptr, sparse_lhs=False):
Expand Down Expand Up @@ -105,42 +106,6 @@ def _callback(op):
return s


def schedule_cuda_transpose(s, out):
"""Schedule for transpose on the gpu.
Roughly follows this:
https://developer.nvidia.com/blog/efficient-matrix-transpose-cuda-cc/, but
without the padding for shared memory. For better performance, we could
rewrite it in tir to add the padding.
"""

def _callback(op):
# pylint: disable=invalid-name
m, n = s[op].op.axis
warp_size = int(tvm.target.Target.current(allow_none=False).thread_warp_size)
no, ni = s[op].split(n, factor=warp_size)
mo, mi = s[op].split(m, factor=warp_size)
s[op].reorder(mo, no, mi, ni)
s[op].bind(mo, te.thread_axis("blockIdx.x"))
s[op].bind(no, te.thread_axis("blockIdx.y"))
c = s.cache_read(op.input_tensors[0], "shared", op)
s[c].compute_at(s[op], no)
thread_x = te.thread_axis("threadIdx.x")
thread_y = te.thread_axis("threadIdx.y")
s[op].bind(ni, thread_x)
# This is a hack to make the scheduling language realize that this axis
# can be scheduled.
a, _ = s[c].split(s[c].op.axis[1], factor=1)
s[c].bind(a, thread_x)
# Use 4 warps per block. Slightly faster than 1 warp per block
ao, _ = s[op].split(mi, nparts=4)
s[op].bind(ao, thread_y)
ao, _ = s[c].split(s[c].op.axis[0], nparts=4)
s[c].bind(ao, thread_y)

traverse_inline(s, out.op, _callback)


def sparse_dense_tir(data, w_data, w_indices, w_indptr):
"""Compute data * w^T.
Expand Down Expand Up @@ -388,7 +353,7 @@ def schedule_sparse_dense_padded(outs):
# necessary
data_t = outs[0].op.input_tensors[0]
s = te.create_schedule([outs[0].op, data_t.op])
schedule_cuda_transpose(s, outs[0].op.input_tensors[0])
schedule_transpose_from_existing(s, outs[0].op.input_tensors[0])
return s


Expand Down
67 changes: 67 additions & 0 deletions python/tvm/topi/cuda/transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""CUDA implementations of transforms"""

from ... import te
from ...target import Target
from ..utils import traverse_inline


def schedule_transpose(outs):
"""Schedule a unfused transpose"""
outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
s = te.create_schedule([x.op for x in outs])
schedule_transpose_from_existing(s, outs[0])
return s


def schedule_transpose_from_existing(s, out):
"""Schedule for transpose on the gpu.
Roughly follows this:
https://developer.nvidia.com/blog/efficient-matrix-transpose-cuda-cc/, but
without the padding for shared memory. For better performance, we could
rewrite it in tir to add the padding. Also, rewriting in tir would allow
use to use warp shuffles instead of shared memory (see
https://github.com/bryancatanzaro/trove).
"""

def _callback(op):
# pylint: disable=invalid-name
m, n = s[op].op.axis
warp_size = int(Target.current(allow_none=False).thread_warp_size)
no, ni = s[op].split(n, factor=warp_size)
mo, mi = s[op].split(m, factor=warp_size)
s[op].reorder(mo, no, mi, ni)
s[op].bind(mo, te.thread_axis("blockIdx.x"))
s[op].bind(no, te.thread_axis("blockIdx.y"))
c = s.cache_read(op.input_tensors[0], "shared", op)
s[c].compute_at(s[op], no)
thread_x = te.thread_axis("threadIdx.x")
thread_y = te.thread_axis("threadIdx.y")
s[op].bind(ni, thread_x)
# This is a hack to make the scheduling language realize that this axis
# can be scheduled.
a, _ = s[c].split(s[c].op.axis[1], factor=1)
s[c].bind(a, thread_x)
# Use 4 warps per block. Slightly faster than 1 warp per block
ao, _ = s[op].split(mi, nparts=4)
s[op].bind(ao, thread_y)
ao, _ = s[c].split(s[c].op.axis[0], nparts=4)
s[c].bind(ao, thread_y)

traverse_inline(s, out.op, _callback)
26 changes: 26 additions & 0 deletions tests/python/topi/python/test_topi_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import tvm
from tvm import te
from tvm import topi
from tvm import relay
import tvm.topi.testing
from tvm.contrib.nvcc import have_fp16

Expand Down Expand Up @@ -870,6 +871,31 @@ def test_transpose():
verify_transpose((3, 10), None)


@tvm.testing.parametrize_targets("cuda", "rocm")
def test_transpose_unfused_schedule(target, dev):
shape = (100, tvm.target.Target(target).thread_warp_size + 3)
x = relay.var("x", relay.TensorType(shape, "float32"))
f = relay.transpose(x)
ex = relay.create_executor(
kind="graph", mod=tvm.IRModule.from_expr(relay.Function([x], f)), device=dev, target=target
)
r = np.random.rand(*shape)
tvm.testing.assert_allclose(ex.evaluate()(r).asnumpy(), np.transpose(r))

# We want to make sure schedule does not fire here, but there is no way of
# inspecting which schedules were used.
x = relay.var("x", relay.TensorType(shape, "float32"))
y = relay.var("y", relay.TensorType(shape, "float32"))
f = relay.transpose(x + y)
ex = relay.create_executor(
kind="graph",
mod=tvm.IRModule.from_expr(relay.Function([x, y], f)),
device=dev,
target=target,
)
tvm.testing.assert_allclose(ex.evaluate()(r, r).asnumpy(), np.transpose(r + r))


@tvm.testing.uses_gpu
def test_reshape():
verify_reshape((1, 2, 3, 4), (2, 3, 4))
Expand Down
2 changes: 1 addition & 1 deletion vta/tutorials/autotvm/tune_relay_vta.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def tune_and_evaluate(tuning_opt):
)

# filter out non-packed conv2d task
tasks = list(filter(lambda t: len(t.args[0][1]) > 4, tasks))
tasks = list(filter(lambda t: len(t.args[0][1]) > 4 and "conv" in t.name, tasks))

# We should have extracted 10 convolution tasks
assert len(tasks) == 10
Expand Down

0 comments on commit 28ea03c

Please sign in to comment.