-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[TVMC] Split common tvmc test file into more specific files (#9206)
The `test_tvmc_common.py` file was becoming a bit of a mixed bag of tests and as we now want to extend the `Target` processing logic it made sense to split each out into its own file to make it clearer what each does. `test_common.py` has also been renamed before we start using it for all the tests instead.
- Loading branch information
Showing
7 changed files
with
489 additions
and
416 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
import pytest | ||
|
||
from tvm.contrib.target.vitis_ai import vitis_ai_available | ||
from tvm.driver import tvmc | ||
|
||
from tvm.driver.tvmc.common import TVMCException | ||
|
||
|
||
def test_config_invalid_format(): | ||
with pytest.raises(TVMCException): | ||
_ = tvmc.common.parse_configs(["relay.backend.use_auto_scheduler.missing.value"]) | ||
|
||
|
||
def test_config_missing_from_tvm(): | ||
with pytest.raises(TVMCException): | ||
_ = tvmc.common.parse_configs(["relay.backend.use_auto_scheduler.missing.value=1234"]) | ||
|
||
|
||
def test_config_unsupported_tvmc_config(): | ||
with pytest.raises(TVMCException): | ||
_ = tvmc.common.parse_configs(["tir.LoopPartition=value"]) | ||
|
||
|
||
def test_config_empty(): | ||
with pytest.raises(TVMCException): | ||
_ = tvmc.common.parse_configs([""]) | ||
|
||
|
||
def test_config_valid_config_bool(): | ||
configs = tvmc.common.parse_configs(["relay.backend.use_auto_scheduler=true"]) | ||
|
||
assert len(configs) == 1 | ||
assert "relay.backend.use_auto_scheduler" in configs.keys() | ||
assert configs["relay.backend.use_auto_scheduler"] == True | ||
|
||
|
||
@pytest.mark.skipif( | ||
not vitis_ai_available(), | ||
reason="--target vitis-ai is not available. TVM built with 'USE_VITIS_AI OFF'", | ||
) | ||
def test_config_valid_multiple_configs(): | ||
configs = tvmc.common.parse_configs( | ||
[ | ||
"relay.backend.use_auto_scheduler=false", | ||
"tir.detect_global_barrier=10", | ||
"relay.ext.vitis_ai.options.build_dir=mystring", | ||
] | ||
) | ||
|
||
assert len(configs) == 3 | ||
assert "relay.backend.use_auto_scheduler" in configs.keys() | ||
assert configs["relay.backend.use_auto_scheduler"] == False | ||
assert "tir.detect_global_barrier" in configs.keys() | ||
assert configs["tir.detect_global_barrier"] == 10 | ||
assert "relay.ext.vitis_ai.options.build_dir" in configs.keys() | ||
assert configs["relay.ext.vitis_ai.options.build_dir"] == "mystring" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
import argparse | ||
|
||
import pytest | ||
|
||
from tvm.driver import tvmc | ||
|
||
|
||
def test_shape_parser(): | ||
# Check that a valid input is parsed correctly | ||
shape_string = "input:[10,10,10]" | ||
shape_dict = tvmc.common.parse_shape_string(shape_string) | ||
assert shape_dict == {"input": [10, 10, 10]} | ||
|
||
|
||
def test_alternate_syntax(): | ||
shape_string = "input:0:[10,10,10] input2:[20,20,20,20]" | ||
shape_dict = tvmc.common.parse_shape_string(shape_string) | ||
assert shape_dict == {"input:0": [10, 10, 10], "input2": [20, 20, 20, 20]} | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"shape_string", | ||
[ | ||
"input:[10,10,10] input2:[20,20,20,20]", | ||
"input: [10, 10, 10] input2: [20, 20, 20, 20]", | ||
"input:[10,10,10],input2:[20,20,20,20]", | ||
], | ||
) | ||
def test_alternate_syntaxes(shape_string): | ||
shape_dict = tvmc.common.parse_shape_string(shape_string) | ||
assert shape_dict == {"input": [10, 10, 10], "input2": [20, 20, 20, 20]} | ||
|
||
|
||
def test_negative_dimensions(): | ||
# Check that negative dimensions parse to Any correctly. | ||
shape_string = "input:[-1,3,224,224]" | ||
shape_dict = tvmc.common.parse_shape_string(shape_string) | ||
# Convert to strings to allow comparison with Any. | ||
assert str(shape_dict) == "{'input': [?, 3, 224, 224]}" | ||
|
||
|
||
def test_multiple_valid_gpu_inputs(): | ||
# Check that multiple valid gpu inputs are parsed correctly. | ||
shape_string = "gpu_0/data_0:[1, -1,224,224] gpu_1/data_1:[7, 7]" | ||
shape_dict = tvmc.common.parse_shape_string(shape_string) | ||
expected = "{'gpu_0/data_0': [1, ?, 224, 224], 'gpu_1/data_1': [7, 7]}" | ||
assert str(shape_dict) == expected | ||
|
||
|
||
def test_invalid_pattern(): | ||
shape_string = "input:[a,10]" | ||
with pytest.raises(argparse.ArgumentTypeError): | ||
tvmc.common.parse_shape_string(shape_string) | ||
|
||
|
||
def test_invalid_separators(): | ||
shape_string = "input:5,10 input2:10,10" | ||
with pytest.raises(argparse.ArgumentTypeError): | ||
tvmc.common.parse_shape_string(shape_string) | ||
|
||
|
||
def test_invalid_colon(): | ||
shape_string = "gpu_0/data_0:5,10 :test:10,10" | ||
with pytest.raises(argparse.ArgumentTypeError): | ||
tvmc.common.parse_shape_string(shape_string) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"shape_string", | ||
[ | ||
"gpu_0/data_0:5,10 /:10,10", | ||
"gpu_0/data_0:5,10 data/:10,10", | ||
"gpu_0/data_0:5,10 /data:10,10", | ||
"gpu_0/invalid/data_0:5,10 data_1:10,10", | ||
], | ||
) | ||
def test_invalid_slashes(shape_string): | ||
with pytest.raises(argparse.ArgumentTypeError): | ||
tvmc.common.parse_shape_string(shape_string) |
Oops, something went wrong.