Skip to content

Commit aa67a6a

Browse files
authored
[Hexagon] Add USMP tests (#11279)
* Add USMP tests * Address Chris comments * Address Chris comment on assert * trigger
1 parent 2023a20 commit aa67a6a

17 files changed

+229
-59
lines changed

python/tvm/testing/usmp.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
""" This file contains USMP tests harnesses."""
18+
19+
import tvm
20+
21+
22+
def is_tvm_backendallocworkspace_calls(mod: tvm.runtime.module) -> bool:
23+
"""TVMBackendAllocWorkspace call check.
24+
25+
This checker checks whether any c-source produced has TVMBackendAllocWorkspace calls.
26+
If USMP is invoked, none of them should have TVMBAW calls
27+
"""
28+
dso_modules = mod._collect_dso_modules()
29+
for dso_mod in dso_modules:
30+
if dso_mod.type_key not in ["c", "llvm"]:
31+
assert (
32+
False
33+
), 'Current AoT codegen flow should only produce type "c" or "llvm" runtime modules'
34+
35+
source = dso_mod.get_source()
36+
if source.count("TVMBackendAllocWorkspace") != 0:
37+
return True
38+
39+
return False

tests/python/contrib/test_hexagon/conftest.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,14 @@
2121
import os
2222
import random
2323
import socket
24-
from typing import Optional
24+
from typing import Optional, Union
2525

2626
import pytest
2727

2828
import tvm
2929
import tvm.rpc.tracker
30-
from tvm.contrib.hexagon.build import HexagonLauncher
30+
from tvm.contrib.hexagon.build import HexagonLauncher, HexagonLauncherRPC
31+
from tvm.contrib.hexagon.session import Session
3132

3233
HEXAGON_TOOLCHAIN = "HEXAGON_TOOLCHAIN"
3334
TVM_TRACKER_HOST = "TVM_TRACKER_HOST"
@@ -84,7 +85,7 @@ def android_serial_number() -> Optional[str]:
8485
previous_port = None
8586

8687

87-
def get_free_port():
88+
def get_free_port() -> int:
8889

8990
global previous_port
9091
if previous_port is None:
@@ -100,7 +101,7 @@ def get_free_port():
100101

101102

102103
@pytest.fixture(scope="session")
103-
def _tracker_info() -> (str, int):
104+
def _tracker_info() -> Union[str, int]:
104105
env_tracker_host = os.getenv(TVM_TRACKER_HOST, default="")
105106
env_tracker_port = os.getenv(TVM_TRACKER_PORT, default="")
106107

@@ -156,7 +157,9 @@ def adb_server_socket() -> str:
156157

157158

158159
@tvm.testing.fixture
159-
def hexagon_launcher(request, android_serial_number, rpc_server_port, adb_server_socket):
160+
def hexagon_launcher(
161+
request, android_serial_number, rpc_server_port, adb_server_socket
162+
) -> HexagonLauncherRPC:
160163
if android_serial_number is None:
161164
yield None
162165
else:
@@ -181,7 +184,7 @@ def hexagon_launcher(request, android_serial_number, rpc_server_port, adb_server
181184

182185

183186
@tvm.testing.fixture
184-
def hexagon_session(hexagon_launcher):
187+
def hexagon_session(hexagon_launcher) -> Session:
185188
if hexagon_launcher is None:
186189
yield None
187190
else:

tests/python/contrib/test_hexagon/test_launcher.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,13 @@
2323
from tvm import te
2424
from tvm import relay
2525
from tvm.relay.backend import Executor, Runtime
26+
from tvm.contrib.hexagon.session import Session
2627

2728
from .conftest import requires_hexagon_toolchain
2829

2930

3031
@requires_hexagon_toolchain
31-
def test_add(hexagon_session):
32+
def test_add(hexagon_session: Session):
3233
dtype = "int8"
3334
A = tvm.te.placeholder((2,), dtype=dtype)
3435
B = tvm.te.placeholder((1,), dtype=dtype)
@@ -53,7 +54,7 @@ def test_add(hexagon_session):
5354

5455

5556
@requires_hexagon_toolchain
56-
def test_add_vtcm(hexagon_session):
57+
def test_add_vtcm(hexagon_session: Session):
5758
dtype = "int8"
5859
A = tvm.te.placeholder((2,), dtype=dtype)
5960
B = tvm.te.placeholder((1,), dtype=dtype)
@@ -122,7 +123,7 @@ def test_matmul(self, hexagon_session, M, N, K):
122123

123124

124125
@requires_hexagon_toolchain
125-
def test_graph_executor(hexagon_session):
126+
def test_graph_executor(hexagon_session: Session):
126127
dtype = "float32"
127128
data = relay.var("data", relay.TensorType((1, 64, 64, 3), dtype))
128129
weight = relay.var("weight", relay.TensorType((5, 5, 3, 8), dtype))
@@ -178,7 +179,7 @@ def test_graph_executor(hexagon_session):
178179

179180

180181
@requires_hexagon_toolchain
181-
def test_graph_executor_multiple_conv2d(hexagon_session):
182+
def test_graph_executor_multiple_conv2d(hexagon_session: Session):
182183
dtype = "float32"
183184
input_shape = (1, 8, 8, 3)
184185
w1_shape = (5, 5, 3, 1)
@@ -255,7 +256,7 @@ def test_graph_executor_multiple_conv2d(hexagon_session):
255256

256257

257258
@requires_hexagon_toolchain
258-
def test_aot_executor(hexagon_session, aot_host_target, aot_target):
259+
def test_aot_executor(hexagon_session: Session, aot_host_target, aot_target):
259260
dtype = "float32"
260261
input_shape = (1, 128, 128, 3)
261262
w_shape = (5, 5, 3, 8)
@@ -314,7 +315,7 @@ def test_aot_executor(hexagon_session, aot_host_target, aot_target):
314315

315316

316317
@requires_hexagon_toolchain
317-
def test_aot_executor_multiple_conv2d(hexagon_session, aot_host_target, aot_target):
318+
def test_aot_executor_multiple_conv2d(hexagon_session: Session, aot_host_target, aot_target):
318319
dtype = "float32"
319320
input_shape = (1, 8, 8, 3)
320321
w1_shape = (5, 5, 3, 1)

tests/python/contrib/test_hexagon/test_models.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,17 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18-
import os
1918
import sys
2019
import pytest
2120
import numpy as np
2221

2322
import tvm.testing
24-
from tvm import te
2523
from tvm import relay
2624
from tvm.relay.backend import Executor, Runtime
25+
from tvm.contrib.hexagon.session import Session
2726

2827
from .conftest import requires_hexagon_toolchain
2928

30-
MOBILENET_MODEL = ""
31-
3229

3330
def get_mobilenet():
3431
"""Download and import mobilenet model with ONNX"""
@@ -42,7 +39,7 @@ def get_mobilenet():
4239

4340

4441
@requires_hexagon_toolchain
45-
def test_mobilenet(hexagon_session):
42+
def test_mobilenet(hexagon_session: Session):
4643
dtype = "float32"
4744
onnx_model = get_mobilenet()
4845

@@ -88,8 +85,11 @@ def test_mobilenet(hexagon_session):
8885
tvm.testing.assert_allclose(hexagon_output, expected_output, rtol=1e-4, atol=1e-5)
8986

9087

88+
enable_usmp = tvm.testing.parameter(False, True)
89+
90+
9191
@requires_hexagon_toolchain
92-
def test_mobilenet_aot(hexagon_session, aot_host_target, aot_target):
92+
def test_mobilenet_aot(hexagon_session: Session, aot_host_target, aot_target, enable_usmp):
9393
if hexagon_session._launcher._serial_number == "simulator":
9494
pytest.skip(msg="Skip on simulator due to long runtime.")
9595

@@ -104,7 +104,8 @@ def test_mobilenet_aot(hexagon_session, aot_host_target, aot_target):
104104
inputs = {input_name: data_in}
105105

106106
target_llvm = tvm.target.Target("llvm")
107-
with tvm.transform.PassContext(opt_level=3):
107+
config = {"tir.usmp.enable": enable_usmp}
108+
with tvm.transform.PassContext(opt_level=3, config=config):
108109
hexagon_lowered = tvm.relay.build(
109110
relay_mod,
110111
tvm.target.Target(aot_target, host=aot_host_target),
@@ -113,6 +114,12 @@ def test_mobilenet_aot(hexagon_session, aot_host_target, aot_target):
113114
params=params,
114115
)
115116

117+
aot_mod = hexagon_session.get_executor_from_factory(hexagon_lowered)
118+
aot_mod.set_input(**inputs)
119+
aot_mod.run()
120+
hexagon_output = aot_mod.get_output(0).numpy()
121+
122+
with tvm.transform.PassContext(opt_level=3):
116123
llvm_lowered = tvm.relay.build(
117124
relay_mod,
118125
tvm.target.Target(target_llvm, host=target_llvm),
@@ -121,11 +128,6 @@ def test_mobilenet_aot(hexagon_session, aot_host_target, aot_target):
121128
params=params,
122129
)
123130

124-
aot_mod = hexagon_session.get_executor_from_factory(hexagon_lowered)
125-
aot_mod.set_input(**inputs)
126-
aot_mod.run()
127-
hexagon_output = aot_mod.get_output(0).numpy()
128-
129131
llvm_graph_mod = tvm.contrib.graph_executor.GraphModule(llvm_lowered["default"](tvm.cpu(0)))
130132
llvm_graph_mod.set_input(**inputs)
131133
llvm_graph_mod.run()

tests/python/contrib/test_hexagon/test_thread_pool.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import tvm
2222
import tvm.contrib.hexagon
23+
from tvm.contrib.hexagon.session import Session
2324
import tvm.script
2425
import tvm.testing
2526
from tvm import te
@@ -53,7 +54,7 @@ def elemwise_sum_parallel(a: T.handle, b: T.handle, c: T.handle, n: T.int32):
5354
C[vi] = A[vi] + B[vi]
5455

5556

56-
def generate_add_test_data(hexagon_session, n=128 * 1024):
57+
def generate_add_test_data(hexagon_session: Session, n=128 * 1024):
5758
a = tvm.nd.array(np.random.uniform(size=n).astype("float32"), hexagon_session.device)
5859
b = tvm.nd.array(np.random.uniform(size=n).astype("float32"), hexagon_session.device)
5960
c = tvm.nd.array(np.zeros(n, dtype="float32"), hexagon_session.device)
@@ -85,7 +86,7 @@ def test_speedup(hexagon_session, capsys):
8586

8687

8788
@requires_hexagon_toolchain
88-
def test_elemwise_sum_parallel(hexagon_session):
89+
def test_elemwise_sum_parallel(hexagon_session: Session):
8990
if hexagon_session is None:
9091
pytest.skip(msg="Skip hardware test, ANDROID_SERIAL_NUMBER is not set.")
9192

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import sys
19+
import pytest
20+
import numpy as np
21+
22+
import tvm.testing
23+
from tvm import te
24+
from tvm import relay
25+
from tvm.relay.backend import Executor, Runtime
26+
from tvm.contrib.hexagon.session import Session
27+
from tvm.testing.usmp import is_tvm_backendallocworkspace_calls
28+
29+
from .conftest import requires_hexagon_toolchain
30+
31+
usmp_enabled = tvm.testing.parameter(False, True)
32+
33+
34+
@requires_hexagon_toolchain
35+
def test_conv2d(hexagon_session: Session, aot_host_target, aot_target, usmp_enabled):
36+
dtype = "float32"
37+
input_shape = (1, 8, 8, 3)
38+
w1_shape = (5, 5, 3, 1)
39+
w2_shape = (5, 5, 1, 3)
40+
data = relay.var("data", relay.TensorType(input_shape, dtype))
41+
weight1 = relay.var("weight1", relay.TensorType(w1_shape, dtype))
42+
weight2 = relay.var("weight2", relay.TensorType(w2_shape, dtype))
43+
y1 = relay.nn.conv2d(
44+
data,
45+
weight1,
46+
padding=(2, 2),
47+
kernel_size=(5, 5),
48+
data_layout="NHWC",
49+
kernel_layout="HWIO",
50+
out_dtype="float32",
51+
)
52+
y2 = relay.nn.conv2d(
53+
y1,
54+
weight2,
55+
padding=(2, 2),
56+
kernel_size=(5, 5),
57+
data_layout="NHWC",
58+
kernel_layout="HWIO",
59+
out_dtype="float32",
60+
)
61+
f = relay.Function([data, weight1, weight2], y2)
62+
relay_mod = tvm.IRModule.from_expr(f)
63+
relay_mod = relay.transform.InferType()(relay_mod)
64+
65+
weight1_data = np.random.rand(w1_shape[0], w1_shape[1], w1_shape[2], w1_shape[3]).astype(
66+
dtype=dtype
67+
)
68+
weight2_data = np.random.rand(w2_shape[0], w2_shape[1], w2_shape[2], w2_shape[3]).astype(
69+
dtype=dtype
70+
)
71+
input_data = np.random.rand(
72+
input_shape[0], input_shape[1], input_shape[2], input_shape[3]
73+
).astype(dtype=dtype)
74+
75+
params = {"weight1": weight1_data, "weight2": weight2_data}
76+
inputs = {"data": input_data}
77+
78+
with tvm.transform.PassContext(opt_level=3, config={"tir.usmp.enable": usmp_enabled}):
79+
lowered = tvm.relay.build(
80+
relay_mod,
81+
params=params,
82+
target=tvm.target.Target(aot_target, host=aot_host_target),
83+
runtime=Runtime("cpp"),
84+
executor=Executor("aot", {"unpacked-api": False, "interface-api": "packed"}),
85+
)
86+
87+
assert is_tvm_backendallocworkspace_calls(lowered.lib) != usmp_enabled
88+
89+
aot_mod = hexagon_session.get_executor_from_factory(lowered)
90+
aot_mod.set_input(**inputs)
91+
aot_mod.run()
92+
hexagon_output = aot_mod.get_output(0).numpy()
93+
94+
target_llvm = tvm.target.Target("llvm")
95+
with tvm.transform.PassContext(opt_level=3):
96+
llvm_lowered = tvm.relay.build(
97+
relay_mod,
98+
tvm.target.Target(target_llvm, host=target_llvm),
99+
runtime=Runtime("cpp"),
100+
executor=Executor("graph"),
101+
)
102+
103+
llvm_graph_mod = tvm.contrib.graph_executor.GraphModule(llvm_lowered["default"](tvm.cpu(0)))
104+
llvm_graph_mod.set_input(**params)
105+
llvm_graph_mod.run(**inputs)
106+
expected_output = llvm_graph_mod.get_output(0).numpy()
107+
108+
tvm.testing.assert_allclose(hexagon_output, expected_output, rtol=1e-4, atol=1e-5)
109+
110+
111+
if __name__ == "__main__":
112+
sys.exit(pytest.main(sys.argv))

tests/python/contrib/test_hexagon/topi/test_batch_matmul.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import tvm
2323
from tvm import topi
2424
from tvm import te
25+
from tvm.contrib.hexagon.session import Session
2526
import tvm.topi.testing
2627
from tvm.topi.utils import get_const_tuple
2728

@@ -46,7 +47,7 @@ class TestMatMulFloat:
4647

4748
# TODO(mehrdadh): add dynamic testing
4849
@requires_hexagon_toolchain
49-
def test_batch_matmul(self, hexagon_session, x_batch, y_batch, M, N, K, dtype):
50+
def test_batch_matmul(self, hexagon_session: Session, x_batch, y_batch, M, N, K, dtype):
5051
if dtype == "float16":
5152
pytest.xfail("float16 is not supported.")
5253

@@ -98,7 +99,7 @@ class TestMatMulInt8:
9899
)
99100

100101
@requires_hexagon_toolchain
101-
def test_batch_matmul_int8(self, hexagon_session, x_batch, y_batch, M, N, K):
102+
def test_batch_matmul_int8(self, hexagon_session: Session, x_batch, y_batch, M, N, K):
102103
dtype = "int8"
103104
out_dtype = "int8"
104105
assert x_batch == y_batch or x_batch == 1 or y_batch == 1

0 commit comments

Comments
 (0)