From d75a7ca828b1944e94291a22fc35c9e9ccf11f6b Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Tue, 27 Oct 2020 22:40:02 -0600 Subject: [PATCH] Scatter on Cuda (#6533) * working cuda scatter fix lint fix pylint again * cuda scatter with threading * add dynamic shape tests * remove unused variable --- python/tvm/relay/op/_transform.py | 2 +- python/tvm/relay/op/strategy/cuda.py | 13 + python/tvm/relay/op/strategy/generic.py | 23 +- python/tvm/topi/cuda/__init__.py | 1 + python/tvm/topi/cuda/scatter.py | 443 ++++++++++++++++++++++++ tests/python/relay/test_op_level3.py | 41 ++- 6 files changed, 515 insertions(+), 8 deletions(-) create mode 100644 python/tvm/topi/cuda/scatter.py diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 56a3f2640e5d..4ee6f2ebb5c1 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -104,7 +104,7 @@ def compute_scatter(attrs, inputs, output_type): return [topi.scatter(inputs[0], inputs[1], inputs[2], attrs.axis)] -_reg.register_schedule("scatter", strategy.schedule_scatter) +_reg.register_strategy("scatter", strategy.scatter_strategy) # scatter_add @_reg.register_compute("scatter_add") diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index ca44e49ce1dd..d77361d906fb 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -651,6 +651,19 @@ def sparse_dense_padded_strategy_cuda(attrs, inputs, out_type, target): return strategy +@scatter_strategy.register(["cuda", "gpu"]) +def scatter_cuda(attrs, inputs, out_type, target): + """sparse dense cuda strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_scatter(topi.cuda.scatter), + wrap_topi_schedule(topi.generic.schedule_extern), + name="scatter.cuda", + plevel=10, + ) + return strategy + + @argsort_strategy.register(["cuda", "gpu"]) def argsort_strategy_cuda(attrs, inputs, out_type, target): """argsort cuda strategy""" diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index f6030d471594..34d1999707e9 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1032,11 +1032,24 @@ def schedule_argwhere(attrs, outs, target): # scatter -@generic_func -def schedule_scatter(attrs, outs, target): - """schedule scatter""" - with target: - return topi.generic.schedule_scatter(outs) +@override_native_generic_func("scatter_strategy") +def scatter_strategy(attrs, outs, out_type, target): + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_scatter(topi.scatter), + wrap_topi_schedule(topi.generic.schedule_scatter), + name="scatter.generic", + ) + return strategy + + +def wrap_compute_scatter(topi_compute): + """Wrap scatter topi compute""" + + def _compute_scatter(attrs, inputs, _): + return [topi_compute(inputs[0], inputs[1], inputs[2], axis=attrs.axis)] + + return _compute_scatter # scatter_add diff --git a/python/tvm/topi/cuda/__init__.py b/python/tvm/topi/cuda/__init__.py index ed8037024635..3ff544f4bb3e 100644 --- a/python/tvm/topi/cuda/__init__.py +++ b/python/tvm/topi/cuda/__init__.py @@ -46,6 +46,7 @@ from .ssd import * from .nms import get_valid_counts, non_max_suppression from .rcnn import * +from .scatter import * from .sort import * from .conv2d_nhwc_tensorcore import * from .conv3d_ndhwc_tensorcore import * diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py new file mode 100644 index 000000000000..6522d74d8bef --- /dev/null +++ b/python/tvm/topi/cuda/scatter.py @@ -0,0 +1,443 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument +"""Scatter operator """ +import tvm +from tvm import te + + +def ceil_div(a, b): + return (a + b - 1) // b + + +def gen_ir_1d(data, indices, updates, axis, out): + """Generate scatter ir for 1d inputs + + Parameters + ---------- + data : tir.Tensor + The input data to the operator. + + indices : tir.Tensor + The index locations to update. + + updates : tir.Tensor + The values to update. + + axis : int + The axis to scatter on + + out : tir.Tensor + The output tensor. + + Returns + ------- + ret : tir + The computational ir. + """ + assert axis == 0 + n = data.shape[0] + + ib = tvm.tir.ir_builder.create() + + out_ptr = ib.buffer_ptr(out) + data_ptr = ib.buffer_ptr(data) + + with ib.new_scope(): + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(bx, "thread_extent", n) + out_ptr[bx] = data_ptr[bx] + + indices_ptr = ib.buffer_ptr(indices) + updates_ptr = ib.buffer_ptr(updates) + + ni = indices.shape[0] + + with ib.new_scope(): + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(bx, "thread_extent", 1) + with ib.for_range(0, ni, name="i") as i: + index = indices_ptr[i] + with ib.if_scope(index < 0): + out_ptr[index + n] = updates_ptr[i] + with ib.else_scope(): + out_ptr[index] = updates_ptr[i] + + return ib.get() + + +def gen_ir_2d(data, indices, updates, axis, out): + """Generate scatter ir for 2d inputs + + Parameters + ---------- + data : tir.Tensor + The input data to the operator. + + indices : tir.Tensor + The index locations to update. + + updates : tir.Tensor + The values to update. + + axis : int + The axis to scatter on + + out : tir.Tensor + The output tensor. + + Returns + ------- + ret : tir + The computational ir. + """ + warp_size = tvm.target.Target.current(False).thread_warp_size + + n = data.shape[0] + c = data.shape[1] + + ib = tvm.tir.ir_builder.create() + + out_ptr = ib.buffer_ptr(out) + data_ptr = ib.buffer_ptr(data) + + with ib.new_scope(): + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(bx, "thread_extent", n) + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(tx, "thread_extent", warp_size) + with ib.for_range(0, ceil_div(c, warp_size), name="j") as j_: + j = j_ * warp_size + tx + with ib.if_scope(j < c): + idx = bx * c + j + out_ptr[idx] = data_ptr[idx] + + indices_ptr = ib.buffer_ptr(indices) + updates_ptr = ib.buffer_ptr(updates) + + ni = indices.shape[0] + ci = indices.shape[1] + + if axis == 0: + with ib.new_scope(): + j = te.thread_axis("blockIdx.x") + ib.scope_attr(j, "thread_extent", ci) + with ib.for_range(0, ni, name="i") as i: + idx = i * ci + j + index = indices_ptr[idx] + with ib.if_scope(index < 0): + out_ptr[(index + n) * c + j] = updates_ptr[idx] + with ib.else_scope(): + out_ptr[index * c + j] = updates_ptr[idx] + else: + with ib.new_scope(): + i = te.thread_axis("blockIdx.x") + ib.scope_attr(i, "thread_extent", ni) + with ib.for_range(0, ci, name="j") as j: + idx = i * ci + j + index = indices_ptr[idx] + with ib.if_scope(index < 0): + out_ptr[i * c + (index + c)] = updates_ptr[idx] + with ib.else_scope(): + out_ptr[i * c + index] = updates_ptr[idx] + return ib.get() + + +def gen_ir_3d(data, indices, updates, axis, out): + """Generate scatter ir for 3d inputs + + Parameters + ---------- + data : tir.Tensor + The input data to the operator. + + indices : tir.Tensor + The index locations to update. + + updates : tir.Tensor + The values to update. + + axis : int + The axis to scatter on + + out : tir.Tensor + The output tensor. + + Returns + ------- + ret : tir + The computational ir. + """ + warp_size = tvm.target.Target.current(False).thread_warp_size + + n = data.shape[0] + c = data.shape[1] + h = data.shape[2] + + ib = tvm.tir.ir_builder.create() + + out_ptr = ib.buffer_ptr(out) + data_ptr = ib.buffer_ptr(data) + + with ib.new_scope(): + bx = te.thread_axis("blockIdx.x") + ib.scope_attr(bx, "thread_extent", n) + by = te.thread_axis("blockIdx.y") + ib.scope_attr(by, "thread_extent", c) + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(tx, "thread_extent", warp_size) + with ib.for_range(0, ceil_div(h, warp_size), name="k") as k_: + k = k_ * warp_size + tx + with ib.if_scope(k < h): + idx = (bx * c + by) * h + k + out_ptr[idx] = data_ptr[idx] + + indices_ptr = ib.buffer_ptr(indices) + updates_ptr = ib.buffer_ptr(updates) + ni = indices.shape[0] + ci = indices.shape[1] + hi = indices.shape[2] + + if axis == 0: + with ib.new_scope(): + j = te.thread_axis("blockIdx.x") + ib.scope_attr(j, "thread_extent", ci) + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(tx, "thread_extent", warp_size) + with ib.for_range(0, ni, name="i") as i: + with ib.for_range(0, ceil_div(hi, warp_size), name="k") as k_: + k = k_ * warp_size + tx + with ib.if_scope(k < hi): + idx = (i * ci + j) * hi + k + index = indices_ptr[idx] + with ib.if_scope(index < 0): + out_ptr[((index + n) * c + j) * h + k] = updates_ptr[idx] + with ib.else_scope(): + out_ptr[(index * c + j) * h + k] = updates_ptr[idx] + elif axis == 1: + with ib.new_scope(): + i = te.thread_axis("blockIdx.x") + ib.scope_attr(i, "thread_extent", ni) + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(tx, "thread_extent", warp_size) + with ib.for_range(0, ci, name="j") as j: + with ib.for_range(0, ceil_div(hi, warp_size), name="k") as k_: + k = k_ * warp_size + tx + with ib.if_scope(k < hi): + idx = (i * ci + j) * hi + k + index = indices_ptr[idx] + with ib.if_scope(index < 0): + out_ptr[(i * c + (index + c)) * h + k] = updates_ptr[idx] + with ib.else_scope(): + out_ptr[(i * c + index) * h + k] = updates_ptr[idx] + else: + with ib.new_scope(): + i = te.thread_axis("blockIdx.x") + ib.scope_attr(i, "thread_extent", ni) + j = te.thread_axis("blockIdx.y") + ib.scope_attr(j, "thread_extent", ci) + with ib.for_range(0, hi, name="k") as k: + idx = (i * ci + j) * hi + k + index = indices_ptr[idx] + with ib.if_scope(index < 0): + out_ptr[(i * c + j) * h + (index + h)] = updates_ptr[idx] + with ib.else_scope(): + out_ptr[(i * c + j) * h + index] = updates_ptr[idx] + return ib.get() + + +def gen_ir_4d(data, indices, updates, axis, out): + """Generate scatter ir for 4d inputs + + Parameters + ---------- + data : tir.Tensor + The input data to the operator. + + indices : tir.Tensor + The index locations to update. + + updates : tir.Tensor + The values to update. + + axis : int + The axis to scatter on + + out : tir.Tensor + The output tensor. + + Returns + ------- + ret : tir + The computational ir. + """ + warp_size = tvm.target.Target.current(False).thread_warp_size + + n = data.shape[0] + c = data.shape[1] + h = data.shape[2] + w = data.shape[3] + + ib = tvm.tir.ir_builder.create() + + out_ptr = ib.buffer_ptr(out) + data_ptr = ib.buffer_ptr(data) + with ib.new_scope(): + i = te.thread_axis("blockIdx.x") + ib.scope_attr(i, "thread_extent", n) + j = te.thread_axis("blockIdx.y") + ib.scope_attr(j, "thread_extent", c) + k = te.thread_axis("blockIdx.z") + ib.scope_attr(k, "thread_extent", h) + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(tx, "thread_extent", warp_size) + with ib.for_range(0, ceil_div(w, warp_size), name="l") as l_: + l = l_ * warp_size + tx + with ib.if_scope(l < w): + idx = ((i * c + j) * h + k) * w + l + out_ptr[idx] = data_ptr[idx] + + indices_ptr = ib.buffer_ptr(indices) + updates_ptr = ib.buffer_ptr(updates) + ni = indices.shape[0] + ci = indices.shape[1] + hi = indices.shape[2] + wi = indices.shape[3] + + if axis == 0: + with ib.new_scope(): + j = te.thread_axis("blockIdx.y") + ib.scope_attr(j, "thread_extent", ci) + k = te.thread_axis("blockIdx.z") + ib.scope_attr(k, "thread_extent", hi) + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(tx, "thread_extent", warp_size) + with ib.for_range(0, ni, name="i") as i: + with ib.for_range(0, ceil_div(wi, warp_size), name="l") as l_: + l = l_ * warp_size + tx + with ib.if_scope(l < wi): + idx = ((i * ci + j) * hi + k) * wi + l + index = indices_ptr[idx] + with ib.if_scope(index < 0): + out_ptr[(((index + n) * c + j) * h + k) * w + l] = updates_ptr[idx] + with ib.else_scope(): + out_ptr[((index * c + j) * h + k) * w + l] = updates_ptr[idx] + elif axis == 1: + with ib.new_scope(): + i = te.thread_axis("blockIdx.x") + ib.scope_attr(i, "thread_extent", ni) + k = te.thread_axis("blockIdx.z") + ib.scope_attr(k, "thread_extent", hi) + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(tx, "thread_extent", warp_size) + with ib.for_range(0, ci, name="j") as j: + with ib.for_range(0, ceil_div(wi, warp_size), name="l") as l_: + l = l_ * warp_size + tx + with ib.if_scope(l < wi): + idx = ((i * ci + j) * hi + k) * wi + l + index = indices_ptr[idx] + with ib.if_scope(index < 0): + out_ptr[((i * c + (index + c)) * h + k) * w + l] = updates_ptr[idx] + with ib.else_scope(): + out_ptr[((i * c + index) * h + k) * w + l] = updates_ptr[idx] + elif axis == 2: + with ib.new_scope(): + i = te.thread_axis("blockIdx.x") + ib.scope_attr(i, "thread_extent", ni) + j = te.thread_axis("blockIdx.y") + ib.scope_attr(j, "thread_extent", ci) + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(tx, "thread_extent", warp_size) + with ib.for_range(0, hi, name="k") as k: + with ib.for_range(0, ceil_div(wi, warp_size), name="l") as l_: + l = l_ * warp_size + tx + with ib.if_scope(l < wi): + idx = ((i * ci + j) * hi + k) * wi + l + index = indices_ptr[idx] + with ib.if_scope(index < 0): + out_ptr[((i * c + j) * h + (index + h)) * w + l] = updates_ptr[idx] + with ib.else_scope(): + out_ptr[((i * c + j) * h + index) * w + l] = updates_ptr[idx] + else: + with ib.new_scope(): + i = te.thread_axis("blockIdx.x") + ib.scope_attr(i, "thread_extent", ni) + j = te.thread_axis("blockIdx.y") + ib.scope_attr(j, "thread_extent", ci) + k = te.thread_axis("blockIdx.z") + ib.scope_attr(k, "thread_extent", hi) + with ib.for_range(0, wi, name="l") as l: + idx = ((i * ci + j) * hi + k) * wi + l + index = indices_ptr[idx] + with ib.if_scope(index < 0): + out_ptr[((i * c + j) * h + k) * w + (index + w)] = updates_ptr[idx] + with ib.else_scope(): + out_ptr[((i * c + j) * h + k) * w + index] = updates_ptr[idx] + + return ib.get() + + +def scatter(data, indices, updates, axis=0): + """Update data at positions defined by indices with values in updates + + Parameters + ---------- + data : relay.Expr + The input data to the operator. + + indices : relay.Expr + The index locations to update. + + updates : relay.Expr + The values to update. + + axis : int + The axis to scatter on + + Returns + ------- + ret : relay.Expr + The computed result. + """ + if axis < 0: + axis += len(data.shape) + assert axis >= 0 + assert axis < len(data.shape) + + rank = len(data.shape) + assert 1 <= rank <= 4, "scatter only supports 1-4 dimensions" + + ir_funcs = { + 1: gen_ir_1d, + 2: gen_ir_2d, + 3: gen_ir_3d, + 4: gen_ir_4d, + } + + out_shape = data.shape + out_buf = tvm.tir.decl_buffer(out_shape, data.dtype, "out_buf") + out = te.extern( + [out_shape], + [data, indices, updates], + lambda ins, outs: ir_funcs[rank](ins[0], ins[1], ins[2], axis, outs[0]), + dtype=data.dtype, + out_buffers=[out_buf], + name="scatter_gpu", + tag="scatter_gpu", + ) + + return out diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index d18a12f20fa5..e636fe3f0037 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -910,6 +910,7 @@ def verify_reverse_sequence(x_data, seq_lengths, batch_axis, seq_axis, ref_res): ) +@tvm.testing.uses_gpu def test_scatter(): def ref_scatter(data, indices, updates, axis=0): idx = np.indices(indices.shape).reshape(indices.ndim, -1) @@ -935,13 +936,34 @@ def verify_scatter(dshape, ishape, axis=0): indices_np = np.random.randint(-dshape[axis], dshape[axis] - 1, ishape).astype("int64") ref_res = ref_scatter(data_np, indices_np, updates_np, axis) - # TODO(mbrookhart): expand testing when adding more backend schedules - for target, ctx in [("llvm", tvm.cpu())]: + + for target, ctx in tvm.testing.enabled_targets(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) op_res = intrp.evaluate(func)(data_np, indices_np, updates_np) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) + def verify_dynamic_scatter(dshape, ishape, axis=0): + d = relay.var("d", relay.TensorType([relay.Any() for i in range(len(dshape))], "float32")) + i = relay.var("i", relay.TensorType([relay.Any() for i in range(len(ishape))], "int64")) + u = relay.var("u", relay.TensorType([relay.Any() for i in range(len(ishape))], "float32")) + z = relay.op.scatter(d, i, u, axis) + + func = relay.Function([d, i, u], z) + + data_np = np.random.uniform(size=dshape).astype("float32") + updates_np = np.random.uniform(size=ishape).astype("float32") + indices_np = np.random.randint(-dshape[axis], dshape[axis] - 1, ishape).astype("int64") + + ref_res = ref_scatter(data_np, indices_np, updates_np, axis) + + for target, ctx in tvm.testing.enabled_targets(): + for kind in ["vm", "debug"]: + mod = tvm.ir.IRModule.from_expr(func) + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(data_np, indices_np, updates_np) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) + verify_scatter((10,), (10,), 0) verify_scatter((10, 5), (10, 5), -2) verify_scatter((10, 5), (10, 5), -1) @@ -950,11 +972,26 @@ def verify_scatter(dshape, ishape, axis=0): verify_scatter((2, 3, 4), (1, 3, 4), 0) verify_scatter((2, 3, 4), (2, 1, 4), 1) verify_scatter((2, 3, 4), (2, 3, 1), 2) + verify_scatter((4, 2, 1), (1, 1, 1), 0) verify_scatter((2, 3, 4, 5), (1, 3, 4, 5), 0) verify_scatter((6, 3, 4, 5), (2, 3, 4, 5), 1) verify_scatter((2, 3, 8, 5), (2, 3, 1, 1), 2) verify_scatter((16, 16, 4, 5), (16, 16, 4, 5), 3) + verify_dynamic_scatter((10,), (10,), 0) + verify_dynamic_scatter((10, 5), (10, 5), -2) + verify_dynamic_scatter((10, 5), (10, 5), -1) + verify_dynamic_scatter((10, 5), (3, 5), 0) + verify_dynamic_scatter((12, 4), (7, 2), 1) + verify_dynamic_scatter((2, 3, 4), (1, 3, 4), 0) + verify_dynamic_scatter((2, 3, 4), (2, 1, 4), 1) + verify_dynamic_scatter((2, 3, 4), (2, 3, 1), 2) + verify_dynamic_scatter((4, 2, 1), (1, 1, 1), 0) + verify_dynamic_scatter((2, 3, 4, 5), (1, 3, 4, 5), 0) + verify_dynamic_scatter((6, 3, 4, 5), (2, 3, 4, 5), 1) + verify_dynamic_scatter((2, 3, 8, 5), (2, 3, 1, 1), 2) + verify_dynamic_scatter((16, 16, 4, 5), (16, 16, 4, 5), 3) + def test_scatter_add(): def ref_scatter_add(data, indices, updates, axis=0):