Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions python/tvm/relay/op/strategy/arm_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,21 +211,23 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
name="conv2d_nhwc_dsp.arm_cpu",
)
elif kernel_layout == "HWIO":
is_aarch64 = target.features.is_aarch64
has_asimd = target.features.has_asimd
has_dot_prod = target.features.has_dotprod

if has_dot_prod and data.dtype in ["int8", "uint8"]:
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_quantized_native),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized_native),
name="conv2d_NHWC_quantized_native.arm_cpu",
)
if has_asimd and data.dtype in ["int8", "uint8"]:
if is_aarch64 and has_asimd and data.dtype in ["int8", "uint8"]:
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved),
name="conv2d_NHWC_quantized_interleaved.arm_cpu",
)
if (not has_asimd) or (data.dtype not in ["int8", "uint8"]):
if (not is_aarch64) or (data.dtype not in ["int8", "uint8"]):
# TODO(@giuseros)
# This strategy errors out for quantized data types when tuning.
# Let's use this only for non-aarch64 or non-quantized cases
Expand Down Expand Up @@ -285,7 +287,7 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
)
elif layout == "NHWC":
assert kernel_layout == "HWOI"
if target.features.has_asimd:
if target.features.is_aarch64 and target.features.has_asimd:
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.compute_depthwise_conv2d_nhwc),
wrap_topi_schedule(topi.arm_cpu.schedule_depthwise_conv2d_nhwc),
Expand All @@ -298,7 +300,6 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
# The int8 implementation DOES need the DSP unit (for SXTB16), but it is not
# possible to use the DSP unit to speed up a NHWC depthwise convolution (though
# an NCHW convolution would benefit).

elif (
dilation_w == dilation_h == 1
and kernel.shape[3] == 1 # channel_multiplier == 1
Expand All @@ -308,6 +309,7 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
or (data.shape[3] % 2 == 0 and data.dtype == "int16")
)
and (padding != "SAME" or data.shape[1] % stride_h == data.shape[2] % stride_w == 0)
and target.kind.name == "c"
# Ideally we should check that kernel is a Relay constant, but strategy functions
# don't have access to the data needed to check this.
):
Expand Down
97 changes: 97 additions & 0 deletions tests/python/relay/strategy/test_select_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,102 @@ def test_concatenate(target, expected_implementation):
assert impl.name == expected_implementation


@pytest.mark.parametrize(
"target,expected_impl",
[
("llvm -device=arm_cpu", "conv2d_nhwc_spatial_pack.arm_cpu"),
(
"llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+neon",
"conv2d_NHWC_quantized_interleaved.arm_cpu",
),
(
"llvm -device=arm_cpu -mtriple=armv8l-linux-gnu -mattr=+neon",
"conv2d_nhwc_spatial_pack.arm_cpu",
),
],
)
def test_int8_conv2d(target, expected_impl):
target = tvm.target.Target(target)

dtype = "int8"
data_shape = (1, 1, 1, 4)
weight_shape = (1, 1, 4, 4)
data_layout = "NHWC"
kernel_layout = "HWIO"
channels = 4
kernel_size = (1, 1)

out = relay.nn.conv2d(
relay.var("data", shape=data_shape, dtype=dtype),
relay.var("weight", shape=weight_shape, dtype=dtype),
kernel_size=kernel_size,
channels=channels,
data_layout=data_layout,
kernel_layout=kernel_layout,
)
out = run_infer_type(out)

with target:
impl, _ = relay.backend.te_compiler.select_implementation(
out.op,
out.attrs,
[te.placeholder(data_shape, dtype), te.placeholder(weight_shape, dtype)],
out.checked_type,
target,
)

assert impl.name == expected_impl


@pytest.mark.parametrize(
"target,expected_impl",
[
("llvm -device=arm_cpu", "depthwise_conv2d_nhwc.generic"),
(
"llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+neon",
"depthwise_conv2d_nhwc.arm_cpu",
),
(
"llvm -device=arm_cpu -mtriple=armv8l-linux-gnu -mattr=+neon",
"depthwise_conv2d_nhwc.generic",
),
("c -device=arm_cpu -mcpu=cortex-m55", "depthwise_conv2d_nhwc_dsp.arm_cpu"),
],
)
def test_int8_depthwise_conv2d(target, expected_impl):
target = tvm.target.Target(target)

dtype = "int8"
out_dtype = "int32"
data_shape = (2, 2, 4, 8)
weight_shape = (2, 2, 8, 1)
data_layout = "NHWC"
kernel_layout = "HWOI"
groups = 8
kernel_size = (2, 2)

out = relay.nn.conv2d(
relay.var("data", shape=data_shape, dtype=dtype),
relay.var("weight", shape=weight_shape, dtype=dtype),
kernel_size=kernel_size,
data_layout=data_layout,
kernel_layout=kernel_layout,
groups=groups,
out_dtype=out_dtype,
)
out = run_infer_type(out)

with target:
impl, _ = relay.backend.te_compiler.select_implementation(
out.op,
out.attrs,
[te.placeholder(data_shape, dtype), te.placeholder(weight_shape, dtype)],
out.checked_type,
target,
)

assert impl.name == expected_impl


if __name__ == "__main__":
tvm.testing.main()