Skip to content

Commit

Permalink
all tests pass
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics committed Dec 2, 2020
1 parent 12ce21b commit 7fffa99
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 71 deletions.
106 changes: 42 additions & 64 deletions python/tvm/topi/cuda/argwhere.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from .sort import topk, topk_thrust, argsort, argsort_thrust
from .. import tag
from ..transform import strided_slice, adv_index, squeeze, dynamic_strided_slice1
from ..utils import const_vector

logger = logging.getLogger("topi")

Expand All @@ -47,17 +46,6 @@ def _get_sort_func(mode=0):
return ret


def _create_end(data, out, end):
ib = tvm.tir.ir_builder.create()
end = tvm.tir.const(end, dtype=out.dtype)
out_ptr = ib.buffer_ptr(out)
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(bx, "thread_extent", 1)
out_ptr[0] = data.shape[0]
out_ptr[1] = end
return ib.get()


def argwhere_1d_ir(condition, out):
"""Low level IR for argwhere 1D
Expand Down Expand Up @@ -247,19 +235,8 @@ def argwhere_2d(output_shape, condition):
out2 = sort_func(out1, axis=0, dtype="int32")
out3 = squeeze(out2)

return adv_index(out, [out3])
out = adv_index(out, [out3])
else:
# out_shape = [2]
# out_buf = tvm.tir.decl_buffer(out_shape, "int32", "strided_slice_out_buf")
# end = te.extern(
# [out_shape],
# [out],
# lambda ins, outs: _create_end(ins[0], outs[0], 2),
# dtype="int32",
# out_buffers=[out_buf],
# name="strided_slice_gpu_end0",
# tag="strided_slice_gpu_end0",
# )
out1 = dynamic_strided_slice1(out, [0, 1], [out.shape[0], 2], [1, 1])
out2 = sort_func(out1, axis=0, dtype="int32")
out3 = squeeze(out2)
Expand All @@ -268,28 +245,8 @@ def argwhere_2d(output_shape, condition):
out1 = dynamic_strided_slice1(out, [0, 0], [out.shape[0], 1], [1, 1])
out2 = sort_func(out1, axis=0, dtype="int32")
out3 = squeeze(out2)
return adv_index(out, [out3])
# out1 = dynamic_strided_slice1(out, [0, 1], [-1, -1])
# out1 = strided_slice(out, const_vector([0, 1]), end)
# out2 = sort_func(out1, axis=0, dtype="int32")
# out3 = squeeze(out2)
# out = adv_index(out, [out3])

# out_buf = tvm.tir.decl_buffer(out_shape, "int32", "strided_slice_out_buf")
# end = te.extern(
# [out_shape],
# [out],
# lambda ins, outs: _create_end(ins[0], outs[0], 1),
# dtype="int32",
# out_buffers=[out_buf],
# name="strided_slice_gpu_end1",
# tag="strided_slice_gpu_end1",
# )
# out1 = strided_slice(out, const_vector([0, 0]), end)
# out2 = sort_func(out1, axis=0, dtype="int32")
# out3 = squeeze(out2)

# return adv_index(out, [out3])
out = adv_index(out, [out3])
return out


def argwhere_3d_ir(condition, out):
Expand Down Expand Up @@ -382,18 +339,25 @@ def argwhere_3d(output_shape, condition):
tag="argwhere3d_gpu",
)

if out.shape[0] <= 1:
if isinstance(out.shape[0], (int, tvm.tir.expr.IntImm)) and int(out.shape[0]) <= 1:
return out

# sort the output from the least significant to the most significant
# column.
sort_func = _get_sort_func(1)
for i in reversed(range(3)):
out1 = strided_slice(out, [0, i], [out.shape[0], i + 1])
out2 = sort_func(out1, axis=0, dtype="int32")
out3 = squeeze(out2)
out = adv_index(out, [out3])

if isinstance(out.shape[0], (int, tvm.tir.expr.IntImm)):
for i in reversed(range(3)):
out1 = strided_slice(out, [0, i], [out.shape[0], i + 1])
out2 = sort_func(out1, axis=0, dtype="int32")
out3 = squeeze(out2)
out = adv_index(out, [out3])
else:
for i in reversed(range(3)):
out1 = dynamic_strided_slice1(out, [0, i], [out.shape[0], i + 1], [1, 1])
out2 = sort_func(out1, axis=0, dtype="int32")
out3 = squeeze(out2)
out = adv_index(out, [out3])
return out


Expand Down Expand Up @@ -490,17 +454,24 @@ def argwhere_4d(output_shape, condition):
tag="argwhere4d_gpu",
)

if out.shape[0] <= 1:
if isinstance(out.shape[0], (int, tvm.tir.expr.IntImm)) and int(out.shape[0]) <= 1:
return out

# sort the output from the least significant to the most significant
# column.
sort_func = _get_sort_func(1)
for i in reversed(range(4)):
out1 = strided_slice(out, [0, i], [out.shape[0], i + 1])
out2 = sort_func(out1, axis=0, dtype="int32")
out3 = squeeze(out2)
out = adv_index(out, [out3])
if isinstance(out.shape[0], (int, tvm.tir.expr.IntImm)):
for i in reversed(range(4)):
out1 = strided_slice(out, [0, i], [out.shape[0], i + 1])
out2 = sort_func(out1, axis=0, dtype="int32")
out3 = squeeze(out2)
out = adv_index(out, [out3])
else:
for i in reversed(range(4)):
out1 = dynamic_strided_slice1(out, [0, i], [out.shape[0], i + 1], [1, 1])
out2 = sort_func(out1, axis=0, dtype="int32")
out3 = squeeze(out2)
out = adv_index(out, [out3])

return out

Expand Down Expand Up @@ -601,17 +572,24 @@ def argwhere_5d(output_shape, condition):
tag="argwhere5d_gpu",
)

if out.shape[0] <= 1:
if isinstance(out.shape[0], (int, tvm.tir.expr.IntImm)) and int(out.shape[0]) <= 1:
return out

# sort the output from the least significant to the most significant
# column.
sort_func = _get_sort_func(1)
for i in reversed(range(5)):
out1 = strided_slice(out, [0, i], [out.shape[0], i + 1])
out2 = sort_func(out1, axis=0, dtype="int32")
out3 = squeeze(out2)
out = adv_index(out, [out3])
if isinstance(out.shape[0], (int, tvm.tir.expr.IntImm)):
for i in reversed(range(5)):
out1 = strided_slice(out, [0, i], [out.shape[0], i + 1])
out2 = sort_func(out1, axis=0, dtype="int32")
out3 = squeeze(out2)
out = adv_index(out, [out3])
else:
for i in reversed(range(5)):
out1 = dynamic_strided_slice1(out, [0, i], [out.shape[0], i + 1], [1, 1])
out2 = sort_func(out1, axis=0, dtype="int32")
out3 = squeeze(out2)
out = adv_index(out, [out3])

return out

Expand Down
1 change: 1 addition & 0 deletions python/tvm/topi/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def strided_slice(a, begin, end, strides=None, slice_mode="end"):
strides = []
return cpp.strided_slice(a, begin, end, strides, slice_mode)


def dynamic_strided_slice1(a, begin, end, strides):
return cpp.dynamic_strided_slice1(a, begin, end, strides)

Expand Down
1 change: 0 additions & 1 deletion src/tir/transforms/make_packed_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) {
}

PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) {
LOG(INFO) << AsText(func, false);
auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(global_symbol) << "MakePackedAPI: Expect PrimFunc to have the global_symbol attribute";

Expand Down
9 changes: 3 additions & 6 deletions tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,11 +226,10 @@ def verify_any_argwhere(x_shape, x_np_shape, dtype="bool"):

@tvm.testing.uses_gpu
def test_any_argwhere():
# verify_any_argwhere(any_dims(1), (5,))
verify_any_argwhere(any_dims(1), (5,))
verify_any_argwhere(any_dims(2), (5, 5))
verify_any_argwhere(any_dims(2), (5, 5), "int32")
verify_any_argwhere(any_dims(2), (5, 5), "int8")
"""
verify_any_argwhere(any_dims(3), (5, 5, 5))
verify_any_argwhere(any_dims(4), (5, 5, 5, 5))
verify_any_argwhere(any_dims(5), (5, 5, 5, 5, 5))
Expand All @@ -242,7 +241,6 @@ def test_any_argwhere():
verify_any_argwhere(any_dims(3), (5, 5, 5), "int8")
verify_any_argwhere(any_dims(4), (5, 5, 5, 5), "int8")
verify_any_argwhere(any_dims(5), (5, 5, 5, 5, 5), "int8")
"""


def verify_any_take(data_shape, indices_shape, axis, data_np_shape, indices_np_shape):
Expand Down Expand Up @@ -814,6 +812,7 @@ def verify_any_topk(data_shape, kval, np_dshape, dtype, const_k=False):

check_result(in_vals, mod, ref_out)


@tvm.testing.uses_gpu
def test_any_topk():
verify_any_topk(any_dims(1), 5, (10,), "float32")
Expand Down Expand Up @@ -1362,6 +1361,4 @@ def test_any_where():


if __name__ == "__main__":
#pytest.main([__file__])
#test_any_topk()
test_any_argwhere()
pytest.main([__file__])

0 comments on commit 7fffa99

Please sign in to comment.