|
| 1 | +# Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +# or more contributor license agreements. See the NOTICE file |
| 3 | +# distributed with this work for additional information |
| 4 | +# regarding copyright ownership. The ASF licenses this file |
| 5 | +# to you under the Apache License, Version 2.0 (the |
| 6 | +# "License"); you may not use this file except in compliance |
| 7 | +# with the License. You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, |
| 12 | +# software distributed under the License is distributed on an |
| 13 | +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +# KIND, either express or implied. See the License for the |
| 15 | +# specific language governing permissions and limitations |
| 16 | +# under the License. |
| 17 | +# pylint: disable=invalid-name |
| 18 | +"""searchsorted operator for GPU""" |
| 19 | +import tvm |
| 20 | +from tvm import te |
| 21 | +from .. import utils |
| 22 | +from ..searchsorted import binary_search |
| 23 | + |
| 24 | + |
| 25 | +def searchsorted(sorted_sequence, values, right, out_dtype="int64"): |
| 26 | + """Find indices where elements should be inserted to maintain order. |
| 27 | + If `sorted_sequence` is N-dimensional, the innermost dimension of |
| 28 | + `values` are searched in the corresponding dimension of `sorted_sequence`. |
| 29 | +
|
| 30 | + Parameters |
| 31 | + ---------- |
| 32 | + sorted_sequence : te.Tensor |
| 33 | + N-D or 1-D Tensor, containing monotonically increasing sequence |
| 34 | + on the innermost dimension. |
| 35 | +
|
| 36 | + values : te.Tensor |
| 37 | + N-D Tensor containing the search values. When `sorted_sequence` is 1-D, |
| 38 | + the shape of `values` can be arbitrary. Otherwise, ranks of `sorted_sequence` |
| 39 | + and `values` must be the same, and outer N-1 axes must have the same size. |
| 40 | +
|
| 41 | + right : bool, optional |
| 42 | + Controls which index is returned if a value lands exactly on one of sorted values. If |
| 43 | + False, the index of the first suitable location found is given. If true, return the |
| 44 | + last such index. If there is no suitable index, return either 0 or N (where N is the |
| 45 | + size of the innermost dimension). |
| 46 | +
|
| 47 | + dtype : string, optional |
| 48 | + The data type of the output indices. |
| 49 | +
|
| 50 | + Returns |
| 51 | + ------- |
| 52 | + indices : te.Tensor |
| 53 | + Tensor with same shape as values, representing the indices of |
| 54 | + elements of `values` if they are inserted in `sorted_sequence`. |
| 55 | + """ |
| 56 | + |
| 57 | + def ir(sorted_sequence, values, indices): |
| 58 | + ib = tvm.tir.ir_builder.create() |
| 59 | + sorted_sequence_shape = sorted_sequence.shape |
| 60 | + values_shape = values.shape |
| 61 | + num_search = utils.prod(values_shape) |
| 62 | + search_range = sorted_sequence_shape[-1] |
| 63 | + |
| 64 | + sorted_sequence = ib.buffer_ptr(sorted_sequence) |
| 65 | + values = ib.buffer_ptr(values) |
| 66 | + indices = ib.buffer_ptr(indices) |
| 67 | + |
| 68 | + max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads) |
| 69 | + bx = te.thread_axis("blockIdx.x") |
| 70 | + tx = te.thread_axis("threadIdx.x") |
| 71 | + ib.scope_attr( |
| 72 | + bx, "thread_extent", tvm.tir.indexdiv(num_search + max_threads - 1, max_threads) |
| 73 | + ) |
| 74 | + ib.scope_attr(tx, "thread_extent", max_threads) |
| 75 | + tid = bx * max_threads + tx |
| 76 | + |
| 77 | + with ib.if_scope(tid < num_search): |
| 78 | + if len(sorted_sequence_shape) == 1: |
| 79 | + sequence_offset = 0 |
| 80 | + else: |
| 81 | + sequence_id = tid // values_shape[-1] |
| 82 | + sequence_offset = sequence_id * search_range |
| 83 | + |
| 84 | + indices[tid] = binary_search( |
| 85 | + ib, |
| 86 | + sequence_offset, |
| 87 | + search_range, |
| 88 | + sorted_sequence, |
| 89 | + values[tid], |
| 90 | + right, |
| 91 | + out_dtype, |
| 92 | + ) |
| 93 | + |
| 94 | + return ib.get() |
| 95 | + |
| 96 | + return te.extern( |
| 97 | + values.shape, |
| 98 | + [sorted_sequence, values], |
| 99 | + lambda ins, outs: ir(ins[0], ins[1], outs[0]), |
| 100 | + name="searchsorted", |
| 101 | + dtype=out_dtype, |
| 102 | + ) |
0 commit comments