Skip to content

Commit

Permalink
[TOPI] [Hexagon] Uint8 Reshape and batch flatten slice ops (apache#12037
Browse files Browse the repository at this point in the history
)

* [TOPI] [Hexagon] Uint8 Reshape and batch flatten slice ops

* Fix documentation
  • Loading branch information
abhikran-quic authored Jul 16, 2022
1 parent 895f79f commit c0e996e
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 21 deletions.
13 changes: 8 additions & 5 deletions python/tvm/topi/hexagon/slice_ops/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,14 @@ def reshape_compute(inp: te.Tensor, new_shape: tuple) -> te.Tensor:
return topi.transform.reshape(inp, new_shape)


def stir_schedule_nhwc_1024c(
def stir_sched_nhwc_2d_op(
out: te.Tensor,
inp: te.Tensor,
out_layout: str,
in_layout: str,
c_split: int,
) -> tir.Schedule:
"""Schedule for output layout: nhwc-1024c-2d"""
"""Schedule for output layout: nc-1024-2d, nc-2048-2d"""
reshape_func = te.create_prim_func([inp, out])
sch = tir.Schedule(reshape_func, debug_mask="all")
compute = sch.get_block("T_reshape")
Expand All @@ -57,7 +58,7 @@ def stir_schedule_nhwc_1024c(
jout, channel = sch.split(j, [None, inp.shape[3]])
height, width = sch.split(jout, [inp.shape[1], inp.shape[2]])
channelo, channeli = sch.split(channel, [None, 1024])
channelio, channelii = sch.split(channeli, [None, 64])
channelio, channelii = sch.split(channeli, [None, c_split])
sch.reorder(i, height, width, channelo, channelio, channelii)
sch.vectorize(channelii)
return sch
Expand Down Expand Up @@ -101,8 +102,10 @@ def reshape_stir_schedule(
sch : tvm.tir.Schedule
The STIR schedule for slice reshape compute
"""
if output_layout == "nhwc-8h2w32c2w-2d":
if output_layout in ["nhwc-8h2w32c2w-2d", "nhwc-8h8w32c-2d"]:
return stir_schedule_nhwc_8h2w32c2w(out, inp, output_layout, input_layout)
if output_layout == "nc-1024-2d":
return stir_schedule_nhwc_1024c(out, inp, output_layout, input_layout)
return stir_sched_nhwc_2d_op(out, inp, output_layout, input_layout, 64)
if output_layout == "nc-2048-2d":
return stir_sched_nhwc_2d_op(out, inp, output_layout, input_layout, 128)
raise RuntimeError(f"Unexpected layout '{output_layout}'")
21 changes: 21 additions & 0 deletions python/tvm/topi/hexagon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,21 @@ def nc_1024_2d(n, c):
return [n, c // 1024, te.AXIS_SEPARATOR, c % 1024]


def nhwc_2048c_2d(n, h, w, c):
"""Return index map for nhwc_2048 2d layout"""
return [n, h, w, c // 2048, te.AXIS_SEPARATOR, c % 2048]


def nc_2048_2d(n, c):
"""Return index map for nc_2048 2d layout"""
return [n, c // 2048, te.AXIS_SEPARATOR, c % 2048]


def nhwc_8h8w32c_2d(n, h, w, c):
"""Return index map for nhwc_8h8w32c 2d layout"""
return [n, h // 8, w // 8, c // 32, te.AXIS_SEPARATOR, h % 8, w % 8, c % 32]


def iohw_16i32o2i_1d(height, width, in_channel, out_channel):
return [
in_channel // 32,
Expand Down Expand Up @@ -129,4 +144,10 @@ def get_layout_transform_fn(layout):
return nc_1024c_2d
if layout == "iohw-16i32o2i-1d":
return iohw_16i32o2i_1d
if layout == "nhwc-2048c-2d":
return nhwc_2048c_2d
if layout == "nc-2048-2d":
return nc_2048_2d
if layout == "nhwc-8h8w32c-2d":
return nhwc_8h8w32c_2d
raise RuntimeError(f"Unexpected layout '{layout}'")
12 changes: 11 additions & 1 deletion tests/python/contrib/test_hexagon/infrastructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,17 @@ def transform_numpy(arr_np, current_layout: str, new_layout: str):
if new_layout == "nhwc-1024c-2d":
N, H, W, C = arr_np.shape
return arr_np.reshape([N, H, W, C // 1024, 1024])
raise RuntimeError(f"Unexpected new_layout '{new_layout}'")
if new_layout == "nc-2048-2d":
N, C = arr_np.shape
return arr_np.reshape([N, C // 2048, 2048])
if new_layout == "nhwc-2048c-2d":
N, H, W, C = arr_np.shape
return arr_np.reshape([N, H, W, C // 2048, 2048])
if new_layout in ["nhwc-8h8w32c-2d"]:
n, h, w, c = arr_np.shape
return arr_np.reshape([n, h // 8, 8, w // 8, 8, c // 32, 32]).transpose(
0, 1, 3, 5, 2, 4, 6
)

if current_layout == "nc":
n, c = arr_np.shape
Expand Down
47 changes: 32 additions & 15 deletions tests/python/contrib/test_hexagon/topi/test_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,23 +56,23 @@ def reshape_helper(
input_layout,
)
with tvm.transform.PassContext(opt_level=3):
print("output of tvm.lower", tvm.lower(tir_s.mod, name=func))
runtime_module = tvm.build(tir_s.mod, target=target, name=func)

mod = hexagon_session.load_module(runtime_module)

a_numpy = (np.random.uniform(-1, 1, input_shape)).astype(data_type)
a_numpy = (np.random.uniform(-10, 10, input_shape)).astype(data_type)
ref = np.reshape(a_numpy, output_shape)

input_np_transformed = transform_numpy(a_numpy, "nhwc", input_layout)
ref_np_transformed = transform_numpy(ref, "nhwc", output_layout)
input_axis_sep = [4]
if output_layout == "nhwc-8h2w32c2w-2d":
if output_layout in ["nhwc-8h2w32c2w-2d", "nhwc-8h8w32c-2d"]:
output_axis_sep = [4]
elif output_layout == "nc-1024-2d":
elif output_layout in ["nc-1024-2d", "nc-2048-2d"]:
output_axis_sep = [2]
else:
raise RuntimeError(f"Unexpected layout '{output_layout}'")

a_tvm = allocate_hexagon_array(
hexagon_session.device,
data=input_np_transformed,
Expand All @@ -86,26 +86,30 @@ def reshape_helper(
axis_separators=output_axis_sep,
mem_scope="global.vtcm",
)

mod(a_tvm, output)
np.testing.assert_allclose(output.numpy(), ref_np_transformed, atol=1e-07, rtol=0)


batch_flatten_tests = (
batch_flatten_fp16_tests = (
([1, 1, 1, 2048], [1, 2048], "nhwc-1024c-2d", "nc-1024-2d", "float16"),
([1, 2, 4, 2048], [1, 2 * 4 * 2048], "nhwc-1024c-2d", "nc-1024-2d", "float16"),
([1, 8, 8, 1024], [1, 8 * 8 * 1024], "nhwc-1024c-2d", "nc-1024-2d", "float16"),
([2, 4, 8, 1024], [2, 4 * 8 * 1024], "nhwc-1024c-2d", "nc-1024-2d", "float16"),
)


batch_flatten_uint8_tests = (
([1, 1, 1, 2048], [1, 2048], "nhwc-2048c-2d", "nc-2048-2d", "uint8"),
([1, 2, 4, 2048], [1, 2 * 4 * 2048], "nhwc-2048c-2d", "nc-2048-2d", "uint8"),
)


class BaseTestBatchFlatten:
(
input_shape,
output_shape,
input_layout,
output_layout,
data_type,
) = tvm.testing.parameters(*batch_flatten_tests)
(input_shape, output_shape, input_layout, output_layout, data_type,) = tvm.testing.parameters(
*batch_flatten_fp16_tests,
*batch_flatten_uint8_tests,
)


class TestBatchFlatten(BaseTestBatchFlatten):
Expand All @@ -132,11 +136,24 @@ def test_batch_flatten(
)


reshape_fp16_tests = (
([1, 8, 4, 64], [1, 8, 8, 32], "nhwc-8h2w32c2w-2d", "nhwc-8h2w32c2w-2d", "float16"),
([1, 16, 8, 128], [1, 16, 16, 64], "nhwc-8h2w32c2w-2d", "nhwc-8h2w32c2w-2d", "float16"),
)


reshape_uint8_tests = (
([1, 8, 8, 128], [1, 8, 16, 64], "nhwc-8h8w32c-2d", "nhwc-8h8w32c-2d", "uint8"),
([1, 16, 64, 128], [1, 16, 128, 64], "nhwc-8h8w32c-2d", "nhwc-8h8w32c-2d", "uint8"),
)


class BaseTestReshape(BaseTestBatchFlatten):
(input_shape, output_shape, input_layout, output_layout, data_type,) = tvm.testing.parameters(
*batch_flatten_tests,
([1, 8, 4, 64], [1, 8, 8, 32], "nhwc-8h2w32c2w-2d", "nhwc-8h2w32c2w-2d", "float16"),
([1, 16, 8, 128], [1, 16, 16, 64], "nhwc-8h2w32c2w-2d", "nhwc-8h2w32c2w-2d", "float16"),
*batch_flatten_fp16_tests,
*batch_flatten_uint8_tests,
*reshape_fp16_tests,
*reshape_uint8_tests,
)


Expand Down

0 comments on commit c0e996e

Please sign in to comment.