Skip to content

Commit bf80b02

Browse files
masahimasa
authored andcommitted
[TOPI] GPU scatter 1D via sorting based approach (apache#7056)
* add thrust stable sort * rename * scatter via sort working * correctly handles negative indices * clean up, add some comments * add doc string * remove scatter benchmark stuff * add more doc * fix typo * lint fix * silence lint * fix py format * check for thrust availablity before test Co-authored-by: masa <masa@pop-os.localdomain>
1 parent 82a8d31 commit bf80b02

File tree

5 files changed

+271
-2
lines changed

5 files changed

+271
-2
lines changed

cmake/modules/CUDA.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ if(USE_CUDA)
6060
message(STATUS "Build with Thrust support")
6161
cmake_minimum_required(VERSION 3.13) # to compile CUDA code
6262
enable_language(CUDA)
63+
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --extended-lambda")
6364
file(GLOB CONTRIB_THRUST_SRC src/runtime/contrib/thrust/*.cu)
6465
list(APPEND RUNTIME_SRCS ${CONTRIB_THRUST_SRC})
6566
endif(USE_THRUST)

python/tvm/topi/cuda/scatter.py

Lines changed: 105 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from tvm import te
2121
from ..scatter import _verify_scatter_nd_inputs
2222
from .nms import atomic_add
23+
from .sort import stable_sort_by_key_thrust, is_thrust_available
2324

2425

2526
def ceil_div(a, b):
@@ -416,6 +417,97 @@ def gen_ir_4d(data, indices, updates, axis, out, update_func):
416417
return ib.get()
417418

418419

420+
def gen_scatter_1d_thrust(data, indices_sorted, updates_sorted, axis, out, _):
421+
"""Generate scatter ir for 1d inputs, using a sorting based approach.
422+
By sorting indices and comparing neighboring two indices, we can tell which
423+
of elements in the indices tensor can scatter its update value into the output.
424+
Sorting of indices, and sorting of updates with respect to indices, can be done
425+
at the same time by thrust's sort_by_key function. It is important that sorting
426+
be done in a "stable" way via stable_sort, to guarantee deterministic output.
427+
428+
Parameters
429+
----------
430+
data : tir.Tensor
431+
The input data to the operator.
432+
433+
indices_sorted : tir.Tensor
434+
The sorted index locations to update.
435+
436+
updates : tir.Tensor
437+
The values to update, sorted by indices.
438+
439+
axis : int
440+
The axis to scatter on. It must be 0 for this function.
441+
442+
out : tir.Tensor
443+
The output tensor.
444+
445+
Returns
446+
-------
447+
ret : tir
448+
The computational ir.
449+
"""
450+
assert axis == 0
451+
n = data.shape[0]
452+
453+
ib = tvm.tir.ir_builder.create()
454+
455+
out_ptr = ib.buffer_ptr(out)
456+
data_ptr = ib.buffer_ptr(data)
457+
458+
max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
459+
nthread_tx = max_threads
460+
461+
with ib.new_scope():
462+
nthread_bx = ceil_div(n, nthread_tx)
463+
tx = te.thread_axis("threadIdx.x")
464+
bx = te.thread_axis("blockIdx.x")
465+
ib.scope_attr(tx, "thread_extent", nthread_tx)
466+
ib.scope_attr(bx, "thread_extent", nthread_bx)
467+
tid = bx * nthread_tx + tx
468+
with ib.if_scope(tid < n):
469+
out_ptr[tid] = data_ptr[tid]
470+
471+
indices_ptr = ib.buffer_ptr(indices_sorted)
472+
updates_ptr = ib.buffer_ptr(updates_sorted)
473+
474+
ni = indices_sorted.shape[0]
475+
476+
def do_update(ib, index, update):
477+
with ib.if_scope(index < 0):
478+
out_ptr[index + n] = update
479+
with ib.else_scope():
480+
out_ptr[index] = update
481+
482+
with ib.new_scope():
483+
nthread_bx = ceil_div(ni, nthread_tx)
484+
tx = te.thread_axis("threadIdx.x")
485+
bx = te.thread_axis("blockIdx.x")
486+
ib.scope_attr(tx, "thread_extent", nthread_tx)
487+
ib.scope_attr(bx, "thread_extent", nthread_bx)
488+
tid = bx * nthread_tx + tx
489+
490+
with ib.if_scope(tid == ni - 1):
491+
# The last element can always update.
492+
index = indices_ptr[tid]
493+
update = updates_ptr[tid]
494+
do_update(ib, index, update)
495+
496+
with ib.else_scope():
497+
with ib.if_scope(tid < ni - 1):
498+
index = indices_ptr[tid]
499+
index_next = indices_ptr[tid + 1]
500+
501+
# If the next neighbor in the sorted list of indices has a different index,
502+
# that means thread tid is the last one to have this index.
503+
# This thread can update the output.
504+
with ib.if_scope(index != index_next):
505+
update = updates_ptr[tid]
506+
do_update(ib, index, update)
507+
508+
return ib.get()
509+
510+
419511
def scatter(data, indices, updates, axis=0):
420512
"""Update data at positions defined by indices with values in updates
421513
@@ -458,9 +550,21 @@ def update_func(dst_ptr, dst_index, update):
458550

459551
out_shape = data.shape
460552
out_buf = tvm.tir.decl_buffer(out_shape, data.dtype, "out_buf")
553+
554+
in_bufs = [data]
555+
556+
if rank == 1 and is_thrust_available():
557+
ir_funcs[1] = gen_scatter_1d_thrust
558+
indices_sorted, updates_sorted = stable_sort_by_key_thrust(
559+
indices, updates, for_scatter=True
560+
)
561+
in_bufs += [indices_sorted, updates_sorted]
562+
else:
563+
in_bufs += [indices, updates]
564+
461565
out = te.extern(
462566
[out_shape],
463-
[data, indices, updates],
567+
in_bufs,
464568
lambda ins, outs: ir_funcs[rank](ins[0], ins[1], ins[2], axis, outs[0], update_func),
465569
dtype=data.dtype,
466570
out_buffers=[out_buf],

python/tvm/topi/cuda/sort.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument
18-
"""Argsort operator """
18+
"""Sort related operators """
1919
import tvm
2020
from tvm import te
21+
from tvm._ffi import get_global_func
2122

2223
from .injective import schedule_injective_from_existing
2324
from ..math import identity
@@ -597,3 +598,59 @@ def schedule_topk(outs):
597598
The computation schedule for the op.
598599
"""
599600
return _schedule_sort(outs)
601+
602+
603+
def stable_sort_by_key_thrust(keys, values, for_scatter=False):
604+
"""Sort values with respect to keys using thrust.
605+
Both keys and values will be sorted and returned.
606+
Sorting is done via stable sort, so relative ordering among
607+
ties are preserved.
608+
609+
Parameters
610+
----------
611+
keys: tvm.te.Tensor
612+
The 1D input keys.
613+
614+
values : tvm.te.Tensor,
615+
The 1D input values.
616+
617+
for_scatter: bool, optional
618+
If True, negative keys are interpreted as negative indices.
619+
Before sorting, negative indices are converted to corresponding positive indices.
620+
The output keys (indices) are all positive.
621+
This option is introduced to optimize the scatter implementation.
622+
623+
Returns
624+
-------
625+
keys_sorted : tvm.te.Tensor
626+
The sorted keys
627+
628+
values_sorted : tvm.te.Tensor
629+
The values sorted with respect to the keys
630+
"""
631+
keys_buf = tvm.tir.decl_buffer(keys.shape, keys.dtype, "keys_buf", data_alignment=8)
632+
values_buf = tvm.tir.decl_buffer(values.shape, values.dtype, "values_buf", data_alignment=8)
633+
out_bufs = [
634+
tvm.tir.decl_buffer(keys.shape, keys.dtype, "keys_buf", data_alignment=8),
635+
tvm.tir.decl_buffer(keys.shape, values.dtype, "values_buf", data_alignment=8),
636+
]
637+
out = te.extern(
638+
[keys.shape, values.shape],
639+
[keys, values],
640+
lambda ins, outs: tvm.tir.call_packed(
641+
"tvm.contrib.thrust.stable_sort_by_key", ins[0], ins[1], outs[0], outs[1], for_scatter
642+
),
643+
in_buffers=[keys_buf, values_buf],
644+
out_buffers=out_bufs,
645+
dtype=[keys.dtype, values.dtype],
646+
name="stable_sort_by_key",
647+
tag="stable_sort_by_key",
648+
)
649+
return out[0], out[1]
650+
651+
652+
def is_thrust_available():
653+
"""
654+
Test if thrust based sorting ops are available.
655+
"""
656+
return get_global_func("tvm.contrib.thrust.sort", allow_missing=True) is not None

src/runtime/contrib/thrust/thrust.cu

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,5 +163,78 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort")
163163
thrust_sort_common(input, values_out, indices_out, is_ascend, get_sort_len,
164164
data_dtype, out_dtype);
165165
});
166+
167+
template<typename KeyType, typename ValueType>
168+
void thrust_stable_sort_by_key(DLTensor* keys_in,
169+
DLTensor* values_in,
170+
DLTensor* keys_out,
171+
DLTensor* values_out,
172+
bool for_scatter) {
173+
const auto size = keys_in->shape[0];
174+
thrust::device_ptr<KeyType> keys_in_ptr(static_cast<KeyType *>(keys_in->data));
175+
thrust::device_ptr<ValueType> values_in_ptr(static_cast<ValueType *>(values_in->data));
176+
thrust::device_ptr<KeyType> keys_out_ptr(static_cast<KeyType *>(keys_out->data));
177+
thrust::device_ptr<ValueType> values_out_ptr(static_cast<ValueType *>(values_out->data));
178+
179+
if (for_scatter) {
180+
thrust::transform(keys_in_ptr, keys_in_ptr + size, keys_out_ptr, [size] __device__(KeyType k) {
181+
if (k < 0) return k + static_cast<KeyType>(size);
182+
return k;
183+
});
184+
} else {
185+
thrust::copy(keys_in_ptr, keys_in_ptr + size, keys_out_ptr);
186+
}
187+
thrust::copy(values_in_ptr, values_in_ptr + size, values_out_ptr);
188+
189+
thrust::stable_sort_by_key(keys_out_ptr, keys_out_ptr + size, values_out_ptr);
190+
}
191+
192+
TVM_REGISTER_GLOBAL("tvm.contrib.thrust.stable_sort_by_key")
193+
.set_body([](TVMArgs args, TVMRetValue* ret) {
194+
ICHECK_GE(args.num_args, 5);
195+
DLTensor* keys_in = args[0];
196+
DLTensor* values_in = args[1];
197+
DLTensor* keys_out = args[2];
198+
DLTensor* values_out = args[3];
199+
bool for_scatter = args[4];
200+
201+
auto key_dtype = DLDataType2String(keys_in->dtype);
202+
auto value_dtype = DLDataType2String(values_in->dtype);
203+
204+
if (key_dtype == "int32") {
205+
if (value_dtype == "int32") {
206+
thrust_stable_sort_by_key<int, int>(keys_in, values_in, keys_out, values_out,
207+
for_scatter);
208+
} else if (value_dtype == "float32") {
209+
thrust_stable_sort_by_key<int, float>(keys_in, values_in, keys_out, values_out,
210+
for_scatter);
211+
} else {
212+
LOG(FATAL) << "Unsupported value dtype: " << value_dtype;
213+
}
214+
} else if (key_dtype == "int64") {
215+
if (value_dtype == "int32") {
216+
thrust_stable_sort_by_key<int64_t, int>(keys_in, values_in, keys_out, values_out,
217+
for_scatter);
218+
} else if (value_dtype == "float32") {
219+
thrust_stable_sort_by_key<int64_t, float>(keys_in, values_in, keys_out, values_out,
220+
for_scatter);
221+
} else {
222+
LOG(FATAL) << "Unsupported value dtype: " << value_dtype;
223+
}
224+
} else if (key_dtype == "float32") {
225+
if (value_dtype == "int32") {
226+
thrust_stable_sort_by_key<float, int>(keys_in, values_in, keys_out, values_out,
227+
for_scatter);
228+
} else if (value_dtype == "float32") {
229+
thrust_stable_sort_by_key<float, float>(keys_in, values_in, keys_out, values_out,
230+
for_scatter);
231+
} else {
232+
LOG(FATAL) << "Unsupported value dtype: " << value_dtype;
233+
}
234+
} else {
235+
LOG(FATAL) << "Unsupported key dtype: " << key_dtype;
236+
}
237+
});
238+
166239
} // namespace contrib
167240
} // namespace tvm

tests/python/contrib/test_sort.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import tvm
1818
import tvm.testing
1919
from tvm import te
20+
from tvm.topi.cuda import stable_sort_by_key_thrust, is_thrust_available
2021
import numpy as np
2122

2223

@@ -90,6 +91,39 @@ def test_sort_np():
9091
tvm.testing.assert_allclose(c.asnumpy(), np_out, rtol=1e-5)
9192

9293

94+
def test_thrust_stable_sort_by_key():
95+
if not is_thrust_available():
96+
print("skip because thrust is not enabled...")
97+
return
98+
99+
size = 6
100+
keys = te.placeholder((size,), name="keys", dtype="int32")
101+
values = te.placeholder((size,), name="values", dtype="int32")
102+
103+
keys_out, values_out = stable_sort_by_key_thrust(keys, values)
104+
105+
ctx = tvm.gpu(0)
106+
target = "cuda"
107+
s = te.create_schedule([keys_out.op, values_out.op])
108+
f = tvm.build(s, [keys, values, keys_out, values_out], target)
109+
110+
keys_np = np.array([1, 4, 2, 8, 2, 7], np.int32)
111+
values_np = np.random.randint(0, 10, size=(size,)).astype(np.int32)
112+
keys_np_out = np.zeros(keys_np.shape, np.int32)
113+
values_np_out = np.zeros(values_np.shape, np.int32)
114+
keys_in = tvm.nd.array(keys_np, ctx)
115+
values_in = tvm.nd.array(values_np, ctx)
116+
keys_out = tvm.nd.array(keys_np_out, ctx)
117+
values_out = tvm.nd.array(values_np_out, ctx)
118+
f(keys_in, values_in, keys_out, values_out)
119+
120+
ref_keys_out = np.sort(keys_np)
121+
ref_values_out = np.array([values_np[i] for i in np.argsort(keys_np)])
122+
tvm.testing.assert_allclose(keys_out.asnumpy(), ref_keys_out, rtol=1e-5)
123+
tvm.testing.assert_allclose(values_out.asnumpy(), ref_values_out, rtol=1e-5)
124+
125+
93126
if __name__ == "__main__":
94127
test_sort()
95128
test_sort_np()
129+
test_thrust_stable_sort_by_key()

0 commit comments

Comments
 (0)