diff --git a/python/tvm/topi/cuda/argwhere.py b/python/tvm/topi/cuda/argwhere.py index 588680fa83b9c..5dc6808e6af80 100644 --- a/python/tvm/topi/cuda/argwhere.py +++ b/python/tvm/topi/cuda/argwhere.py @@ -27,7 +27,6 @@ from .sort import topk, topk_thrust, argsort, argsort_thrust from .. import tag from ..transform import strided_slice, adv_index, squeeze, dynamic_strided_slice1 -from ..utils import const_vector logger = logging.getLogger("topi") @@ -47,17 +46,6 @@ def _get_sort_func(mode=0): return ret -def _create_end(data, out, end): - ib = tvm.tir.ir_builder.create() - end = tvm.tir.const(end, dtype=out.dtype) - out_ptr = ib.buffer_ptr(out) - bx = te.thread_axis("blockIdx.x") - ib.scope_attr(bx, "thread_extent", 1) - out_ptr[0] = data.shape[0] - out_ptr[1] = end - return ib.get() - - def argwhere_1d_ir(condition, out): """Low level IR for argwhere 1D @@ -247,19 +235,8 @@ def argwhere_2d(output_shape, condition): out2 = sort_func(out1, axis=0, dtype="int32") out3 = squeeze(out2) - return adv_index(out, [out3]) + out = adv_index(out, [out3]) else: - # out_shape = [2] - # out_buf = tvm.tir.decl_buffer(out_shape, "int32", "strided_slice_out_buf") - # end = te.extern( - # [out_shape], - # [out], - # lambda ins, outs: _create_end(ins[0], outs[0], 2), - # dtype="int32", - # out_buffers=[out_buf], - # name="strided_slice_gpu_end0", - # tag="strided_slice_gpu_end0", - # ) out1 = dynamic_strided_slice1(out, [0, 1], [out.shape[0], 2], [1, 1]) out2 = sort_func(out1, axis=0, dtype="int32") out3 = squeeze(out2) @@ -268,28 +245,8 @@ def argwhere_2d(output_shape, condition): out1 = dynamic_strided_slice1(out, [0, 0], [out.shape[0], 1], [1, 1]) out2 = sort_func(out1, axis=0, dtype="int32") out3 = squeeze(out2) - return adv_index(out, [out3]) - # out1 = dynamic_strided_slice1(out, [0, 1], [-1, -1]) - # out1 = strided_slice(out, const_vector([0, 1]), end) - # out2 = sort_func(out1, axis=0, dtype="int32") - # out3 = squeeze(out2) - # out = adv_index(out, [out3]) - - # out_buf = tvm.tir.decl_buffer(out_shape, "int32", "strided_slice_out_buf") - # end = te.extern( - # [out_shape], - # [out], - # lambda ins, outs: _create_end(ins[0], outs[0], 1), - # dtype="int32", - # out_buffers=[out_buf], - # name="strided_slice_gpu_end1", - # tag="strided_slice_gpu_end1", - # ) - # out1 = strided_slice(out, const_vector([0, 0]), end) - # out2 = sort_func(out1, axis=0, dtype="int32") - # out3 = squeeze(out2) - - # return adv_index(out, [out3]) + out = adv_index(out, [out3]) + return out def argwhere_3d_ir(condition, out): @@ -382,18 +339,25 @@ def argwhere_3d(output_shape, condition): tag="argwhere3d_gpu", ) - if out.shape[0] <= 1: + if isinstance(out.shape[0], (int, tvm.tir.expr.IntImm)) and int(out.shape[0]) <= 1: return out # sort the output from the least significant to the most significant # column. sort_func = _get_sort_func(1) - for i in reversed(range(3)): - out1 = strided_slice(out, [0, i], [out.shape[0], i + 1]) - out2 = sort_func(out1, axis=0, dtype="int32") - out3 = squeeze(out2) - out = adv_index(out, [out3]) + if isinstance(out.shape[0], (int, tvm.tir.expr.IntImm)): + for i in reversed(range(3)): + out1 = strided_slice(out, [0, i], [out.shape[0], i + 1]) + out2 = sort_func(out1, axis=0, dtype="int32") + out3 = squeeze(out2) + out = adv_index(out, [out3]) + else: + for i in reversed(range(3)): + out1 = dynamic_strided_slice1(out, [0, i], [out.shape[0], i + 1], [1, 1]) + out2 = sort_func(out1, axis=0, dtype="int32") + out3 = squeeze(out2) + out = adv_index(out, [out3]) return out @@ -490,17 +454,24 @@ def argwhere_4d(output_shape, condition): tag="argwhere4d_gpu", ) - if out.shape[0] <= 1: + if isinstance(out.shape[0], (int, tvm.tir.expr.IntImm)) and int(out.shape[0]) <= 1: return out # sort the output from the least significant to the most significant # column. sort_func = _get_sort_func(1) - for i in reversed(range(4)): - out1 = strided_slice(out, [0, i], [out.shape[0], i + 1]) - out2 = sort_func(out1, axis=0, dtype="int32") - out3 = squeeze(out2) - out = adv_index(out, [out3]) + if isinstance(out.shape[0], (int, tvm.tir.expr.IntImm)): + for i in reversed(range(4)): + out1 = strided_slice(out, [0, i], [out.shape[0], i + 1]) + out2 = sort_func(out1, axis=0, dtype="int32") + out3 = squeeze(out2) + out = adv_index(out, [out3]) + else: + for i in reversed(range(4)): + out1 = dynamic_strided_slice1(out, [0, i], [out.shape[0], i + 1], [1, 1]) + out2 = sort_func(out1, axis=0, dtype="int32") + out3 = squeeze(out2) + out = adv_index(out, [out3]) return out @@ -601,17 +572,24 @@ def argwhere_5d(output_shape, condition): tag="argwhere5d_gpu", ) - if out.shape[0] <= 1: + if isinstance(out.shape[0], (int, tvm.tir.expr.IntImm)) and int(out.shape[0]) <= 1: return out # sort the output from the least significant to the most significant # column. sort_func = _get_sort_func(1) - for i in reversed(range(5)): - out1 = strided_slice(out, [0, i], [out.shape[0], i + 1]) - out2 = sort_func(out1, axis=0, dtype="int32") - out3 = squeeze(out2) - out = adv_index(out, [out3]) + if isinstance(out.shape[0], (int, tvm.tir.expr.IntImm)): + for i in reversed(range(5)): + out1 = strided_slice(out, [0, i], [out.shape[0], i + 1]) + out2 = sort_func(out1, axis=0, dtype="int32") + out3 = squeeze(out2) + out = adv_index(out, [out3]) + else: + for i in reversed(range(5)): + out1 = dynamic_strided_slice1(out, [0, i], [out.shape[0], i + 1], [1, 1]) + out2 = sort_func(out1, axis=0, dtype="int32") + out3 = squeeze(out2) + out = adv_index(out, [out3]) return out diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index 9cf1dc62eaf9e..e64e9a58554b0 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -218,6 +218,7 @@ def strided_slice(a, begin, end, strides=None, slice_mode="end"): strides = [] return cpp.strided_slice(a, begin, end, strides, slice_mode) + def dynamic_strided_slice1(a, begin, end, strides): return cpp.dynamic_strided_slice1(a, begin, end, strides) diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 23756fb307dee..7c4a8ef92724b 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -46,7 +46,6 @@ inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) { } PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { - LOG(INFO) << AsText(func, false); auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol) << "MakePackedAPI: Expect PrimFunc to have the global_symbol attribute"; diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 8e17e4b8b027a..b9546e4ade15c 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -226,11 +226,10 @@ def verify_any_argwhere(x_shape, x_np_shape, dtype="bool"): @tvm.testing.uses_gpu def test_any_argwhere(): - # verify_any_argwhere(any_dims(1), (5,)) + verify_any_argwhere(any_dims(1), (5,)) verify_any_argwhere(any_dims(2), (5, 5)) verify_any_argwhere(any_dims(2), (5, 5), "int32") verify_any_argwhere(any_dims(2), (5, 5), "int8") - """ verify_any_argwhere(any_dims(3), (5, 5, 5)) verify_any_argwhere(any_dims(4), (5, 5, 5, 5)) verify_any_argwhere(any_dims(5), (5, 5, 5, 5, 5)) @@ -242,7 +241,6 @@ def test_any_argwhere(): verify_any_argwhere(any_dims(3), (5, 5, 5), "int8") verify_any_argwhere(any_dims(4), (5, 5, 5, 5), "int8") verify_any_argwhere(any_dims(5), (5, 5, 5, 5, 5), "int8") - """ def verify_any_take(data_shape, indices_shape, axis, data_np_shape, indices_np_shape): @@ -814,6 +812,7 @@ def verify_any_topk(data_shape, kval, np_dshape, dtype, const_k=False): check_result(in_vals, mod, ref_out) + @tvm.testing.uses_gpu def test_any_topk(): verify_any_topk(any_dims(1), 5, (10,), "float32") @@ -1362,6 +1361,4 @@ def test_any_where(): if __name__ == "__main__": - #pytest.main([__file__]) - #test_any_topk() - test_any_argwhere() + pytest.main([__file__])