Skip to content

Commit eb2e0dc

Browse files
Matthew BrookhartTrevor Morris
authored andcommitted
Scatter on Cuda (apache#6533)
* working cuda scatter fix lint fix pylint again * cuda scatter with threading * add dynamic shape tests * remove unused variable
1 parent 6f0d417 commit eb2e0dc

File tree

6 files changed

+515
-8
lines changed

6 files changed

+515
-8
lines changed

python/tvm/relay/op/_transform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def compute_scatter(attrs, inputs, output_type):
104104
return [topi.scatter(inputs[0], inputs[1], inputs[2], attrs.axis)]
105105

106106

107-
_reg.register_schedule("scatter", strategy.schedule_scatter)
107+
_reg.register_strategy("scatter", strategy.scatter_strategy)
108108

109109
# scatter_add
110110
@_reg.register_compute("scatter_add")

python/tvm/relay/op/strategy/cuda.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -651,6 +651,19 @@ def sparse_dense_padded_strategy_cuda(attrs, inputs, out_type, target):
651651
return strategy
652652

653653

654+
@scatter_strategy.register(["cuda", "gpu"])
655+
def scatter_cuda(attrs, inputs, out_type, target):
656+
"""sparse dense cuda strategy"""
657+
strategy = _op.OpStrategy()
658+
strategy.add_implementation(
659+
wrap_compute_scatter(topi.cuda.scatter),
660+
wrap_topi_schedule(topi.generic.schedule_extern),
661+
name="scatter.cuda",
662+
plevel=10,
663+
)
664+
return strategy
665+
666+
654667
@argsort_strategy.register(["cuda", "gpu"])
655668
def argsort_strategy_cuda(attrs, inputs, out_type, target):
656669
"""argsort cuda strategy"""

python/tvm/relay/op/strategy/generic.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1032,11 +1032,24 @@ def schedule_argwhere(attrs, outs, target):
10321032

10331033

10341034
# scatter
1035-
@generic_func
1036-
def schedule_scatter(attrs, outs, target):
1037-
"""schedule scatter"""
1038-
with target:
1039-
return topi.generic.schedule_scatter(outs)
1035+
@override_native_generic_func("scatter_strategy")
1036+
def scatter_strategy(attrs, outs, out_type, target):
1037+
strategy = _op.OpStrategy()
1038+
strategy.add_implementation(
1039+
wrap_compute_scatter(topi.scatter),
1040+
wrap_topi_schedule(topi.generic.schedule_scatter),
1041+
name="scatter.generic",
1042+
)
1043+
return strategy
1044+
1045+
1046+
def wrap_compute_scatter(topi_compute):
1047+
"""Wrap scatter topi compute"""
1048+
1049+
def _compute_scatter(attrs, inputs, _):
1050+
return [topi_compute(inputs[0], inputs[1], inputs[2], axis=attrs.axis)]
1051+
1052+
return _compute_scatter
10401053

10411054

10421055
# scatter_add

python/tvm/topi/cuda/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from .ssd import *
4747
from .nms import get_valid_counts, non_max_suppression
4848
from .rcnn import *
49+
from .scatter import *
4950
from .sort import *
5051
from .conv2d_nhwc_tensorcore import *
5152
from .conv3d_ndhwc_tensorcore import *

0 commit comments

Comments
 (0)