Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/tvm/topi/adreno/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _schedule_reduce_adreno(op, sch, is_idx_reduce=False):
sch[temp_val_input].set_scope("local")

shape = get_const_tuple(sch_output.shape)
latest4 = shape[-1] == 4
latest4 = len(shape) > 0 and shape[-1] == 4
div4 = numpy.prod(shape) % 4 == 0

# Fuse and split the axis
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/topi/adreno/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,14 +537,14 @@ def bind_data_copy(stage, axis_to_vectorize=None):
stage.vectorize(iax3)
fused = stage.fuse(ax0, ax1, ax2, oax3)

ftc = numpy.prod(shape) / 4
ftc = numpy.prod(shape) // 4
div = get_div(ftc, 128)
block, thread = stage.split(fused, factor=div)

stage.bind(block, te.thread_axis("blockIdx.z"))
stage.bind(thread, te.thread_axis("threadIdx.z"))
else:
if shape[-1] == 4:
if len(shape) > 0 and shape[-1] == 4:
axes = stage.op.axis
fused = stage.fuse(*axes[:-1])
ftc = numpy.prod(shape[:-1])
Expand All @@ -557,7 +557,7 @@ def bind_data_copy(stage, axis_to_vectorize=None):
ftc = numpy.prod(shape)
vthread = get_div(ftc, 8)
fused = stage.fuse(*stage.op.axis)
ftc = ftc / vthread
ftc = ftc // vthread
# 1024 is a maximum work group size on the most Adreno GPU
num_thread = get_div(ftc, 1024 // vthread)
a, b = stage.split(fused, factor=num_thread)
Expand Down
1 change: 1 addition & 0 deletions python/tvm/topi/testing/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"arm_cpu": topi.arm_cpu.schedule_injective,
"gpu": topi.cuda.schedule_injective,
"hls": topi.hls.schedule_injective,
"adreno": topi.adreno.schedule_injective,
}

_reduce_schedule = {
Expand Down
18 changes: 18 additions & 0 deletions tests/python/relay/opencl_texture/test_reduction_texture.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,5 +177,23 @@ def test_max_global_pooling_block4(remote, target, dtype):
build_run_compare(remote, mod, {}, {"data": input_shape}, {"data": dtype}, target)


@tvm.testing.requires_opencl
@tvm.testing.parametrize_targets("opencl -device=adreno")
def test_sum_cast(remote, target, dtype):
shape = (10,)
A = relay.var("A", shape=shape)
w = relay.op.sum(A)
w = relay.cast(w, "int32")
mod = relay.Function([A], w)

shape_dict = {
"A": shape,
}
dtype_dict = {
"A": dtype,
}
build_run_compare(remote, mod, {}, shape_dict, dtype_dict, target)


if __name__ == "__main__":
tvm.testing.main()
2 changes: 1 addition & 1 deletion tests/python/topi/python/test_topi_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ def check_device(target):
foo(data_nd, begin_nd, end_nd, strides_nd, out_nd)
tvm.testing.assert_allclose(out_nd.numpy(), out_npy)

for target in ["llvm", "opencl", "sdaccel", "aocl_sw_emu"]:
for target in ["llvm", "opencl", "sdaccel", "aocl_sw_emu", "opencl --device=adreno"]:
check_device(target)


Expand Down