Skip to content

Commit adc6f1e

Browse files
mingxin-zhengKumoLiu
authored andcommitted
Add checks for num_fold and fail early if wrong (Project-MONAI#7634)
Fixes Project-MONAI#7628 . ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Mingxin <mingxinz@nvidia.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
1 parent f225764 commit adc6f1e

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

monai/apps/auto3dseg/auto_runner.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,9 +298,13 @@ def __init__(
298298
pass
299299

300300
# inspect and update folds
301-
num_fold = self.inspect_datalist_folds(datalist_filename=datalist_filename)
301+
self.max_fold = self.inspect_datalist_folds(datalist_filename=datalist_filename)
302302
if "num_fold" in self.data_src_cfg:
303303
num_fold = int(self.data_src_cfg["num_fold"]) # override from config
304+
logger.info(f"Setting num_fold {num_fold} based on the input config.")
305+
else:
306+
num_fold = self.max_fold
307+
logger.info(f"Setting num_fold {num_fold} based on the input datalist {datalist_filename}.")
304308

305309
self.data_src_cfg["datalist"] = datalist_filename # update path to a version in work_dir and save user input
306310
ConfigParser.export_config_file(
@@ -398,7 +402,10 @@ def inspect_datalist_folds(self, datalist_filename: str) -> int:
398402

399403
if len(fold_list) > 0:
400404
num_fold = max(fold_list) + 1
401-
logger.info(f"Setting num_fold {num_fold} based on the input datalist {datalist_filename}.")
405+
logger.info(f"Found num_fold {num_fold} based on the input datalist {datalist_filename}.")
406+
# check if every fold is present
407+
if len(set(fold_list)) != num_fold:
408+
raise ValueError(f"Fold numbers are not continuous from 0 to {num_fold - 1}")
402409
elif "validation" in datalist and len(datalist["validation"]) > 0:
403410
logger.info("No fold numbers provided, attempting to use a single fold based on the validation key")
404411
# update the datalist file
@@ -492,6 +499,11 @@ def set_num_fold(self, num_fold: int = 5) -> AutoRunner:
492499

493500
if num_fold <= 0:
494501
raise ValueError(f"num_fold is expected to be an integer greater than zero. Now it gets {num_fold}")
502+
if num_fold > self.max_fold + 1:
503+
# Auto3DSeg allows no validation set, so the maximum fold number is max_fold + 1
504+
raise ValueError(
505+
f"num_fold is greater than the maximum fold number {self.max_fold} in {self.datalist_filename}."
506+
)
495507
self.num_fold = num_fold
496508

497509
return self

0 commit comments

Comments
 (0)