Skip to content

Commit

Permalink
[TVMC] Split common tvmc test file into more specific files (#9206)
Browse files Browse the repository at this point in the history
The `test_tvmc_common.py` file was becoming a bit of a mixed bag of
tests and as we now want to extend the `Target` processing logic it made
sense to split each out into its own file to make it clearer what each
does.

`test_common.py` has also been renamed before we start using it for all the
tests instead.
  • Loading branch information
Mousius authored Oct 12, 2021
1 parent f4922bc commit 9f27be6
Show file tree
Hide file tree
Showing 7 changed files with 489 additions and 416 deletions.
128 changes: 126 additions & 2 deletions tests/python/driver/tvmc/test_frontends.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os
import tarfile

import pytest

import tvm
from tvm.ir.module import IRModule

from tvm.driver import tvmc
Expand Down Expand Up @@ -229,3 +228,128 @@ def test_load_model___wrong_language__to_pytorch(tflite_mobilenet_v1_1_quant):
model_format="pytorch",
shape_dict={"input": [1, 3, 224, 224]},
)


def test_compile_tflite_module_nhwc_to_nchw(tflite_mobilenet_v1_1_quant):
# some CI environments wont offer TFLite, so skip in case it is not present
pytest.importorskip("tflite")

tvmc_model = tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant)
before = tvmc_model.mod

expected_layout = "NCHW"
after = tvmc.common.convert_graph_layout(before, expected_layout)

layout_transform_calls = []

def _is_layout_transform(node):
if isinstance(node, tvm.relay.expr.Call):
layout_transform_calls.append(
node.op.name == "layout_transform"
and node.attrs.src_layout == "NHWC"
and node.attrs.dst_layout == "NCHW"
)

tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform)

assert any(layout_transform_calls), "Expected 'layout_transform NHWC->NCHW' not found"


def test_compile_onnx_module_nchw_to_nhwc(onnx_resnet50):
# some CI environments wont offer ONNX, so skip in case it is not present
pytest.importorskip("onnx")

tvmc_model = tvmc.frontends.load_model(onnx_resnet50)
before = tvmc_model.mod

expected_layout = "NHWC"
after = tvmc.common.convert_graph_layout(before, expected_layout)

layout_transform_calls = []

def _is_layout_transform(node):
if isinstance(node, tvm.relay.expr.Call):
layout_transform_calls.append(
node.op.name == "layout_transform"
and node.attrs.src_layout == "NCHW"
and node.attrs.dst_layout == "NHWC"
)

tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform)

assert any(layout_transform_calls), "Expected 'layout_transform NCWH->NHWC' not found"


def test_compile_paddle_module_nchw_to_nhwc(paddle_resnet50):
# some CI environments wont offer Paddle, so skip in case it is not present
pytest.importorskip("paddle")

tvmc_model = tvmc.frontends.load_model(paddle_resnet50, "paddle")
before = tvmc_model.mod

expected_layout = "NHWC"
after = tvmc.common.convert_graph_layout(before, expected_layout)

layout_transform_calls = []

def _is_layout_transform(node):
if isinstance(node, tvm.relay.expr.Call):
layout_transform_calls.append(
node.op.name == "layout_transform"
and node.attrs.src_layout == "NCHW"
and node.attrs.dst_layout == "NHWC"
)

tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform)

assert any(layout_transform_calls), "Expected 'layout_transform NCWH->NHWC' not found"


def test_compile_tflite_module__same_layout__nhwc_to_nhwc(tflite_mobilenet_v1_1_quant):
# some CI environments wont offer TFLite, so skip in case it is not present
pytest.importorskip("tflite")

tvmc_model = tvmc.frontends.load_model(tflite_mobilenet_v1_1_quant)
before = tvmc_model.mod

expected_layout = "NHWC"
after = tvmc.common.convert_graph_layout(before, expected_layout)

layout_transform_calls = []

def _is_layout_transform(node):
if isinstance(node, tvm.relay.expr.Call):
layout_transform_calls.append(
node.op.name == "layout_transform"
and node.attrs.src_layout == "NHWC"
and node.attrs.dst_layout == "NHWC"
)

tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform)

assert not any(layout_transform_calls), "Unexpected 'layout_transform' call"


def test_compile_onnx_module__same_layout__nchw_to_nchw(onnx_resnet50):
# some CI environments wont offer ONNX, so skip in case it is not present
pytest.importorskip("onnx")

tvmc_model = tvmc.frontends.load_model(onnx_resnet50)
before = tvmc_model.mod

expected_layout = "NCHW"
after = tvmc.common.convert_graph_layout(before, expected_layout)

layout_transform_calls = []

def _is_layout_transform(node):
if isinstance(node, tvm.relay.expr.Call):
layout_transform_calls.append(
node.op.name == "layout_transform"
and node.attrs.src_layout == "NCHW"
and node.attrs.dst_layout == "NCHW"
)

tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform)

assert not any(layout_transform_calls), "Unexpected 'layout_transform' call"
73 changes: 73 additions & 0 deletions tests/python/driver/tvmc/test_pass_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import pytest

from tvm.contrib.target.vitis_ai import vitis_ai_available
from tvm.driver import tvmc

from tvm.driver.tvmc.common import TVMCException


def test_config_invalid_format():
with pytest.raises(TVMCException):
_ = tvmc.common.parse_configs(["relay.backend.use_auto_scheduler.missing.value"])


def test_config_missing_from_tvm():
with pytest.raises(TVMCException):
_ = tvmc.common.parse_configs(["relay.backend.use_auto_scheduler.missing.value=1234"])


def test_config_unsupported_tvmc_config():
with pytest.raises(TVMCException):
_ = tvmc.common.parse_configs(["tir.LoopPartition=value"])


def test_config_empty():
with pytest.raises(TVMCException):
_ = tvmc.common.parse_configs([""])


def test_config_valid_config_bool():
configs = tvmc.common.parse_configs(["relay.backend.use_auto_scheduler=true"])

assert len(configs) == 1
assert "relay.backend.use_auto_scheduler" in configs.keys()
assert configs["relay.backend.use_auto_scheduler"] == True


@pytest.mark.skipif(
not vitis_ai_available(),
reason="--target vitis-ai is not available. TVM built with 'USE_VITIS_AI OFF'",
)
def test_config_valid_multiple_configs():
configs = tvmc.common.parse_configs(
[
"relay.backend.use_auto_scheduler=false",
"tir.detect_global_barrier=10",
"relay.ext.vitis_ai.options.build_dir=mystring",
]
)

assert len(configs) == 3
assert "relay.backend.use_auto_scheduler" in configs.keys()
assert configs["relay.backend.use_auto_scheduler"] == False
assert "tir.detect_global_barrier" in configs.keys()
assert configs["tir.detect_global_barrier"] == 10
assert "relay.ext.vitis_ai.options.build_dir" in configs.keys()
assert configs["relay.ext.vitis_ai.options.build_dir"] == "mystring"
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import argparse
import pytest
from tvm.driver import tvmc


def test_common_parse_pass_list_str():
def test_parse_pass_list_str():
assert [""] == tvmc.common.parse_pass_list_str("")
assert ["FoldScaleAxis", "FuseOps"] == tvmc.common.parse_pass_list_str("FoldScaleAxis,FuseOps")

Expand Down
96 changes: 96 additions & 0 deletions tests/python/driver/tvmc/test_shape_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import argparse

import pytest

from tvm.driver import tvmc


def test_shape_parser():
# Check that a valid input is parsed correctly
shape_string = "input:[10,10,10]"
shape_dict = tvmc.common.parse_shape_string(shape_string)
assert shape_dict == {"input": [10, 10, 10]}


def test_alternate_syntax():
shape_string = "input:0:[10,10,10] input2:[20,20,20,20]"
shape_dict = tvmc.common.parse_shape_string(shape_string)
assert shape_dict == {"input:0": [10, 10, 10], "input2": [20, 20, 20, 20]}


@pytest.mark.parametrize(
"shape_string",
[
"input:[10,10,10] input2:[20,20,20,20]",
"input: [10, 10, 10] input2: [20, 20, 20, 20]",
"input:[10,10,10],input2:[20,20,20,20]",
],
)
def test_alternate_syntaxes(shape_string):
shape_dict = tvmc.common.parse_shape_string(shape_string)
assert shape_dict == {"input": [10, 10, 10], "input2": [20, 20, 20, 20]}


def test_negative_dimensions():
# Check that negative dimensions parse to Any correctly.
shape_string = "input:[-1,3,224,224]"
shape_dict = tvmc.common.parse_shape_string(shape_string)
# Convert to strings to allow comparison with Any.
assert str(shape_dict) == "{'input': [?, 3, 224, 224]}"


def test_multiple_valid_gpu_inputs():
# Check that multiple valid gpu inputs are parsed correctly.
shape_string = "gpu_0/data_0:[1, -1,224,224] gpu_1/data_1:[7, 7]"
shape_dict = tvmc.common.parse_shape_string(shape_string)
expected = "{'gpu_0/data_0': [1, ?, 224, 224], 'gpu_1/data_1': [7, 7]}"
assert str(shape_dict) == expected


def test_invalid_pattern():
shape_string = "input:[a,10]"
with pytest.raises(argparse.ArgumentTypeError):
tvmc.common.parse_shape_string(shape_string)


def test_invalid_separators():
shape_string = "input:5,10 input2:10,10"
with pytest.raises(argparse.ArgumentTypeError):
tvmc.common.parse_shape_string(shape_string)


def test_invalid_colon():
shape_string = "gpu_0/data_0:5,10 :test:10,10"
with pytest.raises(argparse.ArgumentTypeError):
tvmc.common.parse_shape_string(shape_string)


@pytest.mark.parametrize(
"shape_string",
[
"gpu_0/data_0:5,10 /:10,10",
"gpu_0/data_0:5,10 data/:10,10",
"gpu_0/data_0:5,10 /data:10,10",
"gpu_0/invalid/data_0:5,10 data_1:10,10",
],
)
def test_invalid_slashes(shape_string):
with pytest.raises(argparse.ArgumentTypeError):
tvmc.common.parse_shape_string(shape_string)
Loading

0 comments on commit 9f27be6

Please sign in to comment.