Skip to content

Commit

Permalink
5821 Enhance bundle CLI entry for different bundle workflows (#6181)
Browse files Browse the repository at this point in the history
part of #5821 .

### Description

This PR enhanced the bundle CLI entry to support different customized
bundle workflows.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Nic Ma <nma@nvidia.com>
  • Loading branch information
Nic-Ma authored Apr 3, 2023
1 parent 9b4c235 commit 6aa4f90
Show file tree
Hide file tree
Showing 7 changed files with 209 additions and 117 deletions.
1 change: 1 addition & 0 deletions monai/bundle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
init_bundle,
load,
run,
run_workflow,
trt_export,
verify_metadata,
verify_net_in_out,
Expand Down
11 changes: 10 additions & 1 deletion monai/bundle/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,16 @@

from __future__ import annotations

from monai.bundle.scripts import ckpt_export, download, init_bundle, run, trt_export, verify_metadata, verify_net_in_out
from monai.bundle.scripts import (
ckpt_export,
download,
init_bundle,
run,
run_workflow,
trt_export,
verify_metadata,
verify_net_in_out,
)

if __name__ == "__main__":
from monai.utils import optional_import
Expand Down
52 changes: 50 additions & 2 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import warnings
from collections.abc import Mapping, Sequence
from pathlib import Path
from pydoc import locate
from shutil import copyfile
from textwrap import dedent
from typing import Any, Callable
Expand All @@ -30,7 +31,7 @@
from monai.bundle.config_item import ConfigComponent
from monai.bundle.config_parser import ConfigParser
from monai.bundle.utils import DEFAULT_INFERENCE, DEFAULT_METADATA
from monai.bundle.workflows import ConfigWorkflow
from monai.bundle.workflows import BundleWorkflow, ConfigWorkflow
from monai.config import IgniteInfo, PathLike
from monai.data import load_net_with_metadata, save_net_with_metadata
from monai.networks import convert_to_torchscript, convert_to_trt, copy_model_state, get_state_dict, save_state
Expand Down Expand Up @@ -635,7 +636,7 @@ def run(
args_file: a JSON or YAML file to provide default values for `runner_id`, `meta_file`,
`config_file`, `logging`, and override pairs. so that the command line inputs can be simplified.
override: id-value pairs to override or add the corresponding config content.
e.g. ``--net#input_chns 42``.
e.g. ``--net#input_chns 42``, ``--net %/data/other.json#net_arg``.
"""

Expand Down Expand Up @@ -678,6 +679,53 @@ def run(
workflow.finalize()


def run_workflow(workflow: str | BundleWorkflow | None = None, args_file: str | None = None, **kwargs: Any) -> None:
"""
Specify `bundle workflow` to run monai bundle components and workflows.
The workflow should be suclass of `BundleWorkflow` and be available to import.
It can be MONAI existing bundle workflows or user customized workflows.
Typical usage examples:
.. code-block:: bash
# Execute this module as a CLI entry with default ConfigWorkflow:
python -m monai.bundle run_workflow --meta_file <meta path> --config_file <config path>
# Set the workflow to other customized BundleWorkflow subclass:
python -m monai.bundle run_workflow --workflow CustomizedWorkflow ...
Args:
workflow: specified bundle workflow name, should be a string or class, default to "ConfigWorkflow".
args_file: a JSON or YAML file to provide default values for this API.
so that the command line inputs can be simplified.
kwargs: arguments to instantiate the workflow class.
"""

_args = _update_args(args=args_file, workflow=workflow, **kwargs)
_log_input_summary(tag="run", args=_args)
(workflow_name,) = _pop_args(_args, workflow=ConfigWorkflow) # the default workflow name is "ConfigWorkflow"
if isinstance(workflow_name, str):
workflow_class, has_built_in = optional_import("monai.bundle", name=str(workflow_name)) # search built-in
if not has_built_in:
workflow_class = locate(str(workflow_name)) # search dotted path
if workflow_class is None:
raise ValueError(f"cannot locate specified workflow class: {workflow_name}.")
elif issubclass(workflow_name, BundleWorkflow):
workflow_class = workflow_name
else:
raise ValueError(
"Argument `workflow` must be a bundle workflow class name"
f"or subclass of BundleWorkflow, got: {workflow_name}."
)

workflow_ = workflow_class(**_args)
workflow_.initialize()
workflow_.run()
workflow_.finalize()


def verify_metadata(
meta_file: str | Sequence[str] | None = None,
filepath: PathLike | None = None,
Expand Down
2 changes: 1 addition & 1 deletion monai/bundle/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ class ConfigWorkflow(BundleWorkflow):
other unsupported string will raise a ValueError.
default to `None` for common workflow.
override: id-value pairs to override or add the corresponding config content.
e.g. ``--net#input_chns 42``.
e.g. ``--net#input_chns 42``, ``--net %/data/other.json#net_arg``
"""

Expand Down
127 changes: 127 additions & 0 deletions tests/nonconfig_workflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import torch

from monai.bundle import BundleWorkflow
from monai.data import DataLoader, Dataset
from monai.engines import SupervisedEvaluator
from monai.inferers import SlidingWindowInferer
from monai.networks.nets import UNet
from monai.transforms import (
Activationsd,
AsDiscreted,
Compose,
EnsureChannelFirstd,
LoadImaged,
SaveImaged,
ScaleIntensityd,
)
from monai.utils import BundleProperty, set_determinism


class NonConfigWorkflow(BundleWorkflow):
"""
Test class simulates the bundle workflow defined by Python script directly.
"""

def __init__(self, filename, output_dir):
super().__init__(workflow="inference")
self.filename = filename
self.output_dir = output_dir
self._bundle_root = "will override"
self._device = torch.device("cpu")
self._network_def = None
self._inferer = None
self._preprocessing = None
self._postprocessing = None
self._evaluator = None

def initialize(self):
set_determinism(0)
if self._preprocessing is None:
self._preprocessing = Compose(
[LoadImaged(keys="image"), EnsureChannelFirstd(keys="image"), ScaleIntensityd(keys="image")]
)
dataset = Dataset(data=[{"image": self.filename}], transform=self._preprocessing)
dataloader = DataLoader(dataset, batch_size=1, num_workers=4)

if self._network_def is None:
self._network_def = UNet(
spatial_dims=3,
in_channels=1,
out_channels=2,
channels=[2, 2, 4, 8, 4],
strides=[2, 2, 2, 2],
num_res_units=2,
norm="batch",
)
if self._inferer is None:
self._inferer = SlidingWindowInferer(roi_size=(64, 64, 32), sw_batch_size=4, overlap=0.25)

if self._postprocessing is None:
self._postprocessing = Compose(
[
Activationsd(keys="pred", softmax=True),
AsDiscreted(keys="pred", argmax=True),
SaveImaged(keys="pred", output_dir=self.output_dir, output_postfix="seg"),
]
)

self._evaluator = SupervisedEvaluator(
device=self._device,
val_data_loader=dataloader,
network=self._network_def.to(self._device),
inferer=self._inferer,
postprocessing=self._postprocessing,
amp=False,
)

def run(self):
self._evaluator.run()

def finalize(self):
return True

def _get_property(self, name, property):
if name == "bundle_root":
return self._bundle_root
if name == "device":
return self._device
if name == "network_def":
return self._network_def
if name == "inferer":
return self._inferer
if name == "preprocessing":
return self._preprocessing
if name == "postprocessing":
return self._postprocessing
if property[BundleProperty.REQUIRED]:
raise ValueError(f"unsupported property '{name}' is required in the bundle properties.")

def _set_property(self, name, property, value):
if name == "bundle_root":
self._bundle_root = value
elif name == "device":
self._device = value
elif name == "network_def":
self._network_def = value
elif name == "inferer":
self._inferer = value
elif name == "preprocessing":
self._preprocessing = value
elif name == "postprocessing":
self._postprocessing = value
elif property[BundleProperty.REQUIRED]:
raise ValueError(f"unsupported property '{name}' is required in the bundle properties.")
115 changes: 4 additions & 111 deletions tests/test_bundle_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,12 @@
import torch
from parameterized import parameterized

from monai.bundle import BundleWorkflow, ConfigWorkflow
from monai.data import DataLoader, Dataset
from monai.engines import SupervisedEvaluator
from monai.bundle import ConfigWorkflow
from monai.data import Dataset
from monai.inferers import SimpleInferer, SlidingWindowInferer
from monai.networks.nets import UNet
from monai.transforms import (
Activationsd,
AsDiscreted,
Compose,
EnsureChannelFirstd,
LoadImage,
LoadImaged,
SaveImaged,
ScaleIntensityd,
)
from monai.utils import BundleProperty, set_determinism
from monai.transforms import Compose, LoadImage
from tests.nonconfig_workflow import NonConfigWorkflow

TEST_CASE_1 = [os.path.join(os.path.dirname(__file__), "testing_data", "inference.json")]

Expand All @@ -46,103 +36,6 @@
TEST_CASE_3 = [os.path.join(os.path.dirname(__file__), "testing_data", "config_fl_train.json")]


class NonConfigWorkflow(BundleWorkflow):
"""
Test class simulates the bundle workflow defined by Python script directly.
"""

def __init__(self, filename, output_dir):
super().__init__(workflow="inference")
self.filename = filename
self.output_dir = output_dir
self._bundle_root = "will override"
self._device = torch.device("cpu")
self._network_def = None
self._inferer = None
self._preprocessing = None
self._postprocessing = None
self._evaluator = None

def initialize(self):
set_determinism(0)
if self._preprocessing is None:
self._preprocessing = Compose(
[LoadImaged(keys="image"), EnsureChannelFirstd(keys="image"), ScaleIntensityd(keys="image")]
)
dataset = Dataset(data=[{"image": self.filename}], transform=self._preprocessing)
dataloader = DataLoader(dataset, batch_size=1, num_workers=4)

if self._network_def is None:
self._network_def = UNet(
spatial_dims=3,
in_channels=1,
out_channels=2,
channels=[2, 2, 4, 8, 4],
strides=[2, 2, 2, 2],
num_res_units=2,
norm="batch",
)
if self._inferer is None:
self._inferer = SlidingWindowInferer(roi_size=(64, 64, 32), sw_batch_size=4, overlap=0.25)

if self._postprocessing is None:
self._postprocessing = Compose(
[
Activationsd(keys="pred", softmax=True),
AsDiscreted(keys="pred", argmax=True),
SaveImaged(keys="pred", output_dir=self.output_dir, output_postfix="seg"),
]
)

self._evaluator = SupervisedEvaluator(
device=self._device,
val_data_loader=dataloader,
network=self._network_def.to(self._device),
inferer=self._inferer,
postprocessing=self._postprocessing,
amp=False,
)

def run(self):
self._evaluator.run()

def finalize(self):
return True

def _get_property(self, name, property):
if name == "bundle_root":
return self._bundle_root
if name == "device":
return self._device
if name == "network_def":
return self._network_def
if name == "inferer":
return self._inferer
if name == "preprocessing":
return self._preprocessing
if name == "postprocessing":
return self._postprocessing
if property[BundleProperty.REQUIRED]:
raise ValueError(f"unsupported property '{name}' is required in the bundle properties.")

def _set_property(self, name, property, value):
if name == "bundle_root":
self._bundle_root = value
elif name == "device":
self._device = value
elif name == "network_def":
self._network_def = value
elif name == "inferer":
self._inferer = value
elif name == "preprocessing":
self._preprocessing = value
elif name == "postprocessing":
self._postprocessing = value
elif property[BundleProperty.REQUIRED]:
raise ValueError(f"unsupported property '{name}' is required in the bundle properties.")


class TestBundleWorkflow(unittest.TestCase):
def setUp(self):
self.data_dir = tempfile.mkdtemp()
Expand Down
Loading

0 comments on commit 6aa4f90

Please sign in to comment.