1919from parameterized import parameterized
2020
2121from monai .bundle import ConfigParser
22+ from monai .bundle .scripts import ckpt_export
2223from monai .data import load_net_with_metadata
2324from monai .networks import save_state
24- from tests .utils import command_line_tests , skip_if_windows
25+ from tests .utils import skip_if_windows
2526
2627TEST_CASE_1 = ["" , "" ]
2728
@@ -51,8 +52,6 @@ def setUp(self):
5152 self .parser .export_config_file (config = self .def_args , filepath = self .def_args_file )
5253 self .parser .read_config (self .config_file )
5354 self .net = self .parser .get_parsed_content ("network_def" )
54- self .cmd = ["coverage" , "run" , "-m" , "monai.bundle" , "ckpt_export" , "network_def" , "--filepath" , self .ts_file ]
55- self .cmd += ["--meta_file" , self .meta_file , "--config_file" , f"['{ self .config_file } ','{ self .def_args_file } ']" , "--ckpt_file" ]
5655
5756 def tearDown (self ):
5857 if self .device is not None :
@@ -62,36 +61,47 @@ def tearDown(self):
6261 self .tempdir_obj .cleanup ()
6362
6463 @parameterized .expand ([TEST_CASE_1 , TEST_CASE_2 , TEST_CASE_3 ])
65- def test_export (self , key_in_ckpt , use_trace ):
66- save_state (src = self .net if key_in_ckpt == "" else {key_in_ckpt : self .net }, path = self .ckpt_file )
67- full_cmd = self .cmd + [self .ckpt_file , "--key_in_ckpt" , key_in_ckpt , "--args_file" , self .def_args_file ]
68- if use_trace == "True" :
69- full_cmd += ["--use_trace" , use_trace , "--input_shape" , "[1, 1, 96, 96, 96]" ]
70- command_line_tests (full_cmd )
71- self .assertTrue (os .path .exists (self .ts_file ))
72-
73- _ , metadata , extra_files = load_net_with_metadata (
74- self .ts_file , more_extra_files = ["inference.json" , "def_args.json" ]
75- )
76- self .assertIn ("schema" , metadata )
77- self .assertIn ("meta_file" , json .loads (extra_files ["def_args.json" ]))
78- self .assertIn ("network_def" , json .loads (extra_files ["inference.json" ]))
79-
80- @parameterized .expand ([TEST_CASE_1 , TEST_CASE_2 , TEST_CASE_3 ])
81- def test_default_value (self , key_in_ckpt , use_trace ):
64+ def test_ckpt_export_default (self , key_in_ckpt , use_trace ):
8265 ckpt_file = os .path .join (self .tempdir_obj .name , "models/model.pt" )
8366 ts_file = os .path .join (self .tempdir_obj .name , "models/model.ts" )
8467
8568 save_state (src = self .net if key_in_ckpt == "" else {key_in_ckpt : self .net }, path = ckpt_file )
86-
87- # check with default value
88- cmd = ["coverage" , "run" , "-m" , "monai.bundle" , "ckpt_export" , "--key_in_ckpt" , key_in_ckpt ]
89- cmd += ["--config_file" , self .config_file , "--bundle_root" , self .tempdir_obj .name ]
90- if use_trace == "True" :
91- cmd += ["--use_trace" , use_trace , "--input_shape" , "[1, 1, 96, 96, 96]" ]
92- command_line_tests (cmd )
69+ ckpt_export (
70+ net_id = "network_def" ,
71+ filepath = ts_file ,
72+ meta_file = self .meta_file ,
73+ config_file = self .config_file ,
74+ ckpt_file = ckpt_file ,
75+ key_in_ckpt = key_in_ckpt ,
76+ args_file = self .def_args_file ,
77+ use_trace = use_trace ,
78+ input_shape = [1 , 1 , 96 , 96 , 96 ] if use_trace == "True" else None ,
79+ )
9380 self .assertTrue (os .path .exists (ts_file ))
9481
82+ @parameterized .expand ([TEST_CASE_1 , TEST_CASE_2 , TEST_CASE_3 ])
83+ def test_ckpt_export (self , key_in_ckpt , use_trace ):
84+ save_state (src = self .net if key_in_ckpt == "" else {key_in_ckpt : self .net }, path = self .ckpt_file )
85+ ckpt_export (
86+ net_id = "network_def" ,
87+ filepath = self .ts_file ,
88+ meta_file = self .meta_file ,
89+ config_file = [self .config_file , self .def_args_file ],
90+ ckpt_file = self .ckpt_file ,
91+ key_in_ckpt = key_in_ckpt ,
92+ args_file = self .def_args_file ,
93+ use_trace = use_trace ,
94+ input_shape = [1 , 1 , 96 , 96 , 96 ] if use_trace == "True" else None ,
95+ )
96+ self .assertTrue (os .path .exists (self .ts_file ))
97+
98+ _ , metadata , extra_files = load_net_with_metadata (
99+ self .ts_file , more_extra_files = ["inference.json" , "def_args.json" ]
100+ )
101+ self .assertIn ("schema" , metadata )
102+ self .assertIn ("meta_file" , json .loads (extra_files ["def_args.json" ]))
103+ self .assertIn ("network_def" , json .loads (extra_files ["inference.json" ]))
104+
95105
96106if __name__ == "__main__" :
97107 unittest .main ()
0 commit comments