Skip to content

Commit 9cf0245

Browse files
authored
[Relay, TOPI] Add searchsorted op (#9184)
* Add relay definition * 1D cpu test working * multi dim working * gpu version working * check shape in type rel * support side * use target specfic max threads * add relay boilerplate * relay test working * cleanup topi test * fix test * add torch converter * handle other cases * more topi test * support torch bucketize * update doc * fix tests * fix lint * rebase fix * make the test case smaller * add tests for edge cases * replace "side" attribute with boolean "right" * add more descrition to binear_search IR gen params * return index from binary_search rather than update inplace * remove unused argument * format fix
1 parent 3f064b6 commit 9cf0245

File tree

19 files changed

+619
-2
lines changed

19 files changed

+619
-2
lines changed

include/tvm/relay/attrs/algorithm.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,22 @@ struct TopKAttrs : public tvm::AttrsNode<TopKAttrs> {
7676
}
7777
};
7878

79+
struct SearchSortedAttrs : public tvm::AttrsNode<SearchSortedAttrs> {
80+
bool right;
81+
DataType dtype;
82+
83+
TVM_DECLARE_ATTRS(SearchSortedAttrs, "relay.attrs.SearchSortedAttrs") {
84+
TVM_ATTR_FIELD(right).set_default(false).describe(
85+
"Controls which index is returned if a value lands exactly on one of sorted values. If "
86+
" false, the index of the first suitable location found is given. If true, return the "
87+
"last such index. If there is no suitable index, return either 0 or N (where N is the "
88+
"size of the innermost dimension).");
89+
TVM_ATTR_FIELD(dtype)
90+
.set_default(DataType::Int(32))
91+
.describe("Data type of the output indices.");
92+
}
93+
};
94+
7995
} // namespace relay
8096
} // namespace tvm
8197
#endif // TVM_RELAY_ATTRS_ALGORITHM_H_

python/tvm/relay/frontend/pytorch.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2774,6 +2774,26 @@ def all_any_common(self, op, inputs, input_types):
27742774
inp = inputs[0]
27752775
return op(inp, axis=dim, keepdims=keepdim)
27762776

2777+
def searchsorted_common(self, sorted_sequence, values, out_int32, right):
2778+
dtype = "int32" if out_int32 else "int64"
2779+
values_shape = _infer_shape(values)
2780+
2781+
if len(values_shape) == 0:
2782+
values = _op.expand_dims(values, 0)
2783+
2784+
out = _op.searchsorted(sorted_sequence, values, right=right, dtype=dtype)
2785+
2786+
if len(values_shape) == 0:
2787+
return _op.squeeze(out)
2788+
2789+
return out
2790+
2791+
def searchsorted(self, inputs, input_types):
2792+
return self.searchsorted_common(*inputs)
2793+
2794+
def bucketize(self, inputs, input_types):
2795+
return self.searchsorted_common(inputs[1], inputs[0], inputs[2], inputs[3])
2796+
27772797
# Operator mappings
27782798
def create_convert_map(self):
27792799
self.convert_map = {
@@ -2999,6 +3019,8 @@ def create_convert_map(self):
29993019
"aten::lstm": self.lstm,
30003020
"aten::all": functools.partial(self.all_any_common, _op.all),
30013021
"aten::any": functools.partial(self.all_any_common, _op.any),
3022+
"aten::searchsorted": self.searchsorted,
3023+
"aten::bucketize": self.bucketize,
30023024
}
30033025

30043026
def update_convert_map(self, custom_map):

python/tvm/relay/op/_algorithm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@
4141
register_strategy("topk", strategy.topk_strategy)
4242
register_pattern("topk", OpPattern.OPAQUE)
4343

44+
# searchsorted
45+
register_strategy("searchsorted", strategy.searchsorted_strategy)
46+
register_pattern("searchsorted", OpPattern.OPAQUE)
47+
4448

4549
@script
4650
def _topk_shape_func_input_shape(data_shape, k, axis):

python/tvm/relay/op/algorithm.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,37 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int32"):
115115
if ret_type == "both":
116116
return TupleWrapper(out, 2)
117117
return out
118+
119+
120+
def searchsorted(sorted_sequence, values, right=False, dtype="int32"):
121+
"""Find indices where elements should be inserted to maintain order.
122+
If `sorted_sequence` is N-dimensional, the innermost dimension of
123+
`values` are searched in the corresponding dimension of `sorted_sequence`.
124+
125+
Parameters
126+
----------
127+
sorted_sequence : relay.Expr
128+
N-D or 1-D Tensor, containing monotonically increasing sequence
129+
on the innermost dimension.
130+
131+
values : relay.Expr
132+
N-D Tensor containing the search values. When `sorted_sequence` is 1-D,
133+
the shape of `values` can be arbitrary. Otherwise, ranks of `sorted_sequence`
134+
and `values` must be the same, and outer N-1 axes must have the same size.
135+
136+
right : bool, optional
137+
Controls which index is returned if a value lands exactly on one of sorted values. If
138+
False, the index of the first suitable location found is given. If true, return the
139+
last such index. If there is no suitable index, return either 0 or N (where N is the
140+
size of the innermost dimension).
141+
142+
dtype : string, optional
143+
The data type of the output indices.
144+
145+
Returns
146+
-------
147+
indices : relay.Expr
148+
Tensor with same shape as values, representing the indices of
149+
elements of `values` if they are inserted in `sorted_sequence`.
150+
"""
151+
return _make.searchsorted(sorted_sequence, values, right, dtype)

python/tvm/relay/op/op_attrs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,11 @@ class TopkAttrs(Attrs):
564564
"""Attributes used in topk operators"""
565565

566566

567+
@tvm._ffi.register_object("relay.attrs.SearchSortedAttrs")
568+
class SearchSortedAttrs(Attrs):
569+
"""Attributes used in searchsorted operators"""
570+
571+
567572
@tvm._ffi.register_object("relay.attrs.TupleGetItemAttrs")
568573
class TupleGetItemAttrs(Attrs):
569574
"""Attributes used in tuple item access operators"""

python/tvm/relay/op/strategy/cuda.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1022,6 +1022,18 @@ def topk_strategy_cuda(attrs, inputs, out_type, target):
10221022
return strategy
10231023

10241024

1025+
@searchsorted_strategy.register(["cuda", "gpu"])
1026+
def searchsorted_strategy_cuda(attrs, inputs, out_type, target):
1027+
"""searchsorted cuda strategy"""
1028+
strategy = _op.OpStrategy()
1029+
strategy.add_implementation(
1030+
wrap_compute_searchsorted(topi.cuda.searchsorted),
1031+
wrap_topi_schedule(topi.cuda.schedule_extern),
1032+
name="searchsorted.cuda",
1033+
)
1034+
return strategy
1035+
1036+
10251037
@multibox_prior_strategy.register(["cuda", "gpu"])
10261038
def multibox_prior_strategy_cuda(attrs, inputs, out_type, target):
10271039
"""multibox_prior cuda strategy"""

python/tvm/relay/op/strategy/generic.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,6 +1002,31 @@ def topk_strategy(attrs, inputs, out_type, target):
10021002
return strategy
10031003

10041004

1005+
# searchsorted
1006+
def wrap_compute_searchsorted(topi_compute):
1007+
"""Wrap searchsorted compute"""
1008+
1009+
def _compute_searchsorted(attrs, inputs, out_type):
1010+
right = attrs.right
1011+
dtype = attrs.dtype
1012+
return [topi_compute(inputs[0], inputs[1], right, dtype)]
1013+
1014+
return _compute_searchsorted
1015+
1016+
1017+
# searchsorted_strategy
1018+
@override_native_generic_func("searchsorted_strategy")
1019+
def searchsorted_strategy(attrs, inputs, out_type, target):
1020+
"""searchsorted generic strategy"""
1021+
strategy = _op.OpStrategy()
1022+
strategy.add_implementation(
1023+
wrap_compute_searchsorted(topi.searchsorted),
1024+
wrap_topi_schedule(topi.generic.schedule_extern),
1025+
name="searchsorted.generic",
1026+
)
1027+
return strategy
1028+
1029+
10051030
# multibox_prior
10061031
def wrap_compute_multibox_prior(topi_compute):
10071032
"""Wrap multibox_prior compute"""

python/tvm/topi/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from .scan import *
4646
from .einsum import *
4747
from .unique import *
48+
from .searchsorted import *
4849
from . import generic
4950
from . import nn
5051
from . import x86

python/tvm/topi/cuda/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,4 @@
5959
from .sparse_reshape import *
6060
from .transform import *
6161
from .unique import *
62+
from .searchsorted import *
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=invalid-name
18+
"""searchsorted operator for GPU"""
19+
import tvm
20+
from tvm import te
21+
from .. import utils
22+
from ..searchsorted import binary_search
23+
24+
25+
def searchsorted(sorted_sequence, values, right, out_dtype="int64"):
26+
"""Find indices where elements should be inserted to maintain order.
27+
If `sorted_sequence` is N-dimensional, the innermost dimension of
28+
`values` are searched in the corresponding dimension of `sorted_sequence`.
29+
30+
Parameters
31+
----------
32+
sorted_sequence : te.Tensor
33+
N-D or 1-D Tensor, containing monotonically increasing sequence
34+
on the innermost dimension.
35+
36+
values : te.Tensor
37+
N-D Tensor containing the search values. When `sorted_sequence` is 1-D,
38+
the shape of `values` can be arbitrary. Otherwise, ranks of `sorted_sequence`
39+
and `values` must be the same, and outer N-1 axes must have the same size.
40+
41+
right : bool, optional
42+
Controls which index is returned if a value lands exactly on one of sorted values. If
43+
False, the index of the first suitable location found is given. If true, return the
44+
last such index. If there is no suitable index, return either 0 or N (where N is the
45+
size of the innermost dimension).
46+
47+
dtype : string, optional
48+
The data type of the output indices.
49+
50+
Returns
51+
-------
52+
indices : te.Tensor
53+
Tensor with same shape as values, representing the indices of
54+
elements of `values` if they are inserted in `sorted_sequence`.
55+
"""
56+
57+
def ir(sorted_sequence, values, indices):
58+
ib = tvm.tir.ir_builder.create()
59+
sorted_sequence_shape = sorted_sequence.shape
60+
values_shape = values.shape
61+
num_search = utils.prod(values_shape)
62+
search_range = sorted_sequence_shape[-1]
63+
64+
sorted_sequence = ib.buffer_ptr(sorted_sequence)
65+
values = ib.buffer_ptr(values)
66+
indices = ib.buffer_ptr(indices)
67+
68+
max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
69+
bx = te.thread_axis("blockIdx.x")
70+
tx = te.thread_axis("threadIdx.x")
71+
ib.scope_attr(
72+
bx, "thread_extent", tvm.tir.indexdiv(num_search + max_threads - 1, max_threads)
73+
)
74+
ib.scope_attr(tx, "thread_extent", max_threads)
75+
tid = bx * max_threads + tx
76+
77+
with ib.if_scope(tid < num_search):
78+
if len(sorted_sequence_shape) == 1:
79+
sequence_offset = 0
80+
else:
81+
sequence_id = tid // values_shape[-1]
82+
sequence_offset = sequence_id * search_range
83+
84+
indices[tid] = binary_search(
85+
ib,
86+
sequence_offset,
87+
search_range,
88+
sorted_sequence,
89+
values[tid],
90+
right,
91+
out_dtype,
92+
)
93+
94+
return ib.get()
95+
96+
return te.extern(
97+
values.shape,
98+
[sorted_sequence, values],
99+
lambda ins, outs: ir(ins[0], ins[1], outs[0]),
100+
name="searchsorted",
101+
dtype=out_dtype,
102+
)

0 commit comments

Comments
 (0)