Skip to content

Commit b70c662

Browse files
committed
[Relay][Strategy] Use x86 dense schedules for arm_cpu
Currently the fallback used when compiling a dense operation with targets such as `llvm -device=arm_cpu` is `dense.generic`. This results very poor performance. Although apache#13775 meant that x86 schedules are used in cases where no strategy is provided by arm_cpu, the dense strategy is registered due to the existance of specialized schedules for arm_cpu e.g. a schedule for embedded devices. This commit ensures x86 schedules are used inplace of a generic schedule which yeilds much better performance. The commit also follows the same approach for the `dense.generic` schedule as the x86 strategy. This will only be used when autoscheduler is enabled. A test has been added to check the intended schedules are picked when compiling with `arm_cpu`. Change-Id: I8697f630d4acfab71a9626cf9e0dc3086987f163
1 parent d1f7ef4 commit b70c662

File tree

2 files changed

+79
-21
lines changed

2 files changed

+79
-21
lines changed

python/tvm/relay/op/strategy/arm_cpu.py

Lines changed: 41 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -557,33 +557,53 @@ def schedule_dense_arm_cpu(attrs, inputs, out_type, target):
557557
wrap_topi_schedule(topi.arm_cpu.schedule_dense_dsp),
558558
name="dense_dsp.arm_cpu",
559559
)
560-
else:
561-
# For dynamic matrix-vector multiply we use a hand written kernel.
562-
if (
563-
isinstance(inputs[0].shape[0], (int, tir.IntImm))
564-
and inputs[0].shape[0] == 1
565-
and (
566-
topi.utils.is_dynamic_shape(inputs[0].shape)
567-
or topi.utils.is_dynamic_shape(inputs[1].shape)
568-
)
569-
):
570-
strategy.add_implementation(
571-
wrap_compute_dense(topi.x86.dense_dynamic),
572-
wrap_topi_schedule(topi.x86.schedule_dense_dynamic),
573-
name="dense_dynamic.x86",
574-
plevel=20,
575-
)
576-
return strategy
577-
logger.warning("dense is not optimized for arm cpu.")
560+
return strategy
561+
562+
# For dynamic matrix-vector multiply we use a hand written kernel.
563+
if (
564+
isinstance(inputs[0].shape[0], (int, tir.IntImm))
565+
and inputs[0].shape[0] == 1
566+
and (
567+
topi.utils.is_dynamic_shape(inputs[0].shape)
568+
or topi.utils.is_dynamic_shape(inputs[1].shape)
569+
)
570+
):
571+
strategy.add_implementation(
572+
wrap_compute_dense(topi.x86.dense_dynamic),
573+
wrap_topi_schedule(topi.x86.schedule_dense_dynamic),
574+
name="dense_dynamic.x86",
575+
plevel=20,
576+
)
577+
return strategy
578+
579+
need_auto_scheduler_layout = is_auto_scheduler_enabled()
580+
need_meta_schedule_layout = is_meta_schedule_enabled()
581+
if need_auto_scheduler_layout or need_meta_schedule_layout:
578582
strategy.add_implementation(
579583
wrap_compute_dense(
580584
topi.nn.dense,
581-
need_auto_scheduler_layout=is_auto_scheduler_enabled(),
582-
need_meta_schedule_layout=is_meta_schedule_enabled(),
585+
need_auto_scheduler_layout=need_auto_scheduler_layout,
586+
need_meta_schedule_layout=need_meta_schedule_layout,
583587
),
584-
wrap_topi_schedule(topi.generic.schedule_dense),
588+
naive_schedule,
585589
name="dense.generic",
590+
plevel=11,
586591
)
592+
593+
# Fallback to x86 schedules as there is currently no arm_cpu schedule for dense
594+
strategy.add_implementation(
595+
wrap_compute_dense(topi.x86.dense_nopack),
596+
wrap_topi_schedule(topi.x86.schedule_dense_nopack),
597+
name="dense_nopack.x86",
598+
plevel=5,
599+
)
600+
strategy.add_implementation(
601+
wrap_compute_dense(topi.x86.dense_pack),
602+
wrap_topi_schedule(topi.x86.schedule_dense_pack),
603+
name="dense_pack.x86",
604+
plevel=10,
605+
)
606+
587607
return strategy
588608

589609

tests/python/relay/strategy/test_select_implementation.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@
1616
# under the License.
1717

1818
""" Tests strategy selection for Relay ops """
19+
1920
import pytest
21+
import numpy as np
22+
2023
import tvm
2124
from tvm import relay
2225
from tvm import te
@@ -52,5 +55,40 @@ def test_concatenate(target, expected_implementation):
5255
assert impl.name == expected_implementation
5356

5457

58+
@pytest.mark.parametrize(
59+
"target,expected_valid_impl,expected_impl",
60+
[("llvm -device=arm_cpu", ["dense_pack.x86", "dense_nopack.x86"], "dense_pack.x86")],
61+
)
62+
def test_dense(target, expected_valid_impl, expected_impl):
63+
target = tvm.target.Target(target)
64+
65+
data_shape = (30, 40)
66+
weight_shape = (30, 40)
67+
dtype = "float32"
68+
69+
out = relay.nn.dense(
70+
relay.var("data", shape=data_shape, dtype=dtype),
71+
relay.var("weight", shape=weight_shape, dtype=dtype),
72+
out_dtype=dtype,
73+
)
74+
out = run_infer_type(out)
75+
76+
with target:
77+
args = [
78+
out.op,
79+
out.attrs,
80+
[te.placeholder(data_shape, dtype), te.placeholder(weight_shape, dtype)],
81+
out.checked_type,
82+
target,
83+
]
84+
valid_impl = relay.backend.te_compiler.get_valid_implementations(*args)
85+
selected_impl, _ = relay.backend.te_compiler.select_implementation(*args, use_autotvm=False)
86+
87+
assert len(valid_impl) == len(expected_valid_impl)
88+
for impl in valid_impl:
89+
assert impl.name in expected_valid_impl
90+
assert selected_impl.name == expected_impl
91+
92+
5593
if __name__ == "__main__":
5694
tvm.testing.main()

0 commit comments

Comments
 (0)