Skip to content

Commit

Permalink
Merge remote-tracking branch 'yliu/bundle' into bundle
Browse files Browse the repository at this point in the history
  • Loading branch information
KumoLiu committed Mar 26, 2024
2 parents 5aa4f38 + cca300e commit e941dde
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 6 deletions.
10 changes: 6 additions & 4 deletions monai/bundle/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
from pathlib import Path
from typing import Any, Sequence

from monai.config import PathLike
from monai.apps.utils import get_logger
from monai.bundle.config_parser import ConfigParser
from monai.bundle.properties import InferProperties, MetaProperties, TrainProperties
from monai.bundle.utils import DEFAULT_EXP_MGMT_SETTINGS, EXPR_KEY, ID_REF_KEY, ID_SEP_KEY
from monai.config import PathLike
from monai.utils import BundleProperty, BundlePropertyConfig, deprecated_arg, deprecated_arg_default, ensure_tuple

__all__ = ["BundleWorkflow", "ConfigWorkflow"]
Expand Down Expand Up @@ -62,7 +62,9 @@ class BundleWorkflow(ABC):
new_name="workflow_type",
msg_suffix="please use `workflow_type` instead.",
)
def __init__(self, workflow_type: str | None = None, workflow: str | None = None, properties_path: PathLike | None = None):
def __init__(
self, workflow_type: str | None = None, workflow: str | None = None, properties_path: PathLike | None = None
):
workflow_type = workflow if workflow is not None else workflow_type
if workflow_type is None and properties_path is None:
self.properties = copy(MetaProperties)
Expand All @@ -76,10 +78,10 @@ def __init__(self, workflow_type: str | None = None, workflow: str | None = None
self.properties = json.load(json_file)
self.workflow_type = None
return
if workflow_type.lower() in self.supported_train_type:
if workflow_type.lower() in self.supported_train_type: # type: ignore[union-attr]
self.properties = {**TrainProperties, **MetaProperties}
self.workflow_type = "train"
elif workflow_type.lower() in self.supported_infer_type:
elif workflow_type.lower() in self.supported_infer_type: # type: ignore[union-attr]
self.properties = {**InferProperties, **MetaProperties}
self.workflow_type = "infer"
else:
Expand Down
10 changes: 8 additions & 2 deletions tests/test_bundle_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@

TEST_CASE_2 = [os.path.join(os.path.dirname(__file__), "testing_data", "inference.yaml")]

TEST_CASE_3 = [os.path.join(os.path.dirname(__file__), "testing_data", "config_fl_train.json")]
TEST_CASE_3 = [
os.path.join(os.path.dirname(__file__), "testing_data", "config_fl_train.json"),
os.path.join(os.path.dirname(__file__), "testing_data", "fl_train_properties.json"),
]


class TestBundleWorkflow(unittest.TestCase):
Expand Down Expand Up @@ -101,10 +104,11 @@ def test_inference_config(self, config_file):
logging_file=os.path.join(os.path.dirname(__file__), "testing_data", "logging.conf"),
**override,
)
inferer.add_property(name="inferer", required=True, config_id="inferer")
self._test_inferer(inferer)

@parameterized.expand([TEST_CASE_3])
def test_train_config(self, config_file):
def test_train_config(self, config_file, properties_path):
# test standard MONAI model-zoo config workflow
trainer = ConfigWorkflow(
workflow_type="train",
Expand All @@ -113,6 +117,7 @@ def test_train_config(self, config_file):
init_id="initialize",
run_id="run",
final_id="finalize",
properties_path=properties_path,
)
# should initialize before parsing any bundle content
trainer.initialize()
Expand Down Expand Up @@ -144,6 +149,7 @@ def test_train_config(self, config_file):
def test_non_config(self):
# test user defined python style workflow
inferer = NonConfigWorkflow(self.filename, self.data_dir)
inferer.add_property(name="inferer", required=True)
self._test_inferer(inferer)


Expand Down
126 changes: 126 additions & 0 deletions tests/testing_data/fl_train_properties.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
{
"bundle_root": {
"description": "root path of the bundle.",
"required": true,
"id": "bundle_root"
},
"device": {
"description": "target device to execute the bundle workflow.",
"required": true,
"id": "device"
},
"dataset_dir": {
"description": "directory path of the dataset.",
"required": true,
"id": "dataset_dir"
},
"trainer": {
"description": "training workflow engine.",
"required": true,
"id": "train::trainer"
},
"network_def": {
"description": "network module for the training.",
"required": false,
"id": "network_def"
},
"max_epochs": {
"description": "max number of epochs to execute the training.",
"required": true,
"id": "train::trainer::max_epochs"
},
"train_dataset": {
"description": "PyTorch dataset object for the training logic.",
"required": true,
"id": "train::dataset"
},
"train_inferer": {
"description": "MONAI Inferer object to execute the model computation in training.",
"required": true,
"id": "train::inferer"
},
"train_dataset_data": {
"description": "data source for the training dataset.",
"required": false,
"id": "train::dataset::data",
"refer_id": null
},
"train_handlers": {
"description": "event-handlers for the training logic.",
"required": false,
"id": "train::handlers",
"refer_id": "train::trainer::train_handlers"
},
"train_preprocessing": {
"description": "preprocessing for the training input data.",
"required": false,
"id": "train::preprocessing",
"refer_id": "train::dataset::transform"
},
"train_postprocessing": {
"description": "postprocessing for the training model output data.",
"required": false,
"id": "train::postprocessing",
"refer_id": "train::trainer::postprocessing"
},
"train_key_metric": {
"description": "key metric to compute on the training data.",
"required": false,
"id": "train::key_metric",
"refer_id": "train::trainer::key_train_metric"
},
"evaluator": {
"description": "validation workflow engine.",
"required": false,
"id": "validate::evaluator",
"refer_id": "validator"
},
"val_interval": {
"description": "validation interval during the training.",
"required": false,
"id": "val_interval",
"refer_id": "interval"
},
"val_handlers": {
"description": "event-handlers for the validation logic.",
"required": false,
"id": "validate::handlers",
"refer_id": "validate::evaluator::val_handlers"
},
"val_dataset": {
"description": "PyTorch dataset object for the validation logic.",
"required": false,
"id": "validate::dataset",
"refer_id": "validate::dataloader::dataset"
},
"val_dataset_data": {
"description": "data source for the validation dataset.",
"required": false,
"id": "validate::dataset::data",
"refer_id": null
},
"val_inferer": {
"description": "MONAI Inferer object to execute the model computation in validation.",
"required": false,
"id": "validate::inferer",
"refer_id": "validate::evaluator::inferer"
},
"val_preprocessing": {
"description": "preprocessing for the validation input data.",
"required": false,
"id": "validate::preprocessing",
"refer_id": "validate::dataset::transform"
},
"val_postprocessing": {
"description": "postprocessing for the validation model output data.",
"required": false,
"id": "validate::postprocessing",
"refer_id": "validate::evaluator::postprocessing"
},
"val_key_metric": {
"description": "key metric to compute on the validation data.",
"required": false,
"id": "validate::key_metric",
"refer_id": "validate::evaluator::key_val_metric"
}
}

0 comments on commit e941dde

Please sign in to comment.