Skip to content

Commit d98f348

Browse files
garciadiasKumoLiu
andauthored
Solves path problem in test_bundle_trt_export.py (#8357)
Fixes #8354 ### Description Fixes path on test that is only run on special conditions. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: R. Garcia-Dias <rafaelagd@gmail.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
1 parent af54a17 commit d98f348

File tree

5 files changed

+24
-9
lines changed

5 files changed

+24
-9
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,13 @@
1818

1919
MONAI is a [PyTorch](https://pytorch.org/)-based, [open-source](https://github.com/Project-MONAI/MONAI/blob/dev/LICENSE) framework for deep learning in healthcare imaging, part of the [PyTorch Ecosystem](https://pytorch.org/ecosystem/).
2020
Its ambitions are as follows:
21+
2122
- Developing a community of academic, industrial and clinical researchers collaborating on a common foundation;
2223
- Creating state-of-the-art, end-to-end training workflows for healthcare imaging;
2324
- Providing researchers with the optimized and standardized way to create and evaluate deep learning models.
2425

25-
2626
## Features
27+
2728
> _Please see [the technical highlights](https://docs.monai.io/en/latest/highlights.html) and [What's New](https://docs.monai.io/en/latest/whatsnew.html) of the milestone releases._
2829
2930
- flexible pre-processing for multi-dimensional medical imaging data;

tests/bundle/test_bundle_trt_export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def tearDown(self):
7070
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])
7171
@unittest.skipUnless(has_torchtrt and has_tensorrt, "Torch-TensorRT is required for conversion!")
7272
def test_trt_export(self, convert_precision, input_shape, dynamic_batch):
73-
tests_dir = Path(__file__).resolve().parent
73+
tests_dir = Path(__file__).resolve().parents[1]
7474
meta_file = os.path.join(tests_dir, "testing_data", "metadata.json")
7575
config_file = os.path.join(tests_dir, "testing_data", "inference.json")
7676
with tempfile.TemporaryDirectory() as tempdir:

tests/networks/test_convert_to_onnx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def test_unet(self, device, use_trace, use_ort):
6464
rtol=rtol,
6565
atol=atol,
6666
)
67-
self.assertTrue(isinstance(onnx_model, onnx.ModelProto))
67+
self.assertTrue(isinstance(onnx_model, onnx.ModelProto))
6868

6969
@parameterized.expand(TESTS_ORT)
7070
@SkipIfBeforePyTorchVersion((1, 12))

tests/test_utils.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,10 @@
3030
import warnings
3131
from contextlib import contextmanager
3232
from functools import partial, reduce
33+
from itertools import product
3334
from pathlib import Path
3435
from subprocess import PIPE, Popen
35-
from typing import Callable
36+
from typing import Callable, Literal
3637
from urllib.error import ContentTooShortError, HTTPError
3738

3839
import numpy as np
@@ -862,6 +863,21 @@ def equal_state_dict(st_1, st_2):
862863
if torch.cuda.is_available():
863864
TEST_DEVICES.append([torch.device("cuda")])
864865

866+
867+
def dict_product(trailing=False, format: Literal["list", "dict"] = "dict", **items):
868+
keys = items.keys()
869+
values = items.values()
870+
for pvalues in product(*values):
871+
dict_comb = dict(zip(keys, pvalues))
872+
if format == "dict":
873+
if trailing:
874+
yield [dict_comb] + list(pvalues)
875+
else:
876+
yield dict_comb
877+
else:
878+
yield pvalues
879+
880+
865881
if __name__ == "__main__":
866882
parser = argparse.ArgumentParser(prog="util")
867883
parser.add_argument("-c", "--count", default=2, help="max number of gpus")

tests/transforms/test_gibbs_noise.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,12 @@
2121
from monai.transforms import GibbsNoise
2222
from monai.utils.misc import set_determinism
2323
from monai.utils.module import optional_import
24-
from tests.test_utils import TEST_NDARRAYS, assert_allclose
24+
from tests.test_utils import TEST_NDARRAYS, assert_allclose, dict_product
2525

2626
_, has_torch_fft = optional_import("torch.fft", name="fftshift")
2727

28-
TEST_CASES = []
29-
for shape in ((128, 64), (64, 48, 80)):
30-
for input_type in TEST_NDARRAYS if has_torch_fft else [np.array]:
31-
TEST_CASES.append((shape, input_type))
28+
params = {"shape": ((128, 64), (64, 48, 80)), "input_type": TEST_NDARRAYS if has_torch_fft else [np.array]}
29+
TEST_CASES = list(dict_product(format="list", **params))
3230

3331

3432
class TestGibbsNoise(unittest.TestCase):

0 commit comments

Comments
 (0)