diff --git a/monai/bundle/__init__.py b/monai/bundle/__init__.py
index fa139bb601..b550dc8c93 100644
--- a/monai/bundle/__init__.py
+++ b/monai/bundle/__init__.py
@@ -24,6 +24,7 @@
init_bundle,
load,
run,
+ run_workflow,
trt_export,
verify_metadata,
verify_net_in_out,
diff --git a/monai/bundle/__main__.py b/monai/bundle/__main__.py
index aa0ed20ef5..e143a5c7ed 100644
--- a/monai/bundle/__main__.py
+++ b/monai/bundle/__main__.py
@@ -11,7 +11,16 @@
from __future__ import annotations
-from monai.bundle.scripts import ckpt_export, download, init_bundle, run, trt_export, verify_metadata, verify_net_in_out
+from monai.bundle.scripts import (
+ ckpt_export,
+ download,
+ init_bundle,
+ run,
+ run_workflow,
+ trt_export,
+ verify_metadata,
+ verify_net_in_out,
+)
if __name__ == "__main__":
from monai.utils import optional_import
diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py
index 63898f86e5..29adad970c 100644
--- a/monai/bundle/scripts.py
+++ b/monai/bundle/scripts.py
@@ -18,6 +18,7 @@
import warnings
from collections.abc import Mapping, Sequence
from pathlib import Path
+from pydoc import locate
from shutil import copyfile
from textwrap import dedent
from typing import Any, Callable
@@ -30,7 +31,7 @@
from monai.bundle.config_item import ConfigComponent
from monai.bundle.config_parser import ConfigParser
from monai.bundle.utils import DEFAULT_INFERENCE, DEFAULT_METADATA
-from monai.bundle.workflows import ConfigWorkflow
+from monai.bundle.workflows import BundleWorkflow, ConfigWorkflow
from monai.config import IgniteInfo, PathLike
from monai.data import load_net_with_metadata, save_net_with_metadata
from monai.networks import convert_to_torchscript, convert_to_trt, copy_model_state, get_state_dict, save_state
@@ -635,7 +636,7 @@ def run(
args_file: a JSON or YAML file to provide default values for `runner_id`, `meta_file`,
`config_file`, `logging`, and override pairs. so that the command line inputs can be simplified.
override: id-value pairs to override or add the corresponding config content.
- e.g. ``--net#input_chns 42``.
+ e.g. ``--net#input_chns 42``, ``--net %/data/other.json#net_arg``.
"""
@@ -678,6 +679,53 @@ def run(
workflow.finalize()
+def run_workflow(workflow: str | BundleWorkflow | None = None, args_file: str | None = None, **kwargs: Any) -> None:
+ """
+ Specify `bundle workflow` to run monai bundle components and workflows.
+ The workflow should be suclass of `BundleWorkflow` and be available to import.
+ It can be MONAI existing bundle workflows or user customized workflows.
+
+ Typical usage examples:
+
+ .. code-block:: bash
+
+ # Execute this module as a CLI entry with default ConfigWorkflow:
+ python -m monai.bundle run_workflow --meta_file --config_file
+
+ # Set the workflow to other customized BundleWorkflow subclass:
+ python -m monai.bundle run_workflow --workflow CustomizedWorkflow ...
+
+ Args:
+ workflow: specified bundle workflow name, should be a string or class, default to "ConfigWorkflow".
+ args_file: a JSON or YAML file to provide default values for this API.
+ so that the command line inputs can be simplified.
+ kwargs: arguments to instantiate the workflow class.
+
+ """
+
+ _args = _update_args(args=args_file, workflow=workflow, **kwargs)
+ _log_input_summary(tag="run", args=_args)
+ (workflow_name,) = _pop_args(_args, workflow=ConfigWorkflow) # the default workflow name is "ConfigWorkflow"
+ if isinstance(workflow_name, str):
+ workflow_class, has_built_in = optional_import("monai.bundle", name=str(workflow_name)) # search built-in
+ if not has_built_in:
+ workflow_class = locate(str(workflow_name)) # search dotted path
+ if workflow_class is None:
+ raise ValueError(f"cannot locate specified workflow class: {workflow_name}.")
+ elif issubclass(workflow_name, BundleWorkflow):
+ workflow_class = workflow_name
+ else:
+ raise ValueError(
+ "Argument `workflow` must be a bundle workflow class name"
+ f"or subclass of BundleWorkflow, got: {workflow_name}."
+ )
+
+ workflow_ = workflow_class(**_args)
+ workflow_.initialize()
+ workflow_.run()
+ workflow_.finalize()
+
+
def verify_metadata(
meta_file: str | Sequence[str] | None = None,
filepath: PathLike | None = None,
diff --git a/monai/bundle/workflows.py b/monai/bundle/workflows.py
index ace08b3ec8..82bab73fe2 100644
--- a/monai/bundle/workflows.py
+++ b/monai/bundle/workflows.py
@@ -165,7 +165,7 @@ class ConfigWorkflow(BundleWorkflow):
other unsupported string will raise a ValueError.
default to `None` for common workflow.
override: id-value pairs to override or add the corresponding config content.
- e.g. ``--net#input_chns 42``.
+ e.g. ``--net#input_chns 42``, ``--net %/data/other.json#net_arg``
"""
diff --git a/tests/nonconfig_workflow.py b/tests/nonconfig_workflow.py
new file mode 100644
index 0000000000..63562af868
--- /dev/null
+++ b/tests/nonconfig_workflow.py
@@ -0,0 +1,127 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import torch
+
+from monai.bundle import BundleWorkflow
+from monai.data import DataLoader, Dataset
+from monai.engines import SupervisedEvaluator
+from monai.inferers import SlidingWindowInferer
+from monai.networks.nets import UNet
+from monai.transforms import (
+ Activationsd,
+ AsDiscreted,
+ Compose,
+ EnsureChannelFirstd,
+ LoadImaged,
+ SaveImaged,
+ ScaleIntensityd,
+)
+from monai.utils import BundleProperty, set_determinism
+
+
+class NonConfigWorkflow(BundleWorkflow):
+ """
+ Test class simulates the bundle workflow defined by Python script directly.
+
+ """
+
+ def __init__(self, filename, output_dir):
+ super().__init__(workflow="inference")
+ self.filename = filename
+ self.output_dir = output_dir
+ self._bundle_root = "will override"
+ self._device = torch.device("cpu")
+ self._network_def = None
+ self._inferer = None
+ self._preprocessing = None
+ self._postprocessing = None
+ self._evaluator = None
+
+ def initialize(self):
+ set_determinism(0)
+ if self._preprocessing is None:
+ self._preprocessing = Compose(
+ [LoadImaged(keys="image"), EnsureChannelFirstd(keys="image"), ScaleIntensityd(keys="image")]
+ )
+ dataset = Dataset(data=[{"image": self.filename}], transform=self._preprocessing)
+ dataloader = DataLoader(dataset, batch_size=1, num_workers=4)
+
+ if self._network_def is None:
+ self._network_def = UNet(
+ spatial_dims=3,
+ in_channels=1,
+ out_channels=2,
+ channels=[2, 2, 4, 8, 4],
+ strides=[2, 2, 2, 2],
+ num_res_units=2,
+ norm="batch",
+ )
+ if self._inferer is None:
+ self._inferer = SlidingWindowInferer(roi_size=(64, 64, 32), sw_batch_size=4, overlap=0.25)
+
+ if self._postprocessing is None:
+ self._postprocessing = Compose(
+ [
+ Activationsd(keys="pred", softmax=True),
+ AsDiscreted(keys="pred", argmax=True),
+ SaveImaged(keys="pred", output_dir=self.output_dir, output_postfix="seg"),
+ ]
+ )
+
+ self._evaluator = SupervisedEvaluator(
+ device=self._device,
+ val_data_loader=dataloader,
+ network=self._network_def.to(self._device),
+ inferer=self._inferer,
+ postprocessing=self._postprocessing,
+ amp=False,
+ )
+
+ def run(self):
+ self._evaluator.run()
+
+ def finalize(self):
+ return True
+
+ def _get_property(self, name, property):
+ if name == "bundle_root":
+ return self._bundle_root
+ if name == "device":
+ return self._device
+ if name == "network_def":
+ return self._network_def
+ if name == "inferer":
+ return self._inferer
+ if name == "preprocessing":
+ return self._preprocessing
+ if name == "postprocessing":
+ return self._postprocessing
+ if property[BundleProperty.REQUIRED]:
+ raise ValueError(f"unsupported property '{name}' is required in the bundle properties.")
+
+ def _set_property(self, name, property, value):
+ if name == "bundle_root":
+ self._bundle_root = value
+ elif name == "device":
+ self._device = value
+ elif name == "network_def":
+ self._network_def = value
+ elif name == "inferer":
+ self._inferer = value
+ elif name == "preprocessing":
+ self._preprocessing = value
+ elif name == "postprocessing":
+ self._postprocessing = value
+ elif property[BundleProperty.REQUIRED]:
+ raise ValueError(f"unsupported property '{name}' is required in the bundle properties.")
diff --git a/tests/test_bundle_workflow.py b/tests/test_bundle_workflow.py
index 948c351a1c..33a3f1d959 100644
--- a/tests/test_bundle_workflow.py
+++ b/tests/test_bundle_workflow.py
@@ -22,22 +22,12 @@
import torch
from parameterized import parameterized
-from monai.bundle import BundleWorkflow, ConfigWorkflow
-from monai.data import DataLoader, Dataset
-from monai.engines import SupervisedEvaluator
+from monai.bundle import ConfigWorkflow
+from monai.data import Dataset
from monai.inferers import SimpleInferer, SlidingWindowInferer
from monai.networks.nets import UNet
-from monai.transforms import (
- Activationsd,
- AsDiscreted,
- Compose,
- EnsureChannelFirstd,
- LoadImage,
- LoadImaged,
- SaveImaged,
- ScaleIntensityd,
-)
-from monai.utils import BundleProperty, set_determinism
+from monai.transforms import Compose, LoadImage
+from tests.nonconfig_workflow import NonConfigWorkflow
TEST_CASE_1 = [os.path.join(os.path.dirname(__file__), "testing_data", "inference.json")]
@@ -46,103 +36,6 @@
TEST_CASE_3 = [os.path.join(os.path.dirname(__file__), "testing_data", "config_fl_train.json")]
-class NonConfigWorkflow(BundleWorkflow):
- """
- Test class simulates the bundle workflow defined by Python script directly.
-
- """
-
- def __init__(self, filename, output_dir):
- super().__init__(workflow="inference")
- self.filename = filename
- self.output_dir = output_dir
- self._bundle_root = "will override"
- self._device = torch.device("cpu")
- self._network_def = None
- self._inferer = None
- self._preprocessing = None
- self._postprocessing = None
- self._evaluator = None
-
- def initialize(self):
- set_determinism(0)
- if self._preprocessing is None:
- self._preprocessing = Compose(
- [LoadImaged(keys="image"), EnsureChannelFirstd(keys="image"), ScaleIntensityd(keys="image")]
- )
- dataset = Dataset(data=[{"image": self.filename}], transform=self._preprocessing)
- dataloader = DataLoader(dataset, batch_size=1, num_workers=4)
-
- if self._network_def is None:
- self._network_def = UNet(
- spatial_dims=3,
- in_channels=1,
- out_channels=2,
- channels=[2, 2, 4, 8, 4],
- strides=[2, 2, 2, 2],
- num_res_units=2,
- norm="batch",
- )
- if self._inferer is None:
- self._inferer = SlidingWindowInferer(roi_size=(64, 64, 32), sw_batch_size=4, overlap=0.25)
-
- if self._postprocessing is None:
- self._postprocessing = Compose(
- [
- Activationsd(keys="pred", softmax=True),
- AsDiscreted(keys="pred", argmax=True),
- SaveImaged(keys="pred", output_dir=self.output_dir, output_postfix="seg"),
- ]
- )
-
- self._evaluator = SupervisedEvaluator(
- device=self._device,
- val_data_loader=dataloader,
- network=self._network_def.to(self._device),
- inferer=self._inferer,
- postprocessing=self._postprocessing,
- amp=False,
- )
-
- def run(self):
- self._evaluator.run()
-
- def finalize(self):
- return True
-
- def _get_property(self, name, property):
- if name == "bundle_root":
- return self._bundle_root
- if name == "device":
- return self._device
- if name == "network_def":
- return self._network_def
- if name == "inferer":
- return self._inferer
- if name == "preprocessing":
- return self._preprocessing
- if name == "postprocessing":
- return self._postprocessing
- if property[BundleProperty.REQUIRED]:
- raise ValueError(f"unsupported property '{name}' is required in the bundle properties.")
-
- def _set_property(self, name, property, value):
- if name == "bundle_root":
- self._bundle_root = value
- elif name == "device":
- self._device = value
- elif name == "network_def":
- self._network_def = value
- elif name == "inferer":
- self._inferer = value
- elif name == "preprocessing":
- self._preprocessing = value
- elif name == "postprocessing":
- self._postprocessing = value
- elif property[BundleProperty.REQUIRED]:
- raise ValueError(f"unsupported property '{name}' is required in the bundle properties.")
-
-
class TestBundleWorkflow(unittest.TestCase):
def setUp(self):
self.data_dir = tempfile.mkdtemp()
diff --git a/tests/test_integration_bundle_run.py b/tests/test_integration_bundle_run.py
index 24c0286133..2018ca801e 100644
--- a/tests/test_integration_bundle_run.py
+++ b/tests/test_integration_bundle_run.py
@@ -62,8 +62,10 @@ def test_tiny(self):
},
f,
)
- cmd = ["coverage", "run", "-m", "monai.bundle", "run", "training", "--config_file", config_file]
- command_line_tests(cmd)
+ cmd = ["coverage", "run", "-m", "monai.bundle"]
+ # test both CLI entry "run" and "run_workflow"
+ command_line_tests(cmd + ["run", "training", "--config_file", config_file])
+ command_line_tests(cmd + ["run_workflow", "--run_id", "training", "--config_file", config_file])
@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
def test_shape(self, config_file, expected_shape):
@@ -136,6 +138,18 @@ def test_shape(self, config_file, expected_shape):
# test the saved execution configs
self.assertTrue(len(glob(f"{tempdir}/config_*.json")), 2)
+ def test_customized_workflow(self):
+ expected_shape = (64, 64, 64)
+ test_image = np.random.rand(*expected_shape)
+ filename = os.path.join(self.data_dir, "image.nii")
+ nib.save(nib.Nifti1Image(test_image, np.eye(4)), filename)
+
+ cmd = "-m fire monai.bundle.scripts run_workflow --workflow tests.nonconfig_workflow.NonConfigWorkflow"
+ cmd += f" --filename {filename} --output_dir {self.data_dir}"
+ command_line_tests(["coverage", "run"] + cmd.split(" "))
+ loader = LoadImage(image_only=True)
+ self.assertTupleEqual(loader(os.path.join(self.data_dir, "image", "image_seg.nii.gz")).shape, expected_shape)
+
if __name__ == "__main__":
unittest.main()