Skip to content

Commit

Permalink
[TVMC] Python Scripting Init Files (#7698)
Browse files Browse the repository at this point in the history
* add to init files for clean tvmc python

* black reformat init.py

* adjust tests to new imports

* black test files

* tell lint ignore defined-builtin error for tvmc compile

* add colon to match lint syntax

* change import so must use tvm.driver.tvmc instead of tvm.tvmc

Co-authored-by: Jocelyn <jocelyn@pop-os.localdomain>
  • Loading branch information
CircleSpin and Jocelyn authored Mar 24, 2021
1 parent 8131364 commit 1fe0abc
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 22 deletions.
4 changes: 4 additions & 0 deletions python/tvm/driver/tvmc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=redefined-builtin
"""
TVMC - TVM driver command-line interface
"""

from . import autotuner
from . import compiler
from . import runner
from .frontends import load_model as load
from .compiler import compile_model as compile
from .runner import run_module as run
20 changes: 9 additions & 11 deletions tests/python/driver/tvmc/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_save_dumps(tmpdir_factory):
def verify_compile_tflite_module(model, shape_dict=None):
pytest.importorskip("tflite")

graph, lib, params, dumps = tvmc.compiler.compile_model(
graph, lib, params, dumps = tvmc.compile(
model, target="llvm", dump_code="ll", alter_layout="NCHW", shape_dict=shape_dict
)

Expand Down Expand Up @@ -74,7 +74,7 @@ def test_compile_tflite_module(tflite_mobilenet_v1_1_quant):
def test_cross_compile_aarch64_tflite_module(tflite_mobilenet_v1_1_quant):
pytest.importorskip("tflite")

graph, lib, params, dumps = tvmc.compiler.compile_model(
graph, lib, params, dumps = tvmc.compile(
tflite_mobilenet_v1_1_quant,
target="llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr='+neon'",
dump_code="asm",
Expand All @@ -91,9 +91,7 @@ def test_compile_keras__save_module(keras_resnet50, tmpdir_factory):
# some CI environments wont offer tensorflow/Keras, so skip in case it is not present
pytest.importorskip("tensorflow")

graph, lib, params, dumps = tvmc.compiler.compile_model(
keras_resnet50, target="llvm", dump_code="ll"
)
graph, lib, params, dumps = tvmc.compile(keras_resnet50, target="llvm", dump_code="ll")

expected_temp_dir = tmpdir_factory.mktemp("saved_output")
expected_file_name = "saved.tar"
Expand All @@ -111,7 +109,7 @@ def test_cross_compile_aarch64_keras_module(keras_resnet50):
# some CI environments wont offer tensorflow/Keras, so skip in case it is not present
pytest.importorskip("tensorflow")

graph, lib, params, dumps = tvmc.compiler.compile_model(
graph, lib, params, dumps = tvmc.compile(
keras_resnet50,
target="llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr='+neon'",
dump_code="asm",
Expand All @@ -129,7 +127,7 @@ def verify_compile_onnx_module(model, shape_dict=None):
# some CI environments wont offer onnx, so skip in case it is not present
pytest.importorskip("onnx")

graph, lib, params, dumps = tvmc.compiler.compile_model(
graph, lib, params, dumps = tvmc.compile(
model, target="llvm", dump_code="ll", shape_dict=shape_dict
)

Expand Down Expand Up @@ -158,7 +156,7 @@ def test_cross_compile_aarch64_onnx_module(onnx_resnet50):
# some CI environments wont offer onnx, so skip in case it is not present
pytest.importorskip("onnx")

graph, lib, params, dumps = tvmc.compiler.compile_model(
graph, lib, params, dumps = tvmc.compile(
onnx_resnet50,
target="llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+neon",
dump_code="asm",
Expand All @@ -176,7 +174,7 @@ def test_cross_compile_aarch64_onnx_module(onnx_resnet50):
def test_compile_opencl(tflite_mobilenet_v1_0_25_128):
pytest.importorskip("tflite")

graph, lib, params, dumps = tvmc.compiler.compile_model(
graph, lib, params, dumps = tvmc.compile(
tflite_mobilenet_v1_0_25_128,
target="opencl",
target_host="llvm",
Expand All @@ -197,7 +195,7 @@ def test_compile_opencl(tflite_mobilenet_v1_0_25_128):
def test_compile_tflite_module_with_external_codegen(tflite_mobilenet_v1_1_quant):
pytest.importorskip("tflite")

graph, lib, params, dumps = tvmc.compiler.compile_model(
graph, lib, params, dumps = tvmc.compile(
tflite_mobilenet_v1_1_quant, target="ethos-n77, llvm", dump_code="relay"
)

Expand All @@ -221,7 +219,7 @@ def test_compile_check_configs_composite_target(mock_pc, mock_fe, mock_ct, mock_
mock_ct.return_value = mock_codegen
mock_relay.return_value = mock.MagicMock()

graph, lib, params, dumps = tvmc.compiler.compile_model(
graph, lib, params, dumps = tvmc.compile(
"no_file_needed", target="mockcodegen -testopt=value, llvm"
)

Expand Down
18 changes: 8 additions & 10 deletions tests/python/driver/tvmc/test_frontends.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,22 +93,22 @@ def test_load_model__invalid_path__no_language():
pytest.importorskip("tflite")

with pytest.raises(FileNotFoundError):
tvmc.frontends.load_model("not/a/file.tflite")
tvmc.load("not/a/file.tflite")


def test_load_model__invalid_path__with_language():
# some CI environments wont offer onnx, so skip in case it is not present
pytest.importorskip("onnx")

with pytest.raises(FileNotFoundError):
tvmc.frontends.load_model("not/a/file.txt", model_format="onnx")
tvmc.load("not/a/file.txt", model_format="onnx")


def test_load_model__tflite(tflite_mobilenet_v1_1_quant):
# some CI environments wont offer TFLite, so skip in case it is not present
pytest.importorskip("tflite")

mod, params = tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant)
mod, params = tvmc.load(tflite_mobilenet_v1_1_quant)
assert type(mod) is IRModule
assert type(params) is dict
# check whether one known value is part of the params dict
Expand Down Expand Up @@ -149,7 +149,7 @@ def test_load_model__pb(pb_mobilenet_v1_1_quant):
# some CI environments wont offer TensorFlow, so skip in case it is not present
pytest.importorskip("tensorflow")

mod, params = tvmc.frontends.load_model(pb_mobilenet_v1_1_quant)
mod, params = tvmc.load(pb_mobilenet_v1_1_quant)
assert type(mod) is IRModule
assert type(params) is dict
# check whether one known value is part of the params dict
Expand All @@ -161,7 +161,7 @@ def test_load_model___wrong_language__to_keras(tflite_mobilenet_v1_1_quant):
pytest.importorskip("tensorflow")

with pytest.raises(OSError):
tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant, model_format="keras")
tvmc.load(tflite_mobilenet_v1_1_quant, model_format="keras")


def test_load_model___wrong_language__to_tflite(keras_resnet50):
Expand All @@ -179,7 +179,7 @@ def test_load_model___wrong_language__to_onnx(tflite_mobilenet_v1_1_quant):
from google.protobuf.message import DecodeError

with pytest.raises(DecodeError):
tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant, model_format="onnx")
tvmc.load(tflite_mobilenet_v1_1_quant, model_format="onnx")


@pytest.mark.skip(reason="https://github.com/apache/tvm/issues/7455")
Expand All @@ -188,9 +188,7 @@ def test_load_model__pth(pytorch_resnet18):
pytest.importorskip("torch")
pytest.importorskip("torchvision")

mod, params = tvmc.frontends.load_model(
pytorch_resnet18, shape_dict={"input": [1, 3, 224, 224]}
)
mod, params = tvmc.load(pytorch_resnet18, shape_dict={"input": [1, 3, 224, 224]})
assert type(mod) is IRModule
assert type(params) is dict
# check whether one known value is part of the params dict
Expand All @@ -202,7 +200,7 @@ def test_load_model___wrong_language__to_pytorch(tflite_mobilenet_v1_1_quant):
pytest.importorskip("torch")

with pytest.raises(RuntimeError) as e:
tvmc.frontends.load_model(
tvmc.load(
tflite_mobilenet_v1_1_quant,
model_format="pytorch",
shape_dict={"input": [1, 3, 224, 224]},
Expand Down
2 changes: 1 addition & 1 deletion tests/python/driver/tvmc/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_run_tflite_module__with_profile__valid_input(
# some CI environments wont offer TFLite, so skip in case it is not present
pytest.importorskip("tflite")

outputs, times = tvmc.runner.run_module(
outputs, times = tvmc.run(
tflite_compiled_module_as_tarfile,
inputs_file=imagenet_cat,
hostname=None,
Expand Down

0 comments on commit 1fe0abc

Please sign in to comment.