Skip to content

Commit

Permalink
5821 6303 Optimize MonaiAlgo FL based on BundleWorkflow (#6158)
Browse files Browse the repository at this point in the history
part of #5821 
Fixes #6303 

### Description

This PR simplified the MONAI FL `MonaiAlgo` module to leverage
`BundleWorkflow`.
The main point is to decouple the bundle read / write related logic with
FL module and use predefined required-properties.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Nic Ma <nma@nvidia.com>
Signed-off-by: monai-bot <monai.miccai2019@gmail.com>
Co-authored-by: Holger Roth <hroth@nvidia.com>
Co-authored-by: monai-bot <monai.miccai2019@gmail.com>
  • Loading branch information
3 people authored Apr 19, 2023
1 parent c30a5b9 commit d8eb68a
Show file tree
Hide file tree
Showing 15 changed files with 361 additions and 424 deletions.
5 changes: 5 additions & 0 deletions monai/bundle/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,11 @@
BundleProperty.REQUIRED: True,
BundlePropertyConfig.ID: "device",
},
"evaluator": {
BundleProperty.DESC: "inference / evaluation workflow engine.",
BundleProperty.REQUIRED: True,
BundlePropertyConfig.ID: "evaluator",
},
"network_def": {
BundleProperty.DESC: "network module for the inference.",
BundleProperty.REQUIRED: True,
Expand Down
12 changes: 9 additions & 3 deletions monai/bundle/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,18 @@ class BundleWorkflow(ABC):
"""

supported_train_type: tuple = ("train", "training")
supported_infer_type: tuple = ("infer", "inference", "eval", "evaluation")

def __init__(self, workflow: str | None = None):
if workflow is None:
self.properties = None
self.workflow = None
return
if workflow.lower() in ("train", "training"):
if workflow.lower() in self.supported_train_type:
self.properties = TrainProperties
self.workflow = "train"
elif workflow.lower() in ("infer", "inference", "eval", "evaluation"):
elif workflow.lower() in self.supported_infer_type:
self.properties = InferProperties
self.workflow = "infer"
else:
Expand Down Expand Up @@ -215,6 +218,7 @@ def __init__(
else:
settings_ = ConfigParser.load_config_files(tracking)
self.patch_bundle_tracking(parser=self.parser, settings=settings_)
self._is_initialized: bool = False

def initialize(self) -> Any:
"""
Expand All @@ -223,6 +227,7 @@ def initialize(self) -> Any:
"""
# reset the "reference_resolver" buffer at initialization stage
self.parser.parse(reset=True)
self._is_initialized = True
return self._run_expr(id=self.init_id)

def run(self) -> Any:
Expand Down Expand Up @@ -284,7 +289,7 @@ def _get_property(self, name: str, property: dict) -> Any:
property: other information for the target property, defined in `TrainProperties` or `InferProperties`.
"""
if not self.parser.ref_resolver.is_resolved():
if not self._is_initialized:
raise RuntimeError("Please execute 'initialize' before getting any parsed content.")
prop_id = self._get_prop_id(name, property)
return self.parser.get_parsed_content(id=prop_id) if prop_id is not None else None
Expand All @@ -303,6 +308,7 @@ def _set_property(self, name: str, property: dict, value: Any) -> None:
if prop_id is not None:
self.parser[prop_id] = value
# must parse the config again after changing the content
self._is_initialized = False
self.parser.ref_resolver.reset()

def _check_optional_id(self, name: str, property: dict) -> bool:
Expand Down
335 changes: 151 additions & 184 deletions monai/fl/client/monai_algo.py

Large diffs are not rendered by default.

14 changes: 0 additions & 14 deletions monai/fl/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,6 @@ class FlStatistics(StrEnum):
FEATURE_NAMES = "feature_names"


class RequiredBundleKeys(StrEnum):
BUNDLE_ROOT = "bundle_root"


class BundleKeys(StrEnum):
TRAINER = "train#trainer"
EVALUATOR = "validate#evaluator"
TRAIN_TRAINER_MAX_EPOCHS = "train#trainer#max_epochs"
VALIDATE_HANDLERS = "validate#handlers"
DATASET_DIR = "dataset_dir"
TRAIN_DATA = "train#dataset#data"
VALID_DATA = "validate#dataset#data"


class FiltersType(StrEnum):
PRE_FILTERS = "pre_filters"
POST_WEIGHT_FILTERS = "post_weight_filters"
Expand Down
2 changes: 1 addition & 1 deletion monai/fl/utils/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __call__(self, data: ExchangeObject, extra: dict | None = None) -> ExchangeO

class SummaryFilter(Filter):
"""
Summary filter to content of ExchangeObject.
Summary filter to show content of ExchangeObject.
"""

def __call__(self, data: ExchangeObject, extra: dict | None = None) -> ExchangeObject:
Expand Down
4 changes: 4 additions & 0 deletions tests/nonconfig_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ def _get_property(self, name, property):
return self._bundle_root
if name == "device":
return self._device
if name == "evaluator":
return self._evaluator
if name == "network_def":
return self._network_def
if name == "inferer":
Expand All @@ -115,6 +117,8 @@ def _set_property(self, name, property, value):
self._bundle_root = value
elif name == "device":
self._device = value
elif name == "evaluator":
self._evaluator = value
elif name == "network_def":
self._network_def = value
elif name == "inferer":
Expand Down
155 changes: 75 additions & 80 deletions tests/test_fl_monai_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@

import os
import shutil
import tempfile
import unittest
from copy import deepcopy
from os.path import join as pathjoin

from parameterized import parameterized

from monai.bundle import ConfigParser
from monai.bundle import ConfigParser, ConfigWorkflow
from monai.bundle.utils import DEFAULT_HANDLERS_ID
from monai.fl.client.monai_algo import MonaiAlgo
from monai.fl.utils.constants import ExtraItems
Expand All @@ -28,11 +29,14 @@

_root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__)))
_data_dir = os.path.join(_root_dir, "testing_data")
_logging_file = pathjoin(_data_dir, "logging.conf")

TEST_TRAIN_1 = [
{
"bundle_root": _data_dir,
"config_train_filename": os.path.join(_data_dir, "config_fl_train.json"),
"train_workflow": ConfigWorkflow(
config_file=os.path.join(_data_dir, "config_fl_train.json"), workflow="train", logging_file=_logging_file
),
"config_evaluate_filename": None,
"config_filters_filename": os.path.join(_data_dir, "config_fl_filters.json"),
}
Expand All @@ -48,68 +52,92 @@
TEST_TRAIN_3 = [
{
"bundle_root": _data_dir,
"config_train_filename": [
os.path.join(_data_dir, "config_fl_train.json"),
os.path.join(_data_dir, "config_fl_train.json"),
],
"train_workflow": ConfigWorkflow(
config_file=os.path.join(_data_dir, "config_fl_train.json"), workflow="train", logging_file=_logging_file
),
"config_evaluate_filename": None,
"config_filters_filename": [
os.path.join(_data_dir, "config_fl_filters.json"),
os.path.join(_data_dir, "config_fl_filters.json"),
],
"config_filters_filename": os.path.join(_data_dir, "config_fl_filters.json"),
}
]

TEST_TRAIN_4 = [
{
"bundle_root": _data_dir,
"train_workflow": ConfigWorkflow(
config_file=os.path.join(_data_dir, "config_fl_train.json"), workflow="train", logging_file=_logging_file
),
"config_evaluate_filename": None,
"tracking": {
"handlers_id": DEFAULT_HANDLERS_ID,
"configs": {
"execute_config": f"{_data_dir}/config_executed.json",
"trainer": {
"_target_": "MLFlowHandler",
"tracking_uri": path_to_uri(_data_dir) + "/mlflow_override",
"output_transform": "$monai.handlers.from_engine(['loss'], first=True)",
"close_on_complete": True,
},
},
},
"config_filters_filename": None,
}
]

TEST_EVALUATE_1 = [
{
"bundle_root": _data_dir,
"config_train_filename": None,
"config_evaluate_filename": os.path.join(_data_dir, "config_fl_evaluate.json"),
"eval_workflow": ConfigWorkflow(
config_file=[
os.path.join(_data_dir, "config_fl_train.json"),
os.path.join(_data_dir, "config_fl_evaluate.json"),
],
workflow="train",
logging_file=_logging_file,
),
"config_filters_filename": os.path.join(_data_dir, "config_fl_filters.json"),
}
]
TEST_EVALUATE_2 = [
{
"bundle_root": _data_dir,
"config_train_filename": None,
"config_evaluate_filename": os.path.join(_data_dir, "config_fl_evaluate.json"),
"config_evaluate_filename": [
os.path.join(_data_dir, "config_fl_train.json"),
os.path.join(_data_dir, "config_fl_evaluate.json"),
],
"eval_workflow_name": "training",
"config_filters_filename": None,
}
]
TEST_EVALUATE_3 = [
{
"bundle_root": _data_dir,
"config_train_filename": None,
"config_evaluate_filename": [
os.path.join(_data_dir, "config_fl_evaluate.json"),
os.path.join(_data_dir, "config_fl_evaluate.json"),
],
"config_filters_filename": [
os.path.join(_data_dir, "config_fl_filters.json"),
os.path.join(_data_dir, "config_fl_filters.json"),
],
"eval_workflow": ConfigWorkflow(
config_file=[
os.path.join(_data_dir, "config_fl_train.json"),
os.path.join(_data_dir, "config_fl_evaluate.json"),
],
workflow="train",
logging_file=_logging_file,
),
"config_filters_filename": os.path.join(_data_dir, "config_fl_filters.json"),
}
]

TEST_GET_WEIGHTS_1 = [
{
"bundle_root": _data_dir,
"config_train_filename": os.path.join(_data_dir, "config_fl_train.json"),
"train_workflow": ConfigWorkflow(
config_file=os.path.join(_data_dir, "config_fl_train.json"), workflow="train", logging_file=_logging_file
),
"config_evaluate_filename": None,
"send_weight_diff": False,
"config_filters_filename": os.path.join(_data_dir, "config_fl_filters.json"),
}
]
TEST_GET_WEIGHTS_2 = [
{
"bundle_root": _data_dir,
"config_train_filename": None,
"config_evaluate_filename": None,
"send_weight_diff": False,
"config_filters_filename": os.path.join(_data_dir, "config_fl_filters.json"),
}
]
TEST_GET_WEIGHTS_3 = [
{
"bundle_root": _data_dir,
"config_train_filename": os.path.join(_data_dir, "config_fl_train.json"),
Expand All @@ -118,59 +146,31 @@
"config_filters_filename": os.path.join(_data_dir, "config_fl_filters.json"),
}
]
TEST_GET_WEIGHTS_4 = [
TEST_GET_WEIGHTS_3 = [
{
"bundle_root": _data_dir,
"config_train_filename": [
os.path.join(_data_dir, "config_fl_train.json"),
os.path.join(_data_dir, "config_fl_train.json"),
],
"train_workflow": ConfigWorkflow(
config_file=os.path.join(_data_dir, "config_fl_train.json"), workflow="train", logging_file=_logging_file
),
"config_evaluate_filename": None,
"send_weight_diff": True,
"config_filters_filename": [
os.path.join(_data_dir, "config_fl_filters.json"),
os.path.join(_data_dir, "config_fl_filters.json"),
],
"config_filters_filename": os.path.join(_data_dir, "config_fl_filters.json"),
}
]


@SkipIfNoModule("ignite")
@SkipIfNoModule("mlflow")
class TestFLMonaiAlgo(unittest.TestCase):
@parameterized.expand([TEST_TRAIN_1, TEST_TRAIN_2, TEST_TRAIN_3])
@parameterized.expand([TEST_TRAIN_1, TEST_TRAIN_2, TEST_TRAIN_3, TEST_TRAIN_4])
def test_train(self, input_params):
# get testing data dir and update train config; using the first to define data dir
if isinstance(input_params["config_train_filename"], list):
config_train_filename = [
os.path.join(input_params["bundle_root"], x) for x in input_params["config_train_filename"]
]
else:
config_train_filename = os.path.join(input_params["bundle_root"], input_params["config_train_filename"])

data_dir = tempfile.mkdtemp()
# test experiment management
input_params["tracking"] = {
"handlers_id": DEFAULT_HANDLERS_ID,
"configs": {
"execute_config": f"{data_dir}/config_executed.json",
"trainer": {
"_target_": "MLFlowHandler",
"tracking_uri": path_to_uri(data_dir) + "/mlflow_override",
"output_transform": "$monai.handlers.from_engine(['loss'], first=True)",
"close_on_complete": True,
},
},
}

# initialize algo
algo = MonaiAlgo(**input_params)
algo.initialize(extra={ExtraItems.CLIENT_NAME: "test_fl"})
algo.abort()

# initialize model
parser = ConfigParser()
parser.read_config(config_train_filename)
parser = ConfigParser(config=deepcopy(algo.train_workflow.parser.get()))
parser.parse()
network = parser.get_parsed_content("network")

Expand All @@ -179,27 +179,22 @@ def test_train(self, input_params):
# test train
algo.train(data=data, extra={})
algo.finalize()
self.assertTrue(os.path.exists(f"{data_dir}/mlflow_override"))
self.assertTrue(os.path.exists(f"{data_dir}/config_executed.json"))
shutil.rmtree(data_dir)

# test experiment management
if "execute_config" in algo.train_workflow.parser:
self.assertTrue(os.path.exists(f"{_data_dir}/mlflow_override"))
shutil.rmtree(f"{_data_dir}/mlflow_override")
self.assertTrue(os.path.exists(f"{_data_dir}/config_executed.json"))
os.remove(f"{_data_dir}/config_executed.json")

@parameterized.expand([TEST_EVALUATE_1, TEST_EVALUATE_2, TEST_EVALUATE_3])
def test_evaluate(self, input_params):
# get testing data dir and update train config; using the first to define data dir
if isinstance(input_params["config_evaluate_filename"], list):
config_eval_filename = [
os.path.join(input_params["bundle_root"], x) for x in input_params["config_evaluate_filename"]
]
else:
config_eval_filename = os.path.join(input_params["bundle_root"], input_params["config_evaluate_filename"])

# initialize algo
algo = MonaiAlgo(**input_params)
algo.initialize(extra={ExtraItems.CLIENT_NAME: "test_fl"})

# initialize model
parser = ConfigParser()
parser.read_config(config_eval_filename)
parser = ConfigParser(config=deepcopy(algo.eval_workflow.parser.get()))
parser.parse()
network = parser.get_parsed_content("network")

Expand All @@ -208,7 +203,7 @@ def test_evaluate(self, input_params):
# test evaluate
algo.evaluate(data=data, extra={})

@parameterized.expand([TEST_GET_WEIGHTS_1, TEST_GET_WEIGHTS_2, TEST_GET_WEIGHTS_3, TEST_GET_WEIGHTS_4])
@parameterized.expand([TEST_GET_WEIGHTS_1, TEST_GET_WEIGHTS_2, TEST_GET_WEIGHTS_3])
def test_get_weights(self, input_params):
# initialize algo
algo = MonaiAlgo(**input_params)
Expand Down
Loading

0 comments on commit d8eb68a

Please sign in to comment.