|
20 | 20 | from tvm import te
|
21 | 21 | from ..scatter import _verify_scatter_nd_inputs
|
22 | 22 | from .nms import atomic_add
|
| 23 | +from .sort import stable_sort_by_key_thrust, is_thrust_available |
23 | 24 |
|
24 | 25 |
|
25 | 26 | def ceil_div(a, b):
|
@@ -416,6 +417,97 @@ def gen_ir_4d(data, indices, updates, axis, out, update_func):
|
416 | 417 | return ib.get()
|
417 | 418 |
|
418 | 419 |
|
| 420 | +def gen_scatter_1d_thrust(data, indices_sorted, updates_sorted, axis, out, _): |
| 421 | + """Generate scatter ir for 1d inputs, using a sorting based approach. |
| 422 | + By sorting indices and comparing neighboring two indices, we can tell which |
| 423 | + of elements in the indices tensor can scatter its update value into the output. |
| 424 | + Sorting of indices, and sorting of updates with respect to indices, can be done |
| 425 | + at the same time by thrust's sort_by_key function. It is important that sorting |
| 426 | + be done in a "stable" way via stable_sort, to guarantee deterministic output. |
| 427 | +
|
| 428 | + Parameters |
| 429 | + ---------- |
| 430 | + data : tir.Tensor |
| 431 | + The input data to the operator. |
| 432 | +
|
| 433 | + indices_sorted : tir.Tensor |
| 434 | + The sorted index locations to update. |
| 435 | +
|
| 436 | + updates : tir.Tensor |
| 437 | + The values to update, sorted by indices. |
| 438 | +
|
| 439 | + axis : int |
| 440 | + The axis to scatter on. It must be 0 for this function. |
| 441 | +
|
| 442 | + out : tir.Tensor |
| 443 | + The output tensor. |
| 444 | +
|
| 445 | + Returns |
| 446 | + ------- |
| 447 | + ret : tir |
| 448 | + The computational ir. |
| 449 | + """ |
| 450 | + assert axis == 0 |
| 451 | + n = data.shape[0] |
| 452 | + |
| 453 | + ib = tvm.tir.ir_builder.create() |
| 454 | + |
| 455 | + out_ptr = ib.buffer_ptr(out) |
| 456 | + data_ptr = ib.buffer_ptr(data) |
| 457 | + |
| 458 | + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) |
| 459 | + nthread_tx = max_threads |
| 460 | + |
| 461 | + with ib.new_scope(): |
| 462 | + nthread_bx = ceil_div(n, nthread_tx) |
| 463 | + tx = te.thread_axis("threadIdx.x") |
| 464 | + bx = te.thread_axis("blockIdx.x") |
| 465 | + ib.scope_attr(tx, "thread_extent", nthread_tx) |
| 466 | + ib.scope_attr(bx, "thread_extent", nthread_bx) |
| 467 | + tid = bx * nthread_tx + tx |
| 468 | + with ib.if_scope(tid < n): |
| 469 | + out_ptr[tid] = data_ptr[tid] |
| 470 | + |
| 471 | + indices_ptr = ib.buffer_ptr(indices_sorted) |
| 472 | + updates_ptr = ib.buffer_ptr(updates_sorted) |
| 473 | + |
| 474 | + ni = indices_sorted.shape[0] |
| 475 | + |
| 476 | + def do_update(ib, index, update): |
| 477 | + with ib.if_scope(index < 0): |
| 478 | + out_ptr[index + n] = update |
| 479 | + with ib.else_scope(): |
| 480 | + out_ptr[index] = update |
| 481 | + |
| 482 | + with ib.new_scope(): |
| 483 | + nthread_bx = ceil_div(ni, nthread_tx) |
| 484 | + tx = te.thread_axis("threadIdx.x") |
| 485 | + bx = te.thread_axis("blockIdx.x") |
| 486 | + ib.scope_attr(tx, "thread_extent", nthread_tx) |
| 487 | + ib.scope_attr(bx, "thread_extent", nthread_bx) |
| 488 | + tid = bx * nthread_tx + tx |
| 489 | + |
| 490 | + with ib.if_scope(tid == ni - 1): |
| 491 | + # The last element can always update. |
| 492 | + index = indices_ptr[tid] |
| 493 | + update = updates_ptr[tid] |
| 494 | + do_update(ib, index, update) |
| 495 | + |
| 496 | + with ib.else_scope(): |
| 497 | + with ib.if_scope(tid < ni - 1): |
| 498 | + index = indices_ptr[tid] |
| 499 | + index_next = indices_ptr[tid + 1] |
| 500 | + |
| 501 | + # If the next neighbor in the sorted list of indices has a different index, |
| 502 | + # that means thread tid is the last one to have this index. |
| 503 | + # This thread can update the output. |
| 504 | + with ib.if_scope(index != index_next): |
| 505 | + update = updates_ptr[tid] |
| 506 | + do_update(ib, index, update) |
| 507 | + |
| 508 | + return ib.get() |
| 509 | + |
| 510 | + |
419 | 511 | def scatter(data, indices, updates, axis=0):
|
420 | 512 | """Update data at positions defined by indices with values in updates
|
421 | 513 |
|
@@ -458,9 +550,21 @@ def update_func(dst_ptr, dst_index, update):
|
458 | 550 |
|
459 | 551 | out_shape = data.shape
|
460 | 552 | out_buf = tvm.tir.decl_buffer(out_shape, data.dtype, "out_buf")
|
| 553 | + |
| 554 | + in_bufs = [data] |
| 555 | + |
| 556 | + if rank == 1 and is_thrust_available(): |
| 557 | + ir_funcs[1] = gen_scatter_1d_thrust |
| 558 | + indices_sorted, updates_sorted = stable_sort_by_key_thrust( |
| 559 | + indices, updates, for_scatter=True |
| 560 | + ) |
| 561 | + in_bufs += [indices_sorted, updates_sorted] |
| 562 | + else: |
| 563 | + in_bufs += [indices, updates] |
| 564 | + |
461 | 565 | out = te.extern(
|
462 | 566 | [out_shape],
|
463 |
| - [data, indices, updates], |
| 567 | + in_bufs, |
464 | 568 | lambda ins, outs: ir_funcs[rank](ins[0], ins[1], ins[2], axis, outs[0], update_func),
|
465 | 569 | dtype=data.dtype,
|
466 | 570 | out_buffers=[out_buf],
|
|
0 commit comments