Skip to content

Commit

Permalink
Add "properties_path" in BundleWorkflow (#7542)
Browse files Browse the repository at this point in the history
Fixes #7541


### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
KumoLiu and pre-commit-ci[bot] authored Apr 1, 2024
1 parent c885100 commit 264b9e4
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 4 deletions.
23 changes: 19 additions & 4 deletions monai/bundle/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from __future__ import annotations

import json
import os
import sys
import time
Expand All @@ -24,6 +25,7 @@
from monai.bundle.config_parser import ConfigParser
from monai.bundle.properties import InferProperties, MetaProperties, TrainProperties
from monai.bundle.utils import DEFAULT_EXP_MGMT_SETTINGS, EXPR_KEY, ID_REF_KEY, ID_SEP_KEY
from monai.config import PathLike
from monai.utils import BundleProperty, BundlePropertyConfig, deprecated_arg, deprecated_arg_default, ensure_tuple

__all__ = ["BundleWorkflow", "ConfigWorkflow"]
Expand All @@ -46,6 +48,7 @@ class BundleWorkflow(ABC):
or "infer", "inference", "eval", "evaluation" for a inference workflow,
other unsupported string will raise a ValueError.
default to `None` for common workflow.
properties_path: the path to the JSON file of properties.
meta_file: filepath of the metadata file, if this is a list of file paths, their contents will be merged in order.
logging_file: config file for `logging` module in the program. for more details:
https://docs.python.org/3/library/logging.config.html#logging.config.fileConfig.
Expand All @@ -66,6 +69,7 @@ def __init__(
self,
workflow_type: str | None = None,
workflow: str | None = None,
properties_path: PathLike | None = None,
meta_file: str | Sequence[str] | None = None,
logging_file: str | None = None,
):
Expand All @@ -92,15 +96,24 @@ def __init__(
meta_file = None

workflow_type = workflow if workflow is not None else workflow_type
if workflow_type is None:
if workflow_type is None and properties_path is None:
self.properties = copy(MetaProperties)
self.workflow_type = None
self.meta_file = meta_file
return
if workflow_type.lower() in self.supported_train_type:
if properties_path is not None:
properties_path = Path(properties_path)
if not properties_path.is_file():
raise ValueError(f"Property file {properties_path} does not exist.")
with open(properties_path) as json_file:
self.properties = json.load(json_file)
self.workflow_type = None
self.meta_file = meta_file
return
if workflow_type.lower() in self.supported_train_type: # type: ignore[union-attr]
self.properties = {**TrainProperties, **MetaProperties}
self.workflow_type = "train"
elif workflow_type.lower() in self.supported_infer_type:
elif workflow_type.lower() in self.supported_infer_type: # type: ignore[union-attr]
self.properties = {**InferProperties, **MetaProperties}
self.workflow_type = "infer"
else:
Expand Down Expand Up @@ -247,6 +260,7 @@ class ConfigWorkflow(BundleWorkflow):
or "infer", "inference", "eval", "evaluation" for a inference workflow,
other unsupported string will raise a ValueError.
default to `None` for common workflow.
properties_path: the path to the JSON file of properties.
override: id-value pairs to override or add the corresponding config content.
e.g. ``--net#input_chns 42``, ``--net %/data/other.json#net_arg``
Expand All @@ -271,6 +285,7 @@ def __init__(
tracking: str | dict | None = None,
workflow_type: str | None = None,
workflow: str | None = None,
properties_path: PathLike | None = None,
**override: Any,
) -> None:
workflow_type = workflow if workflow is not None else workflow_type
Expand All @@ -289,7 +304,7 @@ def __init__(
else:
config_root_path = Path("configs")
meta_file = str(config_root_path / "metadata.json") if meta_file is None else meta_file
super().__init__(workflow_type=workflow_type, meta_file=meta_file)
super().__init__(workflow_type=workflow_type, meta_file=meta_file, properties_path=properties_path)
self.config_root_path = config_root_path
logging_file = str(self.config_root_path / "logging.conf") if logging_file is None else logging_file
if logging_file is not None:
Expand Down
10 changes: 10 additions & 0 deletions tests/test_bundle_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,16 @@ def test_inference_config(self, config_file):
)
self._test_inferer(inferer)

# test property path
inferer = ConfigWorkflow(
config_file=config_file,
properties_path=os.path.join(os.path.dirname(__file__), "testing_data", "fl_infer_properties.json"),
logging_file=os.path.join(os.path.dirname(__file__), "testing_data", "logging.conf"),
**override,
)
self._test_inferer(inferer)
self.assertEqual(inferer.workflow_type, None)

@parameterized.expand([TEST_CASE_3])
def test_train_config(self, config_file):
# test standard MONAI model-zoo config workflow
Expand Down
16 changes: 16 additions & 0 deletions tests/test_regularization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,15 @@
import torch

from monai.transforms import CutMix, CutMixd, CutOut, MixUp, MixUpd
from monai.utils import set_determinism


class TestMixup(unittest.TestCase):
def setUp(self) -> None:
set_determinism(seed=0)

def tearDown(self) -> None:
set_determinism(None)

def test_mixup(self):
for dims in [2, 3]:
Expand Down Expand Up @@ -53,6 +59,11 @@ def test_mixupd(self):


class TestCutMix(unittest.TestCase):
def setUp(self) -> None:
set_determinism(seed=0)

def tearDown(self) -> None:
set_determinism(None)

def test_cutmix(self):
for dims in [2, 3]:
Expand All @@ -78,6 +89,11 @@ def test_cutmixd(self):


class TestCutOut(unittest.TestCase):
def setUp(self) -> None:
set_determinism(seed=0)

def tearDown(self) -> None:
set_determinism(None)

def test_cutout(self):
for dims in [2, 3]:
Expand Down
67 changes: 67 additions & 0 deletions tests/testing_data/fl_infer_properties.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
{
"bundle_root": {
"description": "root path of the bundle.",
"required": true,
"id": "bundle_root"
},
"device": {
"description": "target device to execute the bundle workflow.",
"required": true,
"id": "device"
},
"dataset_dir": {
"description": "directory path of the dataset.",
"required": true,
"id": "dataset_dir"
},
"dataset": {
"description": "PyTorch dataset object for the inference / evaluation logic.",
"required": true,
"id": "dataset"
},
"evaluator": {
"description": "inference / evaluation workflow engine.",
"required": true,
"id": "evaluator"
},
"network_def": {
"description": "network module for the inference.",
"required": true,
"id": "network_def"
},
"inferer": {
"description": "MONAI Inferer object to execute the model computation in inference.",
"required": true,
"id": "inferer"
},
"dataset_data": {
"description": "data source for the inference / evaluation dataset.",
"required": false,
"id": "dataset::data",
"refer_id": null
},
"handlers": {
"description": "event-handlers for the inference / evaluation logic.",
"required": false,
"id": "handlers",
"refer_id": "evaluator::val_handlers"
},
"preprocessing": {
"description": "preprocessing for the input data.",
"required": false,
"id": "preprocessing",
"refer_id": "dataset::transform"
},
"postprocessing": {
"description": "postprocessing for the model output data.",
"required": false,
"id": "postprocessing",
"refer_id": "evaluator::postprocessing"
},
"key_metric": {
"description": "the key metric during evaluation.",
"required": false,
"id": "key_metric",
"refer_id": "evaluator::key_val_metric"
}
}

0 comments on commit 264b9e4

Please sign in to comment.