Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

5821 Enhance bundle CLI entry for different bundle workflows #6181

Merged
merged 20 commits into from
Apr 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions monai/bundle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
init_bundle,
load,
run,
run_workflow,
trt_export,
verify_metadata,
verify_net_in_out,
Expand Down
11 changes: 10 additions & 1 deletion monai/bundle/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,16 @@

from __future__ import annotations

from monai.bundle.scripts import ckpt_export, download, init_bundle, run, trt_export, verify_metadata, verify_net_in_out
from monai.bundle.scripts import (
ckpt_export,
download,
init_bundle,
run,
run_workflow,
trt_export,
verify_metadata,
verify_net_in_out,
)

if __name__ == "__main__":
from monai.utils import optional_import
Expand Down
52 changes: 50 additions & 2 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import warnings
from collections.abc import Mapping, Sequence
from pathlib import Path
from pydoc import locate
from shutil import copyfile
from textwrap import dedent
from typing import Any, Callable
Expand All @@ -30,7 +31,7 @@
from monai.bundle.config_item import ConfigComponent
from monai.bundle.config_parser import ConfigParser
from monai.bundle.utils import DEFAULT_INFERENCE, DEFAULT_METADATA
from monai.bundle.workflows import ConfigWorkflow
from monai.bundle.workflows import BundleWorkflow, ConfigWorkflow
from monai.config import IgniteInfo, PathLike
from monai.data import load_net_with_metadata, save_net_with_metadata
from monai.networks import convert_to_torchscript, convert_to_trt, copy_model_state, get_state_dict, save_state
Expand Down Expand Up @@ -635,7 +636,7 @@ def run(
args_file: a JSON or YAML file to provide default values for `runner_id`, `meta_file`,
`config_file`, `logging`, and override pairs. so that the command line inputs can be simplified.
override: id-value pairs to override or add the corresponding config content.
e.g. ``--net#input_chns 42``.
e.g. ``--net#input_chns 42``, ``--net %/data/other.json#net_arg``.

"""

Expand Down Expand Up @@ -678,6 +679,53 @@ def run(
workflow.finalize()


def run_workflow(workflow: str | BundleWorkflow | None = None, args_file: str | None = None, **kwargs: Any) -> None:
"""
Specify `bundle workflow` to run monai bundle components and workflows.
Nic-Ma marked this conversation as resolved.
Show resolved Hide resolved
The workflow should be suclass of `BundleWorkflow` and be available to import.
It can be MONAI existing bundle workflows or user customized workflows.

Typical usage examples:

.. code-block:: bash

# Execute this module as a CLI entry with default ConfigWorkflow:
python -m monai.bundle run_workflow --meta_file <meta path> --config_file <config path>

# Set the workflow to other customized BundleWorkflow subclass:
python -m monai.bundle run_workflow --workflow CustomizedWorkflow ...

Args:
workflow: specified bundle workflow name, should be a string or class, default to "ConfigWorkflow".
args_file: a JSON or YAML file to provide default values for this API.
so that the command line inputs can be simplified.
kwargs: arguments to instantiate the workflow class.

"""

_args = _update_args(args=args_file, workflow=workflow, **kwargs)
_log_input_summary(tag="run", args=_args)
(workflow_name,) = _pop_args(_args, workflow=ConfigWorkflow) # the default workflow name is "ConfigWorkflow"
Nic-Ma marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(workflow_name, str):
workflow_class, has_built_in = optional_import("monai.bundle", name=str(workflow_name)) # search built-in
if not has_built_in:
workflow_class = locate(str(workflow_name)) # search dotted path
if workflow_class is None:
raise ValueError(f"cannot locate specified workflow class: {workflow_name}.")
elif issubclass(workflow_name, BundleWorkflow):
workflow_class = workflow_name
else:
raise ValueError(
"Argument `workflow` must be a bundle workflow class name"
f"or subclass of BundleWorkflow, got: {workflow_name}."
)

workflow_ = workflow_class(**_args)
workflow_.initialize()
workflow_.run()
workflow_.finalize()


def verify_metadata(
meta_file: str | Sequence[str] | None = None,
filepath: PathLike | None = None,
Expand Down
2 changes: 1 addition & 1 deletion monai/bundle/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ class ConfigWorkflow(BundleWorkflow):
other unsupported string will raise a ValueError.
default to `None` for common workflow.
override: id-value pairs to override or add the corresponding config content.
e.g. ``--net#input_chns 42``.
e.g. ``--net#input_chns 42``, ``--net %/data/other.json#net_arg``
"""

Expand Down
127 changes: 127 additions & 0 deletions tests/nonconfig_workflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import torch

from monai.bundle import BundleWorkflow
from monai.data import DataLoader, Dataset
from monai.engines import SupervisedEvaluator
from monai.inferers import SlidingWindowInferer
from monai.networks.nets import UNet
from monai.transforms import (
Activationsd,
AsDiscreted,
Compose,
EnsureChannelFirstd,
LoadImaged,
SaveImaged,
ScaleIntensityd,
)
from monai.utils import BundleProperty, set_determinism


class NonConfigWorkflow(BundleWorkflow):
"""
Test class simulates the bundle workflow defined by Python script directly.

"""

def __init__(self, filename, output_dir):
super().__init__(workflow="inference")
self.filename = filename
self.output_dir = output_dir
self._bundle_root = "will override"
self._device = torch.device("cpu")
self._network_def = None
self._inferer = None
self._preprocessing = None
self._postprocessing = None
self._evaluator = None

def initialize(self):
set_determinism(0)
if self._preprocessing is None:
self._preprocessing = Compose(
[LoadImaged(keys="image"), EnsureChannelFirstd(keys="image"), ScaleIntensityd(keys="image")]
)
dataset = Dataset(data=[{"image": self.filename}], transform=self._preprocessing)
dataloader = DataLoader(dataset, batch_size=1, num_workers=4)

if self._network_def is None:
self._network_def = UNet(
spatial_dims=3,
in_channels=1,
out_channels=2,
channels=[2, 2, 4, 8, 4],
strides=[2, 2, 2, 2],
num_res_units=2,
norm="batch",
)
if self._inferer is None:
self._inferer = SlidingWindowInferer(roi_size=(64, 64, 32), sw_batch_size=4, overlap=0.25)

if self._postprocessing is None:
self._postprocessing = Compose(
[
Activationsd(keys="pred", softmax=True),
AsDiscreted(keys="pred", argmax=True),
SaveImaged(keys="pred", output_dir=self.output_dir, output_postfix="seg"),
]
)

self._evaluator = SupervisedEvaluator(
device=self._device,
val_data_loader=dataloader,
network=self._network_def.to(self._device),
inferer=self._inferer,
postprocessing=self._postprocessing,
amp=False,
)

def run(self):
self._evaluator.run()

def finalize(self):
return True

def _get_property(self, name, property):
if name == "bundle_root":
return self._bundle_root
if name == "device":
return self._device
if name == "network_def":
return self._network_def
if name == "inferer":
return self._inferer
if name == "preprocessing":
return self._preprocessing
if name == "postprocessing":
return self._postprocessing
if property[BundleProperty.REQUIRED]:
raise ValueError(f"unsupported property '{name}' is required in the bundle properties.")

def _set_property(self, name, property, value):
if name == "bundle_root":
self._bundle_root = value
elif name == "device":
self._device = value
elif name == "network_def":
self._network_def = value
elif name == "inferer":
self._inferer = value
elif name == "preprocessing":
self._preprocessing = value
elif name == "postprocessing":
self._postprocessing = value
elif property[BundleProperty.REQUIRED]:
raise ValueError(f"unsupported property '{name}' is required in the bundle properties.")
115 changes: 4 additions & 111 deletions tests/test_bundle_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,12 @@
import torch
from parameterized import parameterized

from monai.bundle import BundleWorkflow, ConfigWorkflow
from monai.data import DataLoader, Dataset
from monai.engines import SupervisedEvaluator
from monai.bundle import ConfigWorkflow
from monai.data import Dataset
from monai.inferers import SimpleInferer, SlidingWindowInferer
from monai.networks.nets import UNet
from monai.transforms import (
Activationsd,
AsDiscreted,
Compose,
EnsureChannelFirstd,
LoadImage,
LoadImaged,
SaveImaged,
ScaleIntensityd,
)
from monai.utils import BundleProperty, set_determinism
from monai.transforms import Compose, LoadImage
from tests.nonconfig_workflow import NonConfigWorkflow

TEST_CASE_1 = [os.path.join(os.path.dirname(__file__), "testing_data", "inference.json")]

Expand All @@ -46,103 +36,6 @@
TEST_CASE_3 = [os.path.join(os.path.dirname(__file__), "testing_data", "config_fl_train.json")]


class NonConfigWorkflow(BundleWorkflow):
"""
Test class simulates the bundle workflow defined by Python script directly.
"""

def __init__(self, filename, output_dir):
super().__init__(workflow="inference")
self.filename = filename
self.output_dir = output_dir
self._bundle_root = "will override"
self._device = torch.device("cpu")
self._network_def = None
self._inferer = None
self._preprocessing = None
self._postprocessing = None
self._evaluator = None

def initialize(self):
set_determinism(0)
if self._preprocessing is None:
self._preprocessing = Compose(
[LoadImaged(keys="image"), EnsureChannelFirstd(keys="image"), ScaleIntensityd(keys="image")]
)
dataset = Dataset(data=[{"image": self.filename}], transform=self._preprocessing)
dataloader = DataLoader(dataset, batch_size=1, num_workers=4)

if self._network_def is None:
self._network_def = UNet(
spatial_dims=3,
in_channels=1,
out_channels=2,
channels=[2, 2, 4, 8, 4],
strides=[2, 2, 2, 2],
num_res_units=2,
norm="batch",
)
if self._inferer is None:
self._inferer = SlidingWindowInferer(roi_size=(64, 64, 32), sw_batch_size=4, overlap=0.25)

if self._postprocessing is None:
self._postprocessing = Compose(
[
Activationsd(keys="pred", softmax=True),
AsDiscreted(keys="pred", argmax=True),
SaveImaged(keys="pred", output_dir=self.output_dir, output_postfix="seg"),
]
)

self._evaluator = SupervisedEvaluator(
device=self._device,
val_data_loader=dataloader,
network=self._network_def.to(self._device),
inferer=self._inferer,
postprocessing=self._postprocessing,
amp=False,
)

def run(self):
self._evaluator.run()

def finalize(self):
return True

def _get_property(self, name, property):
if name == "bundle_root":
return self._bundle_root
if name == "device":
return self._device
if name == "network_def":
return self._network_def
if name == "inferer":
return self._inferer
if name == "preprocessing":
return self._preprocessing
if name == "postprocessing":
return self._postprocessing
if property[BundleProperty.REQUIRED]:
raise ValueError(f"unsupported property '{name}' is required in the bundle properties.")

def _set_property(self, name, property, value):
if name == "bundle_root":
self._bundle_root = value
elif name == "device":
self._device = value
elif name == "network_def":
self._network_def = value
elif name == "inferer":
self._inferer = value
elif name == "preprocessing":
self._preprocessing = value
elif name == "postprocessing":
self._postprocessing = value
elif property[BundleProperty.REQUIRED]:
raise ValueError(f"unsupported property '{name}' is required in the bundle properties.")


class TestBundleWorkflow(unittest.TestCase):
def setUp(self):
self.data_dir = tempfile.mkdtemp()
Expand Down
Loading