Skip to content

Commit

Permalink
Scatter on Cuda (apache#6533)
Browse files Browse the repository at this point in the history
* working cuda scatter

fix lint

fix pylint again

* cuda scatter with threading

* add dynamic shape tests

* remove unused variable
  • Loading branch information
Matthew Brookhart authored and Trevor Morris committed Dec 4, 2020
1 parent fe60133 commit d75a7ca
Show file tree
Hide file tree
Showing 6 changed files with 515 additions and 8 deletions.
2 changes: 1 addition & 1 deletion python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def compute_scatter(attrs, inputs, output_type):
return [topi.scatter(inputs[0], inputs[1], inputs[2], attrs.axis)]


_reg.register_schedule("scatter", strategy.schedule_scatter)
_reg.register_strategy("scatter", strategy.scatter_strategy)

# scatter_add
@_reg.register_compute("scatter_add")
Expand Down
13 changes: 13 additions & 0 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,19 @@ def sparse_dense_padded_strategy_cuda(attrs, inputs, out_type, target):
return strategy


@scatter_strategy.register(["cuda", "gpu"])
def scatter_cuda(attrs, inputs, out_type, target):
"""sparse dense cuda strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_scatter(topi.cuda.scatter),
wrap_topi_schedule(topi.generic.schedule_extern),
name="scatter.cuda",
plevel=10,
)
return strategy


@argsort_strategy.register(["cuda", "gpu"])
def argsort_strategy_cuda(attrs, inputs, out_type, target):
"""argsort cuda strategy"""
Expand Down
23 changes: 18 additions & 5 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1032,11 +1032,24 @@ def schedule_argwhere(attrs, outs, target):


# scatter
@generic_func
def schedule_scatter(attrs, outs, target):
"""schedule scatter"""
with target:
return topi.generic.schedule_scatter(outs)
@override_native_generic_func("scatter_strategy")
def scatter_strategy(attrs, outs, out_type, target):
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_scatter(topi.scatter),
wrap_topi_schedule(topi.generic.schedule_scatter),
name="scatter.generic",
)
return strategy


def wrap_compute_scatter(topi_compute):
"""Wrap scatter topi compute"""

def _compute_scatter(attrs, inputs, _):
return [topi_compute(inputs[0], inputs[1], inputs[2], axis=attrs.axis)]

return _compute_scatter


# scatter_add
Expand Down
1 change: 1 addition & 0 deletions python/tvm/topi/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from .ssd import *
from .nms import get_valid_counts, non_max_suppression
from .rcnn import *
from .scatter import *
from .sort import *
from .conv2d_nhwc_tensorcore import *
from .conv3d_ndhwc_tensorcore import *
Expand Down
Loading

0 comments on commit d75a7ca

Please sign in to comment.