1919from parameterized import parameterized
2020
2121from monai .bundle import ConfigParser
22- from monai .bundle .scripts import ckpt_export
2322from monai .data import load_net_with_metadata
2423from monai .networks import save_state
25- from tests .utils import skip_if_windows
24+ from tests .utils import command_line_tests , skip_if_windows
2625
2726TEST_CASE_1 = ["" , "" ]
2827
@@ -52,6 +51,8 @@ def setUp(self):
5251 self .parser .export_config_file (config = self .def_args , filepath = self .def_args_file )
5352 self .parser .read_config (self .config_file )
5453 self .net = self .parser .get_parsed_content ("network_def" )
54+ self .cmd = ["coverage" , "run" , "-m" , "monai.bundle" , "ckpt_export" , "network_def" , "--filepath" , self .ts_file ]
55+ self .cmd += ["--meta_file" , self .meta_file , "--config_file" , f"['{ self .config_file } ','{ self .def_args_file } ']" , "--ckpt_file" ]
5556
5657 def tearDown (self ):
5758 if self .device is not None :
@@ -61,46 +62,35 @@ def tearDown(self):
6162 self .tempdir_obj .cleanup ()
6263
6364 @parameterized .expand ([TEST_CASE_1 , TEST_CASE_2 , TEST_CASE_3 ])
64- def test_ckpt_export_default (self , key_in_ckpt , use_trace ):
65+ def test_export (self , key_in_ckpt , use_trace ):
66+ save_state (src = self .net if key_in_ckpt == "" else {key_in_ckpt : self .net }, path = self .ckpt_file )
67+ full_cmd = self .cmd + [self .ckpt_file , "--key_in_ckpt" , key_in_ckpt , "--args_file" , self .def_args_file ]
68+ if use_trace == "True" :
69+ full_cmd += ["--use_trace" , use_trace , "--input_shape" , "[1, 1, 96, 96, 96]" ]
70+ command_line_tests (full_cmd )
71+ self .assertTrue (os .path .exists (self .ts_file ))
72+
73+ _ , metadata , extra_files = load_net_with_metadata (
74+ self .ts_file , more_extra_files = ["inference.json" , "def_args.json" ]
75+ )
76+ self .assertIn ("schema" , metadata )
77+ self .assertIn ("meta_file" , json .loads (extra_files ["def_args.json" ]))
78+ self .assertIn ("network_def" , json .loads (extra_files ["inference.json" ]))
79+
80+ @parameterized .expand ([TEST_CASE_1 , TEST_CASE_2 , TEST_CASE_3 ])
81+ def test_default_value (self , key_in_ckpt , use_trace ):
6582 ckpt_file = os .path .join (self .tempdir_obj .name , "models/model.pt" )
6683 ts_file = os .path .join (self .tempdir_obj .name , "models/model.ts" )
6784
6885 save_state (src = self .net if key_in_ckpt == "" else {key_in_ckpt : self .net }, path = ckpt_file )
69- ckpt_export (
70- net_id = "network_def" ,
71- filepath = ts_file ,
72- meta_file = self .meta_file ,
73- config_file = self .config_file ,
74- ckpt_file = ckpt_file ,
75- key_in_ckpt = key_in_ckpt ,
76- args_file = self .def_args_file ,
77- use_trace = use_trace ,
78- input_shape = [1 , 1 , 96 , 96 , 96 ] if use_trace == "True" else None ,
79- )
80- self .assertTrue (os .path .exists (ts_file ))
8186
82- @parameterized .expand ([TEST_CASE_1 , TEST_CASE_2 , TEST_CASE_3 ])
83- def test_ckpt_export (self , key_in_ckpt , use_trace ):
84- save_state (src = self .net if key_in_ckpt == "" else {key_in_ckpt : self .net }, path = self .ckpt_file )
85- ckpt_export (
86- net_id = "network_def" ,
87- filepath = self .ts_file ,
88- meta_file = self .meta_file ,
89- config_file = [self .config_file , self .def_args_file ],
90- ckpt_file = self .ckpt_file ,
91- key_in_ckpt = key_in_ckpt ,
92- args_file = self .def_args_file ,
93- use_trace = use_trace ,
94- input_shape = [1 , 1 , 96 , 96 , 96 ] if use_trace == "True" else None ,
95- )
96- self .assertTrue (os .path .exists (self .ts_file ))
97-
98- _ , metadata , extra_files = load_net_with_metadata (
99- self .ts_file , more_extra_files = ["inference.json" , "def_args.json" ]
100- )
101- self .assertIn ("schema" , metadata )
102- self .assertIn ("meta_file" , json .loads (extra_files ["def_args.json" ]))
103- self .assertIn ("network_def" , json .loads (extra_files ["inference.json" ]))
87+ # check with default value
88+ cmd = ["coverage" , "run" , "-m" , "monai.bundle" , "ckpt_export" , "--key_in_ckpt" , key_in_ckpt ]
89+ cmd += ["--config_file" , self .config_file , "--bundle_root" , self .tempdir_obj .name ]
90+ if use_trace == "True" :
91+ cmd += ["--use_trace" , use_trace , "--input_shape" , "[1, 1, 96, 96, 96]" ]
92+ command_line_tests (cmd )
93+ self .assertTrue (os .path .exists (ts_file ))
10494
10595
10696if __name__ == "__main__" :
0 commit comments