Skip to content

Commit d25f326

Browse files
Lunderbergylc
authored andcommitted
[UnitTests] Expose TVM pytest helpers as plugin (apache#8532)
* [UnitTests] Expose TVM pytest helpers as plugin Previously, pytest helper utilities such as automatic parametrization of `target`/`dev`, or `tvm.testing.parameter` were only available for tests within the `${TVM_HOME}/tests` directory. This PR extracts the helper utilities into an importable plugin, which can be used in external tests (e.g. one-off debugging). * [UnitTests] Refactor the plugin-specific logic out into plugin.py. * [UnitTests] Moved marker definition out to global variable.
1 parent 50d6369 commit d25f326

File tree

6 files changed

+328
-275
lines changed

6 files changed

+328
-275
lines changed

conftest.py

Lines changed: 1 addition & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -14,36 +14,5 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17-
import pytest
18-
from pytest import ExitCode
1917

20-
import tvm
21-
import tvm.testing
22-
23-
24-
def pytest_configure(config):
25-
print("enabled targets:", "; ".join(map(lambda x: x[0], tvm.testing.enabled_targets())))
26-
print("pytest marker:", config.option.markexpr)
27-
28-
29-
@pytest.fixture
30-
def dev(target):
31-
return tvm.device(target)
32-
33-
34-
def pytest_generate_tests(metafunc):
35-
tvm.testing._auto_parametrize_target(metafunc)
36-
tvm.testing._parametrize_correlated_parameters(metafunc)
37-
38-
39-
def pytest_collection_modifyitems(config, items):
40-
tvm.testing._count_num_fixture_uses(items)
41-
tvm.testing._remove_global_fixture_definitions(items)
42-
43-
44-
def pytest_sessionfinish(session, exitstatus):
45-
# Don't exit with an error if we select a subset of tests that doesn't
46-
# include anything
47-
if session.config.option.markexpr != "":
48-
if exitstatus == ExitCode.NO_TESTS_COLLECTED:
49-
session.exitstatus = ExitCode.OK
18+
pytest_plugins = ["tvm.testing.plugin"]

pytest.ini

Lines changed: 0 additions & 26 deletions
This file was deleted.

python/tvm/testing/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17+
1718
# pylint: disable=redefined-builtin, wildcard-import
1819
"""Utility Python functions for TVM testing"""
1920
from .utils import assert_allclose, assert_prim_expr_equal, check_bool_expr_is_true
@@ -23,9 +24,7 @@
2324
from .utils import known_failing_targets, requires_cuda, requires_cudagraph
2425
from .utils import requires_gpu, requires_llvm, requires_rocm, requires_rpc
2526
from .utils import requires_tensorcore, requires_metal, requires_micro, requires_opencl
26-
from .utils import _auto_parametrize_target, _count_num_fixture_uses
27-
from .utils import _remove_global_fixture_definitions, _parametrize_correlated_parameters
28-
from .utils import _pytest_target_params, identity_after, terminate_self
27+
from .utils import identity_after, terminate_self
2928

3029
from ._ffi_api import nop, echo, device_test, run_check_signal, object_use_count
3130
from ._ffi_api import test_wrap_callback, test_raise_error_callback, test_check_eq_callback

python/tvm/testing/plugin.py

Lines changed: 294 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,294 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
"""Pytest plugin for using tvm testing extensions.
19+
20+
TVM provides utilities for testing across all supported targets, and
21+
to more easily parametrize across many inputs. For more information
22+
on usage of these features, see documentation in the tvm.testing
23+
module.
24+
25+
These are enabled by default in all pytests provided by tvm, but may
26+
be useful externally for one-off testing. To enable, add the
27+
following line to the test script, or to the conftest.py in the same
28+
directory as the test scripts.
29+
30+
pytest_plugins = ['tvm.testing.plugin']
31+
32+
"""
33+
34+
import collections
35+
36+
import pytest
37+
import _pytest
38+
39+
import tvm
40+
from tvm.testing import utils
41+
42+
43+
MARKERS = {
44+
"gpu": "mark a test as requiring a gpu",
45+
"tensorcore": "mark a test as requiring a tensorcore",
46+
"cuda": "mark a test as requiring cuda",
47+
"opencl": "mark a test as requiring opencl",
48+
"rocm": "mark a test as requiring rocm",
49+
"vulkan": "mark a test as requiring vulkan",
50+
"metal": "mark a test as requiring metal",
51+
"llvm": "mark a test as requiring llvm",
52+
}
53+
54+
55+
def pytest_configure(config):
56+
"""Runs at pytest configure time, defines marks to be used later."""
57+
58+
for markername, desc in MARKERS.items():
59+
config.addinivalue_line("markers", "{}: {}".format(markername, desc))
60+
61+
print("enabled targets:", "; ".join(map(lambda x: x[0], utils.enabled_targets())))
62+
print("pytest marker:", config.option.markexpr)
63+
64+
65+
def pytest_generate_tests(metafunc):
66+
"""Called once per unit test, modifies/parametrizes it as needed."""
67+
_parametrize_correlated_parameters(metafunc)
68+
_auto_parametrize_target(metafunc)
69+
70+
71+
def pytest_collection_modifyitems(config, items):
72+
"""Called after all tests are chosen, currently used for bookkeeping."""
73+
# pylint: disable=unused-argument
74+
_count_num_fixture_uses(items)
75+
_remove_global_fixture_definitions(items)
76+
77+
78+
@pytest.fixture
79+
def dev(target):
80+
"""Give access to the device to tests that need it."""
81+
return tvm.device(target)
82+
83+
84+
def pytest_sessionfinish(session, exitstatus):
85+
# Don't exit with an error if we select a subset of tests that doesn't
86+
# include anything
87+
if session.config.option.markexpr != "":
88+
if exitstatus == pytest.ExitCode.NO_TESTS_COLLECTED:
89+
session.exitstatus = pytest.ExitCode.OK
90+
91+
92+
def _auto_parametrize_target(metafunc):
93+
"""Automatically applies parametrize_targets
94+
95+
Used if a test function uses the "target" fixture, but isn't
96+
already marked with @tvm.testing.parametrize_targets. Intended
97+
for use in the pytest_generate_tests() handler of a conftest.py
98+
file.
99+
100+
"""
101+
102+
def update_parametrize_target_arg(
103+
argnames,
104+
argvalues,
105+
*args,
106+
**kwargs,
107+
):
108+
args = [arg.strip() for arg in argnames.split(",") if arg.strip()]
109+
if "target" in args:
110+
target_i = args.index("target")
111+
112+
new_argvalues = []
113+
for argvalue in argvalues:
114+
115+
if isinstance(argvalue, _pytest.mark.structures.ParameterSet):
116+
# The parametrized value is already a
117+
# pytest.param, so track any marks already
118+
# defined.
119+
param_set = argvalue.values
120+
target = param_set[target_i]
121+
additional_marks = argvalue.marks
122+
elif len(args) == 1:
123+
# Single value parametrization, argvalue is a list of values.
124+
target = argvalue
125+
param_set = (target,)
126+
additional_marks = []
127+
else:
128+
# Multiple correlated parameters, argvalue is a list of tuple of values.
129+
param_set = argvalue
130+
target = param_set[target_i]
131+
additional_marks = []
132+
133+
new_argvalues.append(
134+
pytest.param(
135+
*param_set, marks=_target_to_requirement(target) + additional_marks
136+
)
137+
)
138+
139+
try:
140+
argvalues[:] = new_argvalues
141+
except TypeError as err:
142+
pyfunc = metafunc.definition.function
143+
filename = pyfunc.__code__.co_filename
144+
line_number = pyfunc.__code__.co_firstlineno
145+
msg = (
146+
f"Unit test {metafunc.function.__name__} ({filename}:{line_number}) "
147+
"is parametrized using a tuple of parameters instead of a list "
148+
"of parameters."
149+
)
150+
raise TypeError(msg) from err
151+
152+
if "target" in metafunc.fixturenames:
153+
# Update any explicit use of @pytest.mark.parmaetrize to
154+
# parametrize over targets. This adds the appropriate
155+
# @tvm.testing.requires_* markers for each target.
156+
for mark in metafunc.definition.iter_markers("parametrize"):
157+
update_parametrize_target_arg(*mark.args, **mark.kwargs)
158+
159+
# Check if any explicit parametrizations exist, and apply one
160+
# if they do not. If the function is marked with either
161+
# excluded or known failing targets, use these to determine
162+
# the targets to be used.
163+
parametrized_args = [
164+
arg.strip()
165+
for mark in metafunc.definition.iter_markers("parametrize")
166+
for arg in mark.args[0].split(",")
167+
]
168+
if "target" not in parametrized_args:
169+
excluded_targets = getattr(metafunc.function, "tvm_excluded_targets", [])
170+
xfail_targets = getattr(metafunc.function, "tvm_known_failing_targets", [])
171+
metafunc.parametrize(
172+
"target",
173+
_pytest_target_params(None, excluded_targets, xfail_targets),
174+
scope="session",
175+
)
176+
177+
178+
def _count_num_fixture_uses(items):
179+
# Helper function, counts the number of tests that use each cached
180+
# fixture. Should be called from pytest_collection_modifyitems().
181+
for item in items:
182+
is_skipped = item.get_closest_marker("skip") or any(
183+
mark.args[0] for mark in item.iter_markers("skipif")
184+
)
185+
if is_skipped:
186+
continue
187+
188+
for fixturedefs in item._fixtureinfo.name2fixturedefs.values():
189+
# Only increment the active fixturedef, in a name has been overridden.
190+
fixturedef = fixturedefs[-1]
191+
if hasattr(fixturedef.func, "num_tests_use_this_fixture"):
192+
fixturedef.func.num_tests_use_this_fixture[0] += 1
193+
194+
195+
def _remove_global_fixture_definitions(items):
196+
# Helper function, removes fixture definitions from the global
197+
# variables of the modules they were defined in. This is intended
198+
# to improve readability of error messages by giving a NameError
199+
# if a test function accesses a pytest fixture but doesn't include
200+
# it as an argument. Should be called from
201+
# pytest_collection_modifyitems().
202+
203+
modules = set(item.module for item in items)
204+
205+
for module in modules:
206+
for name in dir(module):
207+
obj = getattr(module, name)
208+
if hasattr(obj, "_pytestfixturefunction") and isinstance(
209+
obj._pytestfixturefunction, _pytest.fixtures.FixtureFunctionMarker
210+
):
211+
delattr(module, name)
212+
213+
214+
def _pytest_target_params(targets, excluded_targets=None, xfail_targets=None):
215+
# Include unrunnable targets here. They get skipped by the
216+
# pytest.mark.skipif in _target_to_requirement(), showing up as
217+
# skipped tests instead of being hidden entirely.
218+
if targets is None:
219+
if excluded_targets is None:
220+
excluded_targets = set()
221+
222+
if xfail_targets is None:
223+
xfail_targets = set()
224+
225+
target_marks = []
226+
for t in utils._get_targets():
227+
# Excluded targets aren't included in the params at all.
228+
if t["target_kind"] not in excluded_targets:
229+
230+
# Known failing targets are included, but are marked
231+
# as expected to fail.
232+
extra_marks = []
233+
if t["target_kind"] in xfail_targets:
234+
extra_marks.append(
235+
pytest.mark.xfail(
236+
reason='Known failing test for target "{}"'.format(t["target_kind"])
237+
)
238+
)
239+
240+
target_marks.append((t["target"], extra_marks))
241+
242+
else:
243+
target_marks = [(target, []) for target in targets]
244+
245+
return [
246+
pytest.param(target, marks=_target_to_requirement(target) + extra_marks)
247+
for target, extra_marks in target_marks
248+
]
249+
250+
251+
def _target_to_requirement(target):
252+
if isinstance(target, str):
253+
target = tvm.target.Target(target)
254+
255+
# mapping from target to decorator
256+
if target.kind.name == "cuda" and "cudnn" in target.attrs.get("libs", []):
257+
return utils.requires_cudnn()
258+
if target.kind.name == "cuda":
259+
return utils.requires_cuda()
260+
if target.kind.name == "rocm":
261+
return utils.requires_rocm()
262+
if target.kind.name == "vulkan":
263+
return utils.requires_vulkan()
264+
if target.kind.name == "nvptx":
265+
return utils.requires_nvptx()
266+
if target.kind.name == "metal":
267+
return utils.requires_metal()
268+
if target.kind.name == "opencl":
269+
return utils.requires_opencl()
270+
if target.kind.name == "llvm":
271+
return utils.requires_llvm()
272+
return []
273+
274+
275+
def _parametrize_correlated_parameters(metafunc):
276+
parametrize_needed = collections.defaultdict(list)
277+
278+
for name, fixturedefs in metafunc.definition._fixtureinfo.name2fixturedefs.items():
279+
fixturedef = fixturedefs[-1]
280+
if hasattr(fixturedef.func, "parametrize_group") and hasattr(
281+
fixturedef.func, "parametrize_values"
282+
):
283+
group = fixturedef.func.parametrize_group
284+
values = fixturedef.func.parametrize_values
285+
parametrize_needed[group].append((name, values))
286+
287+
for parametrize_group in parametrize_needed.values():
288+
if len(parametrize_group) == 1:
289+
name, values = parametrize_group[0]
290+
metafunc.parametrize(name, values, indirect=True)
291+
else:
292+
names = ",".join(name for name, values in parametrize_group)
293+
value_sets = zip(*[values for name, values in parametrize_group])
294+
metafunc.parametrize(names, value_sets, indirect=True)

0 commit comments

Comments
 (0)