Skip to content

Commit 32bb745

Browse files
myronYu0610
authored andcommitted
autorunner params from config (Project-MONAI#7175)
allows setting AutoRunner params from config allows specifying number of folds in config --------- Signed-off-by: myron <amyronenko@nvidia.com> Signed-off-by: Yu0610 <612410030@alum.ccu.edu.tw>
1 parent e07b1df commit 32bb745

File tree

2 files changed

+54
-28
lines changed

2 files changed

+54
-28
lines changed

monai/apps/auto3dseg/auto_runner.py

Lines changed: 52 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -214,22 +214,11 @@ def __init__(
214214
mlflow_tracking_uri: str | None = None,
215215
**kwargs: Any,
216216
):
217-
logger.info(f"AutoRunner using work directory {work_dir}")
218-
os.makedirs(work_dir, exist_ok=True)
219-
220-
self.work_dir = os.path.abspath(work_dir)
221-
self.data_src_cfg = dict()
222-
self.data_src_cfg_name = os.path.join(self.work_dir, "input.yaml")
223-
self.algos = algos
224-
self.templates_path_or_url = templates_path_or_url
225-
self.allow_skip = allow_skip
226-
self.mlflow_tracking_uri = mlflow_tracking_uri
227-
self.kwargs = deepcopy(kwargs)
228-
229-
if input is None and os.path.isfile(self.data_src_cfg_name):
230-
input = self.data_src_cfg_name
217+
if input is None and os.path.isfile(os.path.join(os.path.abspath(work_dir), "input.yaml")):
218+
input = os.path.join(os.path.abspath(work_dir), "input.yaml")
231219
logger.info(f"Input config is not provided, using the default {input}")
232220

221+
self.data_src_cfg = dict()
233222
if isinstance(input, dict):
234223
self.data_src_cfg = input
235224
elif isinstance(input, str) and os.path.isfile(input):
@@ -238,6 +227,51 @@ def __init__(
238227
else:
239228
raise ValueError(f"{input} is not a valid file or dict")
240229

230+
if "work_dir" in self.data_src_cfg: # override from config
231+
work_dir = self.data_src_cfg["work_dir"]
232+
self.work_dir = os.path.abspath(work_dir)
233+
234+
logger.info(f"AutoRunner using work directory {self.work_dir}")
235+
os.makedirs(self.work_dir, exist_ok=True)
236+
self.data_src_cfg_name = os.path.join(self.work_dir, "input.yaml")
237+
238+
self.algos = algos
239+
self.templates_path_or_url = templates_path_or_url
240+
self.allow_skip = allow_skip
241+
242+
# cache.yaml
243+
self.not_use_cache = not_use_cache
244+
self.cache_filename = os.path.join(self.work_dir, "cache.yaml")
245+
self.cache = self.read_cache()
246+
self.export_cache()
247+
248+
# determine if we need to analyze, algo_gen or train from cache, unless manually provided
249+
self.analyze = not self.cache["analyze"] if analyze is None else analyze
250+
self.algo_gen = not self.cache["algo_gen"] if algo_gen is None else algo_gen
251+
self.train = train
252+
self.ensemble = ensemble # last step, no need to check
253+
self.hpo = hpo and has_nni
254+
self.hpo_backend = hpo_backend
255+
self.mlflow_tracking_uri = mlflow_tracking_uri
256+
self.kwargs = deepcopy(kwargs)
257+
258+
# parse input config for AutoRunner param overrides
259+
for param in [
260+
"analyze",
261+
"algo_gen",
262+
"train",
263+
"hpo",
264+
"ensemble",
265+
"not_use_cache",
266+
"allow_skip",
267+
]: # override from config
268+
if param in self.data_src_cfg and isinstance(self.data_src_cfg[param], bool):
269+
setattr(self, param, self.data_src_cfg[param]) # e.g. self.analyze = self.data_src_cfg["analyze"]
270+
271+
for param in ["algos", "hpo_backend", "templates_path_or_url", "mlflow_tracking_uri"]: # override from config
272+
if param in self.data_src_cfg:
273+
setattr(self, param, self.data_src_cfg[param]) # e.g. self.algos = self.data_src_cfg["algos"]
274+
241275
missing_keys = {"dataroot", "datalist", "modality"}.difference(self.data_src_cfg.keys())
242276
if len(missing_keys) > 0:
243277
raise ValueError(f"Config keys are missing {missing_keys}")
@@ -256,6 +290,8 @@ def __init__(
256290

257291
# inspect and update folds
258292
num_fold = self.inspect_datalist_folds(datalist_filename=datalist_filename)
293+
if "num_fold" in self.data_src_cfg:
294+
num_fold = int(self.data_src_cfg["num_fold"]) # override from config
259295

260296
self.data_src_cfg["datalist"] = datalist_filename # update path to a version in work_dir and save user input
261297
ConfigParser.export_config_file(
@@ -266,17 +302,6 @@ def __init__(
266302
self.datastats_filename = os.path.join(self.work_dir, "datastats.yaml")
267303
self.datalist_filename = datalist_filename
268304

269-
self.not_use_cache = not_use_cache
270-
self.cache_filename = os.path.join(self.work_dir, "cache.yaml")
271-
self.cache = self.read_cache()
272-
self.export_cache()
273-
274-
# determine if we need to analyze, algo_gen or train from cache, unless manually provided
275-
self.analyze = not self.cache["analyze"] if analyze is None else analyze
276-
self.algo_gen = not self.cache["algo_gen"] if algo_gen is None else algo_gen
277-
self.train = train
278-
self.ensemble = ensemble # last step, no need to check
279-
280305
self.set_training_params()
281306
self.set_device_info()
282307
self.set_prediction_params()
@@ -288,9 +313,9 @@ def __init__(
288313
self.gpu_customization_specs: dict[str, Any] = {}
289314

290315
# hpo
291-
if hpo_backend.lower() != "nni":
316+
if self.hpo_backend.lower() != "nni":
292317
raise NotImplementedError("HPOGen backend only supports NNI")
293-
self.hpo = hpo and has_nni
318+
self.hpo = self.hpo and has_nni
294319
self.set_hpo_params()
295320
self.search_space: dict[str, dict[str, Any]] = {}
296321
self.hpo_tasks = 0

tests/test_vis_gradcam.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from monai.networks.nets import DenseNet, DenseNet121, SEResNet50
2222
from monai.visualize import GradCAM, GradCAMpp
23-
from tests.utils import assert_allclose
23+
from tests.utils import assert_allclose, skip_if_quick
2424

2525

2626
class DenseNetAdjoint(DenseNet121):
@@ -147,6 +147,7 @@ def __call__(self, x, adjoint_info):
147147
TESTS_ILL.append([cam])
148148

149149

150+
@skip_if_quick
150151
class TestGradientClassActivationMap(unittest.TestCase):
151152
@parameterized.expand(TESTS)
152153
def test_shape(self, cam_class, input_data, expected_shape):

0 commit comments

Comments
 (0)