diff --git a/monai/bundle/__init__.py b/monai/bundle/__init__.py index fa139bb601..b550dc8c93 100644 --- a/monai/bundle/__init__.py +++ b/monai/bundle/__init__.py @@ -24,6 +24,7 @@ init_bundle, load, run, + run_workflow, trt_export, verify_metadata, verify_net_in_out, diff --git a/monai/bundle/__main__.py b/monai/bundle/__main__.py index aa0ed20ef5..e143a5c7ed 100644 --- a/monai/bundle/__main__.py +++ b/monai/bundle/__main__.py @@ -11,7 +11,16 @@ from __future__ import annotations -from monai.bundle.scripts import ckpt_export, download, init_bundle, run, trt_export, verify_metadata, verify_net_in_out +from monai.bundle.scripts import ( + ckpt_export, + download, + init_bundle, + run, + run_workflow, + trt_export, + verify_metadata, + verify_net_in_out, +) if __name__ == "__main__": from monai.utils import optional_import diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 63898f86e5..29adad970c 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -18,6 +18,7 @@ import warnings from collections.abc import Mapping, Sequence from pathlib import Path +from pydoc import locate from shutil import copyfile from textwrap import dedent from typing import Any, Callable @@ -30,7 +31,7 @@ from monai.bundle.config_item import ConfigComponent from monai.bundle.config_parser import ConfigParser from monai.bundle.utils import DEFAULT_INFERENCE, DEFAULT_METADATA -from monai.bundle.workflows import ConfigWorkflow +from monai.bundle.workflows import BundleWorkflow, ConfigWorkflow from monai.config import IgniteInfo, PathLike from monai.data import load_net_with_metadata, save_net_with_metadata from monai.networks import convert_to_torchscript, convert_to_trt, copy_model_state, get_state_dict, save_state @@ -635,7 +636,7 @@ def run( args_file: a JSON or YAML file to provide default values for `runner_id`, `meta_file`, `config_file`, `logging`, and override pairs. so that the command line inputs can be simplified. override: id-value pairs to override or add the corresponding config content. - e.g. ``--net#input_chns 42``. + e.g. ``--net#input_chns 42``, ``--net %/data/other.json#net_arg``. """ @@ -678,6 +679,53 @@ def run( workflow.finalize() +def run_workflow(workflow: str | BundleWorkflow | None = None, args_file: str | None = None, **kwargs: Any) -> None: + """ + Specify `bundle workflow` to run monai bundle components and workflows. + The workflow should be suclass of `BundleWorkflow` and be available to import. + It can be MONAI existing bundle workflows or user customized workflows. + + Typical usage examples: + + .. code-block:: bash + + # Execute this module as a CLI entry with default ConfigWorkflow: + python -m monai.bundle run_workflow --meta_file --config_file + + # Set the workflow to other customized BundleWorkflow subclass: + python -m monai.bundle run_workflow --workflow CustomizedWorkflow ... + + Args: + workflow: specified bundle workflow name, should be a string or class, default to "ConfigWorkflow". + args_file: a JSON or YAML file to provide default values for this API. + so that the command line inputs can be simplified. + kwargs: arguments to instantiate the workflow class. + + """ + + _args = _update_args(args=args_file, workflow=workflow, **kwargs) + _log_input_summary(tag="run", args=_args) + (workflow_name,) = _pop_args(_args, workflow=ConfigWorkflow) # the default workflow name is "ConfigWorkflow" + if isinstance(workflow_name, str): + workflow_class, has_built_in = optional_import("monai.bundle", name=str(workflow_name)) # search built-in + if not has_built_in: + workflow_class = locate(str(workflow_name)) # search dotted path + if workflow_class is None: + raise ValueError(f"cannot locate specified workflow class: {workflow_name}.") + elif issubclass(workflow_name, BundleWorkflow): + workflow_class = workflow_name + else: + raise ValueError( + "Argument `workflow` must be a bundle workflow class name" + f"or subclass of BundleWorkflow, got: {workflow_name}." + ) + + workflow_ = workflow_class(**_args) + workflow_.initialize() + workflow_.run() + workflow_.finalize() + + def verify_metadata( meta_file: str | Sequence[str] | None = None, filepath: PathLike | None = None, diff --git a/monai/bundle/workflows.py b/monai/bundle/workflows.py index ace08b3ec8..82bab73fe2 100644 --- a/monai/bundle/workflows.py +++ b/monai/bundle/workflows.py @@ -165,7 +165,7 @@ class ConfigWorkflow(BundleWorkflow): other unsupported string will raise a ValueError. default to `None` for common workflow. override: id-value pairs to override or add the corresponding config content. - e.g. ``--net#input_chns 42``. + e.g. ``--net#input_chns 42``, ``--net %/data/other.json#net_arg`` """ diff --git a/tests/nonconfig_workflow.py b/tests/nonconfig_workflow.py new file mode 100644 index 0000000000..63562af868 --- /dev/null +++ b/tests/nonconfig_workflow.py @@ -0,0 +1,127 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch + +from monai.bundle import BundleWorkflow +from monai.data import DataLoader, Dataset +from monai.engines import SupervisedEvaluator +from monai.inferers import SlidingWindowInferer +from monai.networks.nets import UNet +from monai.transforms import ( + Activationsd, + AsDiscreted, + Compose, + EnsureChannelFirstd, + LoadImaged, + SaveImaged, + ScaleIntensityd, +) +from monai.utils import BundleProperty, set_determinism + + +class NonConfigWorkflow(BundleWorkflow): + """ + Test class simulates the bundle workflow defined by Python script directly. + + """ + + def __init__(self, filename, output_dir): + super().__init__(workflow="inference") + self.filename = filename + self.output_dir = output_dir + self._bundle_root = "will override" + self._device = torch.device("cpu") + self._network_def = None + self._inferer = None + self._preprocessing = None + self._postprocessing = None + self._evaluator = None + + def initialize(self): + set_determinism(0) + if self._preprocessing is None: + self._preprocessing = Compose( + [LoadImaged(keys="image"), EnsureChannelFirstd(keys="image"), ScaleIntensityd(keys="image")] + ) + dataset = Dataset(data=[{"image": self.filename}], transform=self._preprocessing) + dataloader = DataLoader(dataset, batch_size=1, num_workers=4) + + if self._network_def is None: + self._network_def = UNet( + spatial_dims=3, + in_channels=1, + out_channels=2, + channels=[2, 2, 4, 8, 4], + strides=[2, 2, 2, 2], + num_res_units=2, + norm="batch", + ) + if self._inferer is None: + self._inferer = SlidingWindowInferer(roi_size=(64, 64, 32), sw_batch_size=4, overlap=0.25) + + if self._postprocessing is None: + self._postprocessing = Compose( + [ + Activationsd(keys="pred", softmax=True), + AsDiscreted(keys="pred", argmax=True), + SaveImaged(keys="pred", output_dir=self.output_dir, output_postfix="seg"), + ] + ) + + self._evaluator = SupervisedEvaluator( + device=self._device, + val_data_loader=dataloader, + network=self._network_def.to(self._device), + inferer=self._inferer, + postprocessing=self._postprocessing, + amp=False, + ) + + def run(self): + self._evaluator.run() + + def finalize(self): + return True + + def _get_property(self, name, property): + if name == "bundle_root": + return self._bundle_root + if name == "device": + return self._device + if name == "network_def": + return self._network_def + if name == "inferer": + return self._inferer + if name == "preprocessing": + return self._preprocessing + if name == "postprocessing": + return self._postprocessing + if property[BundleProperty.REQUIRED]: + raise ValueError(f"unsupported property '{name}' is required in the bundle properties.") + + def _set_property(self, name, property, value): + if name == "bundle_root": + self._bundle_root = value + elif name == "device": + self._device = value + elif name == "network_def": + self._network_def = value + elif name == "inferer": + self._inferer = value + elif name == "preprocessing": + self._preprocessing = value + elif name == "postprocessing": + self._postprocessing = value + elif property[BundleProperty.REQUIRED]: + raise ValueError(f"unsupported property '{name}' is required in the bundle properties.") diff --git a/tests/test_bundle_workflow.py b/tests/test_bundle_workflow.py index 948c351a1c..33a3f1d959 100644 --- a/tests/test_bundle_workflow.py +++ b/tests/test_bundle_workflow.py @@ -22,22 +22,12 @@ import torch from parameterized import parameterized -from monai.bundle import BundleWorkflow, ConfigWorkflow -from monai.data import DataLoader, Dataset -from monai.engines import SupervisedEvaluator +from monai.bundle import ConfigWorkflow +from monai.data import Dataset from monai.inferers import SimpleInferer, SlidingWindowInferer from monai.networks.nets import UNet -from monai.transforms import ( - Activationsd, - AsDiscreted, - Compose, - EnsureChannelFirstd, - LoadImage, - LoadImaged, - SaveImaged, - ScaleIntensityd, -) -from monai.utils import BundleProperty, set_determinism +from monai.transforms import Compose, LoadImage +from tests.nonconfig_workflow import NonConfigWorkflow TEST_CASE_1 = [os.path.join(os.path.dirname(__file__), "testing_data", "inference.json")] @@ -46,103 +36,6 @@ TEST_CASE_3 = [os.path.join(os.path.dirname(__file__), "testing_data", "config_fl_train.json")] -class NonConfigWorkflow(BundleWorkflow): - """ - Test class simulates the bundle workflow defined by Python script directly. - - """ - - def __init__(self, filename, output_dir): - super().__init__(workflow="inference") - self.filename = filename - self.output_dir = output_dir - self._bundle_root = "will override" - self._device = torch.device("cpu") - self._network_def = None - self._inferer = None - self._preprocessing = None - self._postprocessing = None - self._evaluator = None - - def initialize(self): - set_determinism(0) - if self._preprocessing is None: - self._preprocessing = Compose( - [LoadImaged(keys="image"), EnsureChannelFirstd(keys="image"), ScaleIntensityd(keys="image")] - ) - dataset = Dataset(data=[{"image": self.filename}], transform=self._preprocessing) - dataloader = DataLoader(dataset, batch_size=1, num_workers=4) - - if self._network_def is None: - self._network_def = UNet( - spatial_dims=3, - in_channels=1, - out_channels=2, - channels=[2, 2, 4, 8, 4], - strides=[2, 2, 2, 2], - num_res_units=2, - norm="batch", - ) - if self._inferer is None: - self._inferer = SlidingWindowInferer(roi_size=(64, 64, 32), sw_batch_size=4, overlap=0.25) - - if self._postprocessing is None: - self._postprocessing = Compose( - [ - Activationsd(keys="pred", softmax=True), - AsDiscreted(keys="pred", argmax=True), - SaveImaged(keys="pred", output_dir=self.output_dir, output_postfix="seg"), - ] - ) - - self._evaluator = SupervisedEvaluator( - device=self._device, - val_data_loader=dataloader, - network=self._network_def.to(self._device), - inferer=self._inferer, - postprocessing=self._postprocessing, - amp=False, - ) - - def run(self): - self._evaluator.run() - - def finalize(self): - return True - - def _get_property(self, name, property): - if name == "bundle_root": - return self._bundle_root - if name == "device": - return self._device - if name == "network_def": - return self._network_def - if name == "inferer": - return self._inferer - if name == "preprocessing": - return self._preprocessing - if name == "postprocessing": - return self._postprocessing - if property[BundleProperty.REQUIRED]: - raise ValueError(f"unsupported property '{name}' is required in the bundle properties.") - - def _set_property(self, name, property, value): - if name == "bundle_root": - self._bundle_root = value - elif name == "device": - self._device = value - elif name == "network_def": - self._network_def = value - elif name == "inferer": - self._inferer = value - elif name == "preprocessing": - self._preprocessing = value - elif name == "postprocessing": - self._postprocessing = value - elif property[BundleProperty.REQUIRED]: - raise ValueError(f"unsupported property '{name}' is required in the bundle properties.") - - class TestBundleWorkflow(unittest.TestCase): def setUp(self): self.data_dir = tempfile.mkdtemp() diff --git a/tests/test_integration_bundle_run.py b/tests/test_integration_bundle_run.py index 24c0286133..2018ca801e 100644 --- a/tests/test_integration_bundle_run.py +++ b/tests/test_integration_bundle_run.py @@ -62,8 +62,10 @@ def test_tiny(self): }, f, ) - cmd = ["coverage", "run", "-m", "monai.bundle", "run", "training", "--config_file", config_file] - command_line_tests(cmd) + cmd = ["coverage", "run", "-m", "monai.bundle"] + # test both CLI entry "run" and "run_workflow" + command_line_tests(cmd + ["run", "training", "--config_file", config_file]) + command_line_tests(cmd + ["run_workflow", "--run_id", "training", "--config_file", config_file]) @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_shape(self, config_file, expected_shape): @@ -136,6 +138,18 @@ def test_shape(self, config_file, expected_shape): # test the saved execution configs self.assertTrue(len(glob(f"{tempdir}/config_*.json")), 2) + def test_customized_workflow(self): + expected_shape = (64, 64, 64) + test_image = np.random.rand(*expected_shape) + filename = os.path.join(self.data_dir, "image.nii") + nib.save(nib.Nifti1Image(test_image, np.eye(4)), filename) + + cmd = "-m fire monai.bundle.scripts run_workflow --workflow tests.nonconfig_workflow.NonConfigWorkflow" + cmd += f" --filename {filename} --output_dir {self.data_dir}" + command_line_tests(["coverage", "run"] + cmd.split(" ")) + loader = LoadImage(image_only=True) + self.assertTupleEqual(loader(os.path.join(self.data_dir, "image", "image_seg.nii.gz")).shape, expected_shape) + if __name__ == "__main__": unittest.main()