Skip to content

Commit

Permalink
[UnitTest] Bugfix, applying requires_* markers to parametrized targets.
Browse files Browse the repository at this point in the history
Initial implementation did work correctly with
@tvm.testing.parametrize_targets.

Also, went through all cases where "target" is used to parametrize on
something other than a target string, and renamed.
  • Loading branch information
Lunderberg committed Jul 28, 2021
1 parent 844546e commit 8ee0702
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 28 deletions.
67 changes: 45 additions & 22 deletions python/tvm/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,33 +794,56 @@ def _auto_parametrize_target(metafunc):
file.
"""
if "target" in metafunc.fixturenames:
for mark in metafunc.definition.iter_markers("parametrize"):
args = [arg.strip() for arg in mark.args[0].split(",") if arg.strip()]

if "target" in args:
# Test explicitly called @pytest.mark.parametrize to
# set the targets. Make sure that it gets the
# appropriate additional marks for that target.
param_sets = mark.args[1]

if len(args) == 1:
targets = param_sets
else:
target_i = args.index("target")
targets = [param_set[target_i] for param_set in param_sets]
def update_parametrize_target_arg(
argnames,
argvalues,
*args,
**kwargs,
):
args = [arg.strip() for arg in argnames.split(",") if arg.strip()]
if "target" in args:
if len(args) == 1:
targets = argvalues
param_sets = [(target,) for target in targets]
else:
target_i = args.index("target")
targets = [param_set[target_i] for param_set in argvalues]
param_sets = argvalues

new_param_sets = [
try:
argvalues[:] = [
pytest.param(*param_set, marks=_target_to_requirement(target))
for target, param_set in zip(targets, param_sets)
]
param_sets[:] = new_param_sets
break
except TypeError as e:
pyfunc = metafunc.definition.function
filename = pyfunc.__code__.co_filename
line_number = pyfunc.__code__.co_firstlineno
msg = (
f"Unit test {metafunc.function.__name__} ({filename}:{line_number}) "
"is parametrized using a tuple of parameters instead of a list "
"of parameters."
)
raise TypeError(msg) from e

else:
# No existing parametrization, time to add one, checking
# if the function is marked with either excluded or known
# failing targets.
if "target" in metafunc.fixturenames:
# Update any explicit use of @pytest.mark.parmaetrize to
# parametrize over targets. This adds the appropriate
# @tvm.testing.requires_* markers for each target.
for mark in metafunc.definition.iter_markers("parametrize"):
update_parametrize_target_arg(*mark.args, **mark.kwargs)

# Check if any explicit parametrizations exist, and apply one
# if they do not. If the function is marked with either
# excluded or known failing targets, use these to determine
# the targets to be used.
parametrized_args = [
arg.strip()
for mark in metafunc.definition.iter_markers("parametrize")
for arg in mark.args[0].split(",")
]
if "target" not in parametrized_args:
excluded_targets = getattr(metafunc.function, "tvm_excluded_targets", [])
xfail_targets = getattr(metafunc.function, "tvm_known_failing_targets", [])
metafunc.parametrize(
Expand Down Expand Up @@ -873,7 +896,7 @@ def parametrize_targets(*args):
if len(args) == 1 and callable(args[0]):
return args

return pytest.mark.parametrize("target", args, scope="session")
return pytest.mark.parametrize("target", list(args), scope="session")


def exclude_targets(*args):
Expand Down
12 changes: 6 additions & 6 deletions tests/python/unittest/test_micro_model_library_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,14 @@ def validate_graph_json(extract_dir, factory):

@tvm.testing.requires_micro
@pytest.mark.parametrize(
"target",
"exe_target",
[
("graph", tvm.target.target.micro("host")),
("aot", tvm.target.target.micro("host", options="-executor=aot")),
],
)
def test_export_model_library_format_c(target):
executor, _target = target
def test_export_model_library_format_c(exe_target):
executor, _target = exe_target
with utils.TempDirectory.set_keep_for_debug(True):
with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}):
relay_mod = tvm.parser.fromtext(
Expand Down Expand Up @@ -264,14 +264,14 @@ def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), float32], %c : Tensor[

@tvm.testing.requires_micro
@pytest.mark.parametrize(
"target",
"exe_target",
[
("graph", tvm.target.target.micro("host")),
("aot", tvm.target.target.micro("host", options="-executor=aot")),
],
)
def test_export_model_library_format_workspace(target):
executor, _target = target
def test_export_model_library_format_workspace(exe_target):
executor, _target = exe_target
with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}):
relay_mod = tvm.parser.fromtext(
"""
Expand Down
30 changes: 30 additions & 0 deletions tests/python/unittest/test_tvm_testing_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,5 +180,35 @@ def test_num_uses_cached(self):
assert self.num_uses_broken_cached_fixture == 0


class TestAutomaticMarks:
@staticmethod
def check_marks(request, target):
parameter = tvm.testing._pytest_target_params([target])[0]
required_marks = [decorator.mark for decorator in parameter.marks]
applied_marks = list(request.node.iter_markers())

for required_mark in required_marks:
assert required_mark in applied_marks

def test_automatic_fixture(self, request, target):
self.check_marks(request, target)

@tvm.testing.parametrize_targets
def test_bare_parametrize(self, request, target):
self.check_marks(request, target)

@tvm.testing.parametrize_targets("llvm", "cuda", "vulkan")
def test_explicit_parametrize(self, request, target):
self.check_marks(request, target)

@pytest.mark.parametrize("target", ["llvm", "cuda", "vulkan"])
def test_pytest_mark(self, request, target):
self.check_marks(request, target)

@pytest.mark.parametrize("target,other_param", [("llvm", 0), ("cuda", 1), ("vulkan", 2)])
def test_pytest_mark_covariant(self, request, target, other_param):
self.check_marks(request, target)


if __name__ == "__main__":
sys.exit(pytest.main(sys.argv))

0 comments on commit 8ee0702

Please sign in to comment.