Skip to content

Commit 89f2623

Browse files
committed
[TESTS] Refactor tests to run on either the GPU or CPU.
Much of the time spent in testing is duplicated work between CPU and GPU test nodes. The main reason is that there is no way to control which TVM devices are enabled at runtime, so tests that use LLVM will run on both GPU and CPU nodes. This patch adds an environment variable, TVM_TEST_DEVICES, which controls which TVM devices should be used by tests. Devices not in TVM_TEST_DEVICES can still be used, so tests must be careful to check that the desired device is enabled with `tvm.testing.device_enabled` or by enumerating all devices with `tvm.testing.enabled_devices`. All tests have been retrofitted with these checks. This patch also provides the decorator `@tvm.testing.gpu` to mark a test as possibly using the gpu. Tests that require the gpu can use `@tvm.testing.requires_gpu`. Tests without these flags will not be run on GPU nodes.
1 parent c958bc1 commit 89f2623

File tree

152 files changed

+2098
-1503
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

152 files changed

+2098
-1503
lines changed

apps/extension/tests/test_ext.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import tvm_ext
1818
import tvm
1919
import tvm._ffi.registry
20+
import tvm.testing
2021
from tvm import te
2122
import numpy as np
2223

@@ -32,7 +33,7 @@ def test_ext_dev():
3233
B = te.compute((n,), lambda *i: A(*i) + 1.0, name='B')
3334
s = te.create_schedule(B.op)
3435
def check_llvm():
35-
if not tvm.runtime.enabled("llvm"):
36+
if not tvm.testing.device_enabled("llvm"):
3637
return
3738
f = tvm.build(s, [A, B], "ext_dev", "llvm")
3839
ctx = tvm.ext_dev(0)
@@ -77,7 +78,7 @@ def test_extern_call():
7778
s = te.create_schedule(B.op)
7879

7980
def check_llvm():
80-
if not tvm.runtime.enabled("llvm"):
81+
if not tvm.testing.device_enabled("llvm"):
8182
return
8283
f = tvm.build(s, [A, B], "llvm")
8384
ctx = tvm.cpu(0)

conftest.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
import tvm.testing
18+
from pytest import ExitCode
19+
20+
def pytest_configure(config):
21+
print("enabled targets:", "; ".join(map(lambda x: x[0], tvm.testing.enabled_targets())))
22+
print("pytest marker:", config.option.markexpr)
23+
24+
def pytest_sessionfinish(session, exitstatus):
25+
# Don't exit with an error if we select a subset of tests that doesn't
26+
# include anything
27+
if session.config.option.markexpr != '':
28+
if exitstatus == ExitCode.NO_TESTS_COLLECTED:
29+
session.exitstatus = ExitCode.OK

docs/contribute/code_guide.rst

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,20 @@ Python Code Styles
8282
- Stick to language features as in ``python 3.5``
8383

8484

85+
Writing Python Tests
86+
--------------------
87+
We use `pytest <https://docs.pytest.org/en/stable/>`_ for all python testing. ``tests/python`` contains all the tests.
88+
89+
If you want your test to run over a variety of targets, use the :py:func:`tvm.testing.parametrize_targets` decorator. For example:
90+
91+
.. code:: python
92+
93+
@tvm.testing.parametrize_targets
94+
def test_mytest(target, ctx):
95+
...
96+
97+
will run `test_mytest` with `target="llvm"`, `target="cuda"`, and few others. This also ensures that your test is run on the correct hardware by the CI. If you only want to test against a couple targets use `@tvm.testing.parametrize_targets("target_1", "target_2")`. If you want to test on a single target, use the associated decorator from :py:func:`tvm.testing`. For example, CUDA tests use the `@tvm.testing.requires_cuda` decorator.
98+
8599
Handle Integer Constant Expression
86100
----------------------------------
87101
We often need to handle constant integer expressions in TVM. Before we do so, the first question we want to ask is that is it really necessary to get a constant integer. If symbolic expression also works and let the logic flow, we should use symbolic expression as much as possible. So the generated code works for shapes that are not known ahead of time.

python/tvm/relay/testing/config.py renamed to pytest.ini

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,13 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17-
"""Configuration about tests"""
18-
from __future__ import absolute_import as _abs
19-
20-
import os
21-
import tvm
22-
23-
24-
def ctx_list():
25-
"""Get context list for testcases"""
26-
device_list = os.environ.get("RELAY_TEST_TARGETS", "")
27-
device_list = (device_list.split(",") if device_list
28-
else ["llvm", "cuda"])
29-
device_list = set(device_list)
30-
res = [(device, tvm.context(device, 0)) for device in device_list]
31-
return [x for x in res if x[1].exist]
17+
[pytest]
18+
markers =
19+
gpu: mark a test as requiring a gpu
20+
tensorcore: mark a test as requiring a tensorcore
21+
cuda: mark a test as requiring cuda
22+
opencl: mark a test as requiring opencl
23+
rocm: mark a test as requiring rocm
24+
vulkan: mark a test as requiring vulkan
25+
metal: mark a test as requiring metal
26+
llvm: mark a test as requiring llvm

python/tvm/relay/testing/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import tvm.relay as relay
2626
import tvm.relay.op as op
2727
from tvm.relay import Prelude
28+
from tvm.testing import enabled_targets
2829

2930
from . import mlp
3031
from . import resnet
@@ -41,7 +42,6 @@
4142
from . import temp_op_attr
4243
from . import synthetic
4344

44-
from .config import ctx_list
4545
from .init import create_workload
4646
from .nat import add_nat_definitions, count, make_nat_value, make_nat_expr
4747
from .py_converter import to_python, run_as_python
@@ -125,7 +125,7 @@ def check_grad(func,
125125
if test_inputs is None:
126126
test_inputs = inputs
127127

128-
for target, ctx in ctx_list():
128+
for target, ctx in enabled_targets():
129129
intrp = relay.create_executor(ctx=ctx, target=target)
130130

131131
# Get analytic gradients.

0 commit comments

Comments
 (0)