Skip to content

Commit d4b01e6

Browse files
committed
Directly tests export_ckpt function instead of using command_line_tests
Performance Before: 94.81s call tests/test_bundle_ckpt_export.py::TestCKPTExport::test_default_value_1_model 20.95s call tests/test_bundle_ckpt_export.py::TestCKPTExport::test_default_value_0_ 15.26s call tests/test_bundle_ckpt_export.py::TestCKPTExport::test_export_2_model 14.86s call tests/test_bundle_ckpt_export.py::TestCKPTExport::test_default_value_2_model 14.55s call tests/test_bundle_ckpt_export.py::TestCKPTExport::test_export_1_model 14.28s call tests/test_bundle_ckpt_export.py::TestCKPTExport::test_export_0_ Performance after: 1.62s call tests/test_bundle_ckpt_export.py::TestCKPTExport::test_ckpt_export_2_model 1.25s call tests/test_bundle_ckpt_export.py::TestCKPTExport::test_ckpt_export_default_2_model 0.64s call tests/test_bundle_ckpt_export.py::TestCKPTExport::test_ckpt_export_0_ 0.57s call tests/test_bundle_ckpt_export.py::TestCKPTExport::test_ckpt_export_1_model 0.57s call tests/test_bundle_ckpt_export.py::TestCKPTExport::test_ckpt_export_default_1_model 0.55s call tests/test_bundle_ckpt_export.py::TestCKPTExport::test_ckpt_export_default_0_ 0.01s setup tests/test_bundle_ckpt_export.py::TestCKPTExport::test_ckpt_export_0_
1 parent e8d3e19 commit d4b01e6

File tree

1 file changed

+37
-27
lines changed

1 file changed

+37
-27
lines changed

tests/test_bundle_ckpt_export.py

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,10 @@
1919
from parameterized import parameterized
2020

2121
from monai.bundle import ConfigParser
22+
from monai.bundle.scripts import ckpt_export
2223
from monai.data import load_net_with_metadata
2324
from monai.networks import save_state
24-
from tests.utils import command_line_tests, skip_if_windows
25+
from tests.utils import skip_if_windows
2526

2627
TEST_CASE_1 = ["", ""]
2728

@@ -51,8 +52,6 @@ def setUp(self):
5152
self.parser.export_config_file(config=self.def_args, filepath=self.def_args_file)
5253
self.parser.read_config(self.config_file)
5354
self.net = self.parser.get_parsed_content("network_def")
54-
self.cmd = ["coverage", "run", "-m", "monai.bundle", "ckpt_export", "network_def", "--filepath", self.ts_file]
55-
self.cmd += ["--meta_file", self.meta_file, "--config_file", f"['{self.config_file}','{self.def_args_file}']", "--ckpt_file"]
5655

5756
def tearDown(self):
5857
if self.device is not None:
@@ -62,36 +61,47 @@ def tearDown(self):
6261
self.tempdir_obj.cleanup()
6362

6463
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
65-
def test_export(self, key_in_ckpt, use_trace):
66-
save_state(src=self.net if key_in_ckpt == "" else {key_in_ckpt: self.net}, path=self.ckpt_file)
67-
full_cmd = self.cmd + [self.ckpt_file, "--key_in_ckpt", key_in_ckpt, "--args_file", self.def_args_file]
68-
if use_trace == "True":
69-
full_cmd += ["--use_trace", use_trace, "--input_shape", "[1, 1, 96, 96, 96]"]
70-
command_line_tests(full_cmd)
71-
self.assertTrue(os.path.exists(self.ts_file))
72-
73-
_, metadata, extra_files = load_net_with_metadata(
74-
self.ts_file, more_extra_files=["inference.json", "def_args.json"]
75-
)
76-
self.assertIn("schema", metadata)
77-
self.assertIn("meta_file", json.loads(extra_files["def_args.json"]))
78-
self.assertIn("network_def", json.loads(extra_files["inference.json"]))
79-
80-
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
81-
def test_default_value(self, key_in_ckpt, use_trace):
64+
def test_ckpt_export_default(self, key_in_ckpt, use_trace):
8265
ckpt_file = os.path.join(self.tempdir_obj.name, "models/model.pt")
8366
ts_file = os.path.join(self.tempdir_obj.name, "models/model.ts")
8467

8568
save_state(src=self.net if key_in_ckpt == "" else {key_in_ckpt: self.net}, path=ckpt_file)
86-
87-
# check with default value
88-
cmd = ["coverage", "run", "-m", "monai.bundle", "ckpt_export", "--key_in_ckpt", key_in_ckpt]
89-
cmd += ["--config_file", self.config_file, "--bundle_root", self.tempdir_obj.name]
90-
if use_trace == "True":
91-
cmd += ["--use_trace", use_trace, "--input_shape", "[1, 1, 96, 96, 96]"]
92-
command_line_tests(cmd)
69+
ckpt_export(
70+
net_id="network_def",
71+
filepath=ts_file,
72+
meta_file=self.meta_file,
73+
config_file=self.config_file,
74+
ckpt_file=ckpt_file,
75+
key_in_ckpt=key_in_ckpt,
76+
args_file=self.def_args_file,
77+
use_trace=use_trace,
78+
input_shape=[1, 1, 96, 96, 96] if use_trace == "True" else None,
79+
)
9380
self.assertTrue(os.path.exists(ts_file))
9481

82+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
83+
def test_ckpt_export(self, key_in_ckpt, use_trace):
84+
save_state(src=self.net if key_in_ckpt == "" else {key_in_ckpt: self.net}, path=self.ckpt_file)
85+
ckpt_export(
86+
net_id="network_def",
87+
filepath=self.ts_file,
88+
meta_file=self.meta_file,
89+
config_file=[self.config_file, self.def_args_file],
90+
ckpt_file=self.ckpt_file,
91+
key_in_ckpt=key_in_ckpt,
92+
args_file=self.def_args_file,
93+
use_trace=use_trace,
94+
input_shape=[1, 1, 96, 96, 96] if use_trace == "True" else None,
95+
)
96+
self.assertTrue(os.path.exists(self.ts_file))
97+
98+
_, metadata, extra_files = load_net_with_metadata(
99+
self.ts_file, more_extra_files=["inference.json", "def_args.json"]
100+
)
101+
self.assertIn("schema", metadata)
102+
self.assertIn("meta_file", json.loads(extra_files["def_args.json"]))
103+
self.assertIn("network_def", json.loads(extra_files["inference.json"]))
104+
95105

96106
if __name__ == "__main__":
97107
unittest.main()

0 commit comments

Comments
 (0)