diff --git a/monai/bundle/__init__.py b/monai/bundle/__init__.py index f172f1a824..4b10f71a11 100644 --- a/monai/bundle/__init__.py +++ b/monai/bundle/__init__.py @@ -13,6 +13,7 @@ from .config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem, Instantiable from .config_parser import ConfigParser +from .properties import InferProperties, TrainProperties from .reference_resolver import ReferenceResolver from .scripts import ( ckpt_export, @@ -22,7 +23,6 @@ get_bundle_versions, init_bundle, load, - patch_bundle_tracking, run, verify_metadata, verify_net_in_out, @@ -36,3 +36,4 @@ MACRO_KEY, load_bundle_config, ) +from .workflows import BundleWorkflow, ConfigWorkflow diff --git a/monai/bundle/properties.py b/monai/bundle/properties.py new file mode 100644 index 0000000000..33cfd876eb --- /dev/null +++ b/monai/bundle/properties.py @@ -0,0 +1,192 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +The predefined properties for a bundle workflow, other applications can leverage the properties +to interact with the bundle workflow. +Some properties are required and some are optional, optional properties mean: if some component of the +bundle workflow refer to the property, the property must be defined, otherwise, the property can be None. +Every item in this `TrainProperties` or `InferProperties` dictionary is a property, +the key is the property name and the values include: +1. description. +2. whether it's a required property. +3. config item ID name (only applicable when the bundle workflow is defined in config). +4. reference config item ID name (only applicable when the bundle workflow is defined in config). + +""" + +from __future__ import annotations + +from monai.bundle.utils import ID_SEP_KEY +from monai.utils import BundleProperty, BundlePropertyConfig + +TrainProperties = { + "bundle_root": { + BundleProperty.DESC: "root path of the bundle.", + BundleProperty.REQUIRED: True, + BundlePropertyConfig.ID: "bundle_root", + }, + "device": { + BundleProperty.DESC: "target device to execute the bundle workflow.", + BundleProperty.REQUIRED: True, + BundlePropertyConfig.ID: "device", + }, + "dataset_dir": { + BundleProperty.DESC: "directory path of the dataset.", + BundleProperty.REQUIRED: True, + BundlePropertyConfig.ID: "dataset_dir", + }, + "trainer": { + BundleProperty.DESC: "training workflow engine.", + BundleProperty.REQUIRED: True, + BundlePropertyConfig.ID: f"train{ID_SEP_KEY}trainer", + }, + "max_epochs": { + BundleProperty.DESC: "max number of epochs to execute the training.", + BundleProperty.REQUIRED: True, + BundlePropertyConfig.ID: f"train{ID_SEP_KEY}trainer{ID_SEP_KEY}max_epochs", + }, + "train_dataset": { + BundleProperty.DESC: "PyTorch dataset object for the training logic.", + BundleProperty.REQUIRED: True, + BundlePropertyConfig.ID: f"train{ID_SEP_KEY}dataset", + }, + "train_dataset_data": { + BundleProperty.DESC: "data source for the training dataset.", + BundleProperty.REQUIRED: True, + BundlePropertyConfig.ID: f"train{ID_SEP_KEY}dataset{ID_SEP_KEY}data", + }, + "train_inferer": { + BundleProperty.DESC: "MONAI Inferer object to execute the model computation in training.", + BundleProperty.REQUIRED: True, + BundlePropertyConfig.ID: f"train{ID_SEP_KEY}inferer", + }, + "train_handlers": { + BundleProperty.DESC: "event-handlers for the training logic.", + BundleProperty.REQUIRED: False, + BundlePropertyConfig.ID: f"train{ID_SEP_KEY}handlers", + BundlePropertyConfig.REF_ID: f"train{ID_SEP_KEY}trainer{ID_SEP_KEY}train_handlers", + }, + "train_preprocessing": { + BundleProperty.DESC: "preprocessing for the training input data.", + BundleProperty.REQUIRED: False, + BundlePropertyConfig.ID: f"train{ID_SEP_KEY}preprocessing", + BundlePropertyConfig.REF_ID: f"train{ID_SEP_KEY}dataset{ID_SEP_KEY}transform", + }, + "train_postprocessing": { + BundleProperty.DESC: "postprocessing for the training model output data.", + BundleProperty.REQUIRED: False, + BundlePropertyConfig.ID: f"train{ID_SEP_KEY}postprocessing", + BundlePropertyConfig.REF_ID: f"train{ID_SEP_KEY}trainer{ID_SEP_KEY}postprocessing", + }, + "train_key_metric": { + BundleProperty.DESC: "key metric to compute on the training data.", + BundleProperty.REQUIRED: False, + BundlePropertyConfig.ID: f"train{ID_SEP_KEY}key_metric", + BundlePropertyConfig.REF_ID: f"train{ID_SEP_KEY}trainer{ID_SEP_KEY}key_train_metric", + }, + "evaluator": { + BundleProperty.DESC: "validation workflow engine.", + BundleProperty.REQUIRED: False, + BundlePropertyConfig.ID: f"validate{ID_SEP_KEY}evaluator", + BundlePropertyConfig.REF_ID: "validator", # this REF_ID is the arg name of `ValidationHandler` + }, + "val_interval": { + BundleProperty.DESC: "validation interval during the training.", + BundleProperty.REQUIRED: False, + BundlePropertyConfig.ID: "val_interval", + BundlePropertyConfig.REF_ID: "interval", # this REF_ID is the arg name of `ValidationHandler` + }, + "val_handlers": { + BundleProperty.DESC: "event-handlers for the validation logic.", + BundleProperty.REQUIRED: False, + BundlePropertyConfig.ID: f"validate{ID_SEP_KEY}handlers", + BundlePropertyConfig.REF_ID: f"validate{ID_SEP_KEY}evaluator{ID_SEP_KEY}val_handlers", + }, + "val_dataset": { + BundleProperty.DESC: "PyTorch dataset object for the validation logic.", + BundleProperty.REQUIRED: False, + BundlePropertyConfig.ID: f"validate{ID_SEP_KEY}dataset", + BundlePropertyConfig.REF_ID: f"validate{ID_SEP_KEY}dataloader{ID_SEP_KEY}dataset", + }, + "val_dataset_data": { + BundleProperty.DESC: "data source for the validation dataset.", + BundleProperty.REQUIRED: False, + BundlePropertyConfig.ID: f"validate{ID_SEP_KEY}dataset{ID_SEP_KEY}data", + BundlePropertyConfig.REF_ID: None, # no reference to this ID + }, + "val_inferer": { + BundleProperty.DESC: "MONAI Inferer object to execute the model computation in validation.", + BundleProperty.REQUIRED: False, + BundlePropertyConfig.ID: f"validate{ID_SEP_KEY}inferer", + BundlePropertyConfig.REF_ID: f"validate{ID_SEP_KEY}evaluator{ID_SEP_KEY}inferer", + }, + "val_preprocessing": { + BundleProperty.DESC: "preprocessing for the validation input data.", + BundleProperty.REQUIRED: False, + BundlePropertyConfig.ID: f"validate{ID_SEP_KEY}preprocessing", + BundlePropertyConfig.REF_ID: f"validate{ID_SEP_KEY}dataset{ID_SEP_KEY}transform", + }, + "val_postprocessing": { + BundleProperty.DESC: "postprocessing for the validation model output data.", + BundleProperty.REQUIRED: False, + BundlePropertyConfig.ID: f"validate{ID_SEP_KEY}postprocessing", + BundlePropertyConfig.REF_ID: f"validate{ID_SEP_KEY}evaluator{ID_SEP_KEY}postprocessing", + }, + "val_key_metric": { + BundleProperty.DESC: "key metric to compute on the validation data.", + BundleProperty.REQUIRED: False, + BundlePropertyConfig.ID: f"validate{ID_SEP_KEY}key_metric", + BundlePropertyConfig.REF_ID: f"validate{ID_SEP_KEY}evaluator{ID_SEP_KEY}key_val_metric", + }, +} + + +InferProperties = { + "bundle_root": { + BundleProperty.DESC: "root path of the bundle.", + BundleProperty.REQUIRED: True, + BundlePropertyConfig.ID: "bundle_root", + }, + "device": { + BundleProperty.DESC: "target device to execute the bundle workflow.", + BundleProperty.REQUIRED: True, + BundlePropertyConfig.ID: "device", + }, + "network_def": { + BundleProperty.DESC: "network module for the inference.", + BundleProperty.REQUIRED: True, + BundlePropertyConfig.ID: "network_def", + }, + "inferer": { + BundleProperty.DESC: "MONAI Inferer object to execute the model computation in inference.", + BundleProperty.REQUIRED: True, + BundlePropertyConfig.ID: "inferer", + }, + "preprocessing": { + BundleProperty.DESC: "preprocessing for the input data.", + BundleProperty.REQUIRED: False, + BundlePropertyConfig.ID: "preprocessing", + BundlePropertyConfig.REF_ID: f"dataset{ID_SEP_KEY}transform", + }, + "postprocessing": { + BundleProperty.DESC: "postprocessing for the model output data.", + BundleProperty.REQUIRED: False, + BundlePropertyConfig.ID: "postprocessing", + BundlePropertyConfig.REF_ID: f"evaluator{ID_SEP_KEY}postprocessing", + }, + "key_metric": { + BundleProperty.DESC: "the key metric during evaluation.", + BundleProperty.REQUIRED: False, + BundlePropertyConfig.ID: "key_metric", + BundlePropertyConfig.REF_ID: f"evaluator{ID_SEP_KEY}key_val_metric", + }, +} diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 3d4b7288d4..55182e429c 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -15,10 +15,8 @@ import json import os import re -import time import warnings from collections.abc import Mapping, Sequence -from logging.config import fileConfig from pathlib import Path from shutil import copyfile from textwrap import dedent @@ -31,12 +29,20 @@ from monai.apps.utils import _basename, download_url, extractall, get_logger from monai.bundle.config_item import ConfigComponent from monai.bundle.config_parser import ConfigParser -from monai.bundle.utils import DEFAULT_EXP_MGMT_SETTINGS, DEFAULT_INFERENCE, DEFAULT_METADATA +from monai.bundle.utils import DEFAULT_INFERENCE, DEFAULT_METADATA +from monai.bundle.workflows import ConfigWorkflow from monai.config import IgniteInfo, PathLike from monai.data import load_net_with_metadata, save_net_with_metadata from monai.networks import convert_to_torchscript, copy_model_state, get_state_dict, save_state -from monai.utils import check_parent_dir, get_equivalent_dtype, min_version, optional_import -from monai.utils.misc import ensure_tuple, pprint_edges +from monai.utils import ( + check_parent_dir, + deprecated_arg, + ensure_tuple, + get_equivalent_dtype, + min_version, + optional_import, + pprint_edges, +) validate, _ = optional_import("jsonschema", name="validate") ValidationError, _ = optional_import("jsonschema.exceptions", name="ValidationError") @@ -573,49 +579,18 @@ def get_bundle_info( return bundle_info[version] -def patch_bundle_tracking(parser: ConfigParser, settings: dict) -> None: - """ - Patch the loaded bundle config with a new handler logic to enable experiment tracking features. - - Args: - parser: loaded config content to patch the handler. - settings: settings for the experiment tracking, should follow the pattern of default settings. - - """ - for k, v in settings["configs"].items(): - if k in settings["handlers_id"]: - engine = parser.get(settings["handlers_id"][k]["id"]) - if engine is not None: - handlers = parser.get(settings["handlers_id"][k]["handlers"]) - if handlers is None: - engine["train_handlers" if k == "trainer" else "val_handlers"] = [v] - else: - handlers.append(v) - elif k not in parser: - parser[k] = v - # save the executed config into file - default_name = f"config_{time.strftime('%Y%m%d_%H%M%S')}.json" - filepath = parser.get("execute_config", None) - if filepath is None: - if "output_dir" not in parser: - # if no "output_dir" in the bundle config, default to "/eval" - parser["output_dir"] = "$@bundle_root + '/eval'" - # experiment management tools can refer to this config item to track the config info - parser["execute_config"] = parser["output_dir"] + f" + '/{default_name}'" - filepath = os.path.join(parser.get_parsed_content("output_dir"), default_name) - Path(filepath).parent.mkdir(parents=True, exist_ok=True) - parser.export_config_file(parser.get(), filepath) - - +@deprecated_arg("runner_id", since="1.1", removed="1.3", new_name="run_id", msg_suffix="please use `run_id` instead.") def run( - runner_id: str | Sequence[str] | None = None, + run_id: str | None = None, + init_id: str | None = None, + final_id: str | None = None, meta_file: str | Sequence[str] | None = None, config_file: str | Sequence[str] | None = None, logging_file: str | None = None, tracking: str | dict | None = None, args_file: str | None = None, **override: Any, -) -> list: +) -> None: """ Specify `config_file` to run monai bundle components and workflows. @@ -640,7 +615,9 @@ def run( python -m monai.bundle run --args_file "/workspace/data/args.json" --config_file Args: - runner_id: ID name of the expected config expression to run, can also be a list of IDs to run in order. + run_id: ID name of the expected config expression to run, default to "run". + init_id: ID name of the expected config expression to initialize before running, default to "initialize". + final_id: ID name of the expected config expression to finalize after running, default to "finalize". meta_file: filepath of the metadata file, if it is a list of file paths, the content of them will be merged. Default to "configs/metadata.json", which is commonly used for bundles in MONAI model zoo. config_file: filepath of the config file, if `None`, must be provided in `args_file`. @@ -648,59 +625,13 @@ def run( logging_file: config file for `logging` module in the program. for more details: https://docs.python.org/3/library/logging.config.html#logging.config.fileConfig. Default to "configs/logging.conf", which is commonly used for bundles in MONAI model zoo. - tracking: enable the experiment tracking feature at runtime with optionally configurable and extensible. - if "mlflow", will add `MLFlowHandler` to the parsed bundle with default logging settings, - if other string, treat it as file path to load the logging settings, if `dict`, - treat it as logging settings, otherwise, use all the default settings. + tracking: if not None, enable the experiment tracking at runtime with optionally configurable and extensible. + if "mlflow", will add `MLFlowHandler` to the parsed bundle with default tracking settings, + if other string, treat it as file path to load the tracking settings. + if `dict`, treat it as tracking settings. will patch the target config content with `tracking handlers` and the top-level items of `configs`. - example of customized settings: - - .. code-block:: python - - tracking = { - "handlers_id": { - "trainer": {"id": "train#trainer", "handlers": "train#handlers"}, - "validator": {"id": "evaluate#evaluator", "handlers": "evaluate#handlers"}, - "evaluator": {"id": "evaluator", "handlers": "handlers"}, - }, - "configs": { - "tracking_uri": "", - "experiment_name": "monai_experiment", - "run_name": None, - "is_not_rank0": ( - "$torch.distributed.is_available() \ - and torch.distributed.is_initialized() and torch.distributed.get_rank() > 0" - ), - "trainer": { - "_target_": "MLFlowHandler", - "_disabled_": "@is_not_rank0", - "tracking_uri": "@tracking_uri", - "experiment_name": "@experiment_name", - "run_name": "@run_name", - "iteration_log": True, - "output_transform": "$monai.handlers.from_engine(['loss'], first=True)", - "close_on_complete": True, - }, - "validator": { - "_target_": "MLFlowHandler", - "_disabled_": "@is_not_rank0", - "tracking_uri": "@tracking_uri", - "experiment_name": "@experiment_name", - "run_name": "@run_name", - "iteration_log": False, - }, - "evaluator": { - "_target_": "MLFlowHandler", - "_disabled_": "@is_not_rank0", - "tracking_uri": "@tracking_uri", - "experiment_name": "@experiment_name", - "run_name": "@run_name", - "iteration_log": False, - "close_on_complete": True, - }, - }, - }, - + for detailed usage examples, plesae check the tutorial: + https://github.com/Project-MONAI/tutorials/blob/main/experiment_management/bundle_integrate_mlflow.ipynb. args_file: a JSON or YAML file to provide default values for `runner_id`, `meta_file`, `config_file`, `logging`, and override pairs. so that the command line inputs can be simplified. override: id-value pairs to override or add the corresponding config content. @@ -710,7 +641,9 @@ def run( _args = _update_args( args=args_file, - runner_id=runner_id, + run_id=run_id, + init_id=init_id, + final_id=final_id, meta_file=meta_file, config_file=config_file, logging_file=logging_file, @@ -720,45 +653,29 @@ def run( if "config_file" not in _args: warnings.warn("`config_file` not provided for 'monai.bundle run'.") _log_input_summary(tag="run", args=_args) - config_file_, meta_file_, runner_id_, logging_file_, tracking_ = _pop_args( + config_file_, meta_file_, init_id_, run_id_, final_id_, logging_file_, tracking_ = _pop_args( _args, config_file=None, meta_file="configs/metadata.json", - runner_id="", + init_id="initialize", + run_id="run", + final_id="finalize", logging_file="configs/logging.conf", tracking=None, ) - if logging_file_ is not None: - if not os.path.exists(logging_file_): - if logging_file_ == "configs/logging.conf": - warnings.warn("default logging file in 'configs/logging.conf' not exists, skip logging.") - else: - raise FileNotFoundError(f"can't find the logging config file: {logging_file_}.") - else: - logger.info(f"set logging properties based on config: {logging_file_}.") - fileConfig(logging_file_, disable_existing_loggers=False) - - parser = ConfigParser() - parser.read_config(f=config_file_) - if meta_file_ is not None: - if not os.path.exists(meta_file_): - warnings.warn("default meta file in 'configs/metadata.json' not exists.") - else: - parser.read_meta(f=meta_file_) - - # the rest key-values in the _args are to override config content - parser.update(pairs=_args) - - # set tracking configs for experiment management - if tracking_ is not None: - if isinstance(tracking_, str) and tracking_ in DEFAULT_EXP_MGMT_SETTINGS: - settings_ = DEFAULT_EXP_MGMT_SETTINGS[tracking_] - else: - settings_ = ConfigParser.load_config_files(tracking_) - patch_bundle_tracking(parser=parser, settings=settings_) - - # resolve and execute the specified runner expressions in the config, return the results - return [parser.get_parsed_content(i, lazy=True, eval_expr=True, instantiate=True) for i in ensure_tuple(runner_id_)] + workflow = ConfigWorkflow( + config_file=config_file_, + meta_file=meta_file_, + logging_file=logging_file_, + init_id=init_id_, + run_id=run_id_, + final_id=final_id_, + tracking=tracking_, + **_args, + ) + workflow.initialize() + workflow.run() + workflow.finalize() def verify_metadata( diff --git a/monai/bundle/workflows.py b/monai/bundle/workflows.py new file mode 100644 index 0000000000..ace08b3ec8 --- /dev/null +++ b/monai/bundle/workflows.py @@ -0,0 +1,368 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import time +import warnings +from abc import ABC, abstractmethod +from logging.config import fileConfig +from pathlib import Path +from typing import Any, Sequence + +from monai.apps.utils import get_logger +from monai.bundle.config_parser import ConfigParser +from monai.bundle.properties import InferProperties, TrainProperties +from monai.bundle.utils import DEFAULT_EXP_MGMT_SETTINGS, EXPR_KEY, ID_REF_KEY, ID_SEP_KEY +from monai.utils import BundleProperty, BundlePropertyConfig + +__all__ = ["BundleWorkflow", "ConfigWorkflow"] + +logger = get_logger(module_name=__name__) + + +class BundleWorkflow(ABC): + """ + Base class for the workflow specification in bundle, it can be a training, evaluation or inference workflow. + It defines the basic interfaces for the bundle workflow behavior: `initialize`, `run`, `finalize`, etc. + And also provides the interface to get / set public properties to interact with a bundle workflow. + + Args: + workflow: specifies the workflow type: "train" or "training" for a training workflow, + or "infer", "inference", "eval", "evaluation" for a inference workflow, + other unsupported string will raise a ValueError. + default to `None` for common workflow. + + """ + + def __init__(self, workflow: str | None = None): + if workflow is None: + self.properties = None + self.workflow = None + return + if workflow.lower() in ("train", "training"): + self.properties = TrainProperties + self.workflow = "train" + elif workflow.lower() in ("infer", "inference", "eval", "evaluation"): + self.properties = InferProperties + self.workflow = "infer" + else: + raise ValueError(f"Unsupported workflow type: '{workflow}'.") + + @abstractmethod + def initialize(self, *args: Any, **kwargs: Any) -> Any: + """ + Initialize the bundle workflow before running. + + """ + raise NotImplementedError() + + @abstractmethod + def run(self, *args: Any, **kwargs: Any) -> Any: + """ + Run the bundle workflow, it can be a training, evaluation or inference. + + """ + raise NotImplementedError() + + @abstractmethod + def finalize(self, *args: Any, **kwargs: Any) -> Any: + """ + Finalize step after the running of bundle workflow. + + """ + raise NotImplementedError() + + @abstractmethod + def _get_property(self, name: str, property: dict) -> Any: + """ + With specified property name and information, get the expected property value. + + Args: + name: the name of target property. + property: other information for the target property, defined in `TrainProperties` or `InferProperties`. + + """ + raise NotImplementedError() + + @abstractmethod + def _set_property(self, name: str, property: dict, value: Any) -> Any: + """ + With specified property name and information, set value for the expected property. + + Args: + name: the name of target property. + property: other information for the target property, defined in `TrainProperties` or `InferProperties`. + value: value to set for the property. + + """ + raise NotImplementedError() + + def __getattr__(self, name): + if self.properties is not None and name in self.properties: + return self._get_property(name=name, property=self.properties[name]) + else: + return self.__getattribute__(name) # getting regular attribute + + def __setattr__(self, name, value): + if name != "properties" and self.properties is not None and name in self.properties: + self._set_property(name=name, property=self.properties[name], value=value) + else: + super().__setattr__(name, value) # setting regular attribute + + def get_workflow_type(self): + """ + Get the workflow type, it can be `None`, "train", or "infer". + + """ + return self.workflow + + def check_properties(self) -> list[str] | None: + """ + Check whether the required properties are existing in the bundle workflow. + If no workflow type specified, return None, otherwise, return a list of required but missing properties. + + """ + if self.properties is None: + return None + return [n for n, p in self.properties.items() if p.get(BundleProperty.REQUIRED, False) and not hasattr(self, n)] + + +class ConfigWorkflow(BundleWorkflow): + """ + Specification for the config-based bundle workflow. + Standardized the `initialize`, `run`, `finalize` behavior in a config-based training, evaluation, or inference. + For more information: https://docs.monai.io/en/latest/mb_specification.html. + + Args: + run_id: ID name of the expected config expression to run, default to "run". + init_id: ID name of the expected config expression to initialize before running, default to "initialize". + final_id: ID name of the expected config expression to finalize after running, default to "finalize". + meta_file: filepath of the metadata file, if it is a list of file paths, the content of them will be merged. + Default to "configs/metadata.json", which is commonly used for bundles in MONAI model zoo. + config_file: filepath of the config file, if it is a list of file paths, the content of them will be merged. + logging_file: config file for `logging` module in the program. for more details: + https://docs.python.org/3/library/logging.config.html#logging.config.fileConfig. + Default to "configs/logging.conf", which is commonly used for bundles in MONAI model zoo. + tracking: if not None, enable the experiment tracking at runtime with optionally configurable and extensible. + if "mlflow", will add `MLFlowHandler` to the parsed bundle with default tracking settings, + if other string, treat it as file path to load the tracking settings. + if `dict`, treat it as tracking settings. + will patch the target config content with `tracking handlers` and the top-level items of `configs`. + for detailed usage examples, plesae check the tutorial: + https://github.com/Project-MONAI/tutorials/blob/main/experiment_management/bundle_integrate_mlflow.ipynb. + workflow: specifies the workflow type: "train" or "training" for a training workflow, + or "infer", "inference", "eval", "evaluation" for a inference workflow, + other unsupported string will raise a ValueError. + default to `None` for common workflow. + override: id-value pairs to override or add the corresponding config content. + e.g. ``--net#input_chns 42``. + + """ + + def __init__( + self, + config_file: str | Sequence[str], + meta_file: str | Sequence[str] | None = "configs/metadata.json", + logging_file: str | None = "configs/logging.conf", + init_id: str = "initialize", + run_id: str = "run", + final_id: str = "finalize", + tracking: str | dict | None = None, + workflow: str | None = None, + **override: dict, + ) -> None: + super().__init__(workflow=workflow) + if logging_file is not None: + if not os.path.exists(logging_file): + if logging_file == "configs/logging.conf": + warnings.warn("Default logging file in 'configs/logging.conf' does not exist, skipping logging.") + else: + raise FileNotFoundError(f"Cannot find the logging config file: {logging_file}.") + else: + logger.info(f"Setting logging properties based on config: {logging_file}.") + fileConfig(logging_file, disable_existing_loggers=False) + + self.parser = ConfigParser() + self.parser.read_config(f=config_file) + if meta_file is not None: + if isinstance(meta_file, str) and not os.path.exists(meta_file): + if meta_file == "configs/metadata.json": + warnings.warn("Default metadata file in 'configs/metadata.json' does not exist, skipping loading.") + else: + raise FileNotFoundError(f"Cannot find the metadata config file: {meta_file}.") + else: + self.parser.read_meta(f=meta_file) + + # the rest key-values in the _args are to override config content + self.parser.update(pairs=override) + self.init_id = init_id + self.run_id = run_id + self.final_id = final_id + # set tracking configs for experiment management + if tracking is not None: + if isinstance(tracking, str) and tracking in DEFAULT_EXP_MGMT_SETTINGS: + settings_ = DEFAULT_EXP_MGMT_SETTINGS[tracking] + else: + settings_ = ConfigParser.load_config_files(tracking) + self.patch_bundle_tracking(parser=self.parser, settings=settings_) + + def initialize(self) -> Any: + """ + Initialize the bundle workflow before running. + + """ + # reset the "reference_resolver" buffer at initialization stage + self.parser.parse(reset=True) + return self._run_expr(id=self.init_id) + + def run(self) -> Any: + """ + Run the bundle workflow, it can be a training, evaluation or inference. + + """ + return self._run_expr(id=self.run_id) + + def finalize(self) -> Any: + """ + Finalize step after the running of bundle workflow. + + """ + return self._run_expr(id=self.final_id) + + def check_properties(self) -> list[str] | None: + """ + Check whether the required properties are existing in the bundle workflow. + If the optional properties have reference in the config, will also check whether the properties are exising. + If no workflow type specified, return None, otherwise, return a list of required but missing properites. + + """ + ret = super().check_properties() + if self.properties is None: + warnings.warn("No available properties had been set, skipping check.") + return None + if ret: + warnings.warn(f"Loaded bundle does not contain the following required properties: {ret}") + # also check whether the optional properties use correct ID name if existing + wrong_props = [] + for n, p in self.properties.items(): + if not p.get(BundleProperty.REQUIRED, False) and not self._check_optional_id(name=n, property=p): + wrong_props.append(n) + if wrong_props: + warnings.warn(f"Loaded bundle defines the following optional properties with wrong ID: {wrong_props}") + if ret is not None: + ret.extend(wrong_props) + return ret + + def _run_expr(self, id: str, **kwargs: dict) -> Any: + return self.parser.get_parsed_content(id, **kwargs) if id in self.parser else None + + def _get_prop_id(self, name: str, property: dict) -> Any: + prop_id = property[BundlePropertyConfig.ID] + if prop_id not in self.parser: + if not property.get(BundleProperty.REQUIRED, False): + return None + else: + raise KeyError(f"Property '{name}' with config ID '{prop_id}' not in the config.") + return prop_id + + def _get_property(self, name: str, property: dict) -> Any: + """ + With specified property name and information, get the parsed property value from config. + + Args: + name: the name of target property. + property: other information for the target property, defined in `TrainProperties` or `InferProperties`. + + """ + if not self.parser.ref_resolver.is_resolved(): + raise RuntimeError("Please execute 'initialize' before getting any parsed content.") + prop_id = self._get_prop_id(name, property) + return self.parser.get_parsed_content(id=prop_id) if prop_id is not None else None + + def _set_property(self, name: str, property: dict, value: Any) -> None: + """ + With specified property name and information, set value for the expected property. + + Args: + name: the name of target property. + property: other information for the target property, defined in `TrainProperties` or `InferProperties`. + value: value to set for the property. + + """ + prop_id = self._get_prop_id(name, property) + if prop_id is not None: + self.parser[prop_id] = value + # must parse the config again after changing the content + self.parser.ref_resolver.reset() + + def _check_optional_id(self, name: str, property: dict) -> bool: + """ + If an optional property has reference in the config, check whether the property is existing. + If `ValidationHandler` is defined for a training workflow, will check whether the optional properties + "evaluator" and "val_interval" are existing. + + Args: + name: the name of target property. + property: other information for the target property, defined in `TrainProperties` or `InferProperties`. + + """ + id = property.get(BundlePropertyConfig.ID, None) + ref_id = property.get(BundlePropertyConfig.REF_ID, None) + if ref_id is None: + # no ID of reference config item, skipping check for this optional property + return True + # check validation `validator` and `interval` properties as the handler index of ValidationHandler is unknown + if name in ("evaluator", "val_interval"): + if f"train{ID_SEP_KEY}handlers" in self.parser: + for h in self.parser[f"train{ID_SEP_KEY}handlers"]: + if h["_target_"] == "ValidationHandler": + ref = h.get(ref_id, None) + else: + ref = self.parser.get(ref_id, None) + if ref is not None and ref != ID_REF_KEY + id: + return False + return True + + @staticmethod + def patch_bundle_tracking(parser: ConfigParser, settings: dict) -> None: + """ + Patch the loaded bundle config with a new handler logic to enable experiment tracking features. + + Args: + parser: loaded config content to patch the handler. + settings: settings for the experiment tracking, should follow the pattern of default settings. + + """ + for k, v in settings["configs"].items(): + if k in settings["handlers_id"]: + engine = parser.get(settings["handlers_id"][k]["id"]) + if engine is not None: + handlers = parser.get(settings["handlers_id"][k]["handlers"]) + if handlers is None: + engine["train_handlers" if k == "trainer" else "val_handlers"] = [v] + else: + handlers.append(v) + elif k not in parser: + parser[k] = v + # save the executed config into file + default_name = f"config_{time.strftime('%Y%m%d_%H%M%S')}.json" + filepath = parser.get("execute_config", None) + if filepath is None: + if "output_dir" not in parser: + # if no "output_dir" in the bundle config, default to "/eval" + parser["output_dir"] = f"{EXPR_KEY}{ID_REF_KEY}bundle_root + '/eval'" + # experiment management tools can refer to this config item to track the config info + parser["execute_config"] = parser["output_dir"] + f" + '/{default_name}'" + filepath = os.path.join(parser.get_parsed_content("output_dir"), default_name) + Path(filepath).parent.mkdir(parents=True, exist_ok=True) + parser.export_config_file(parser.get(), filepath) diff --git a/monai/fl/client/monai_algo.py b/monai/fl/client/monai_algo.py index 497ca4c38c..031143c69b 100644 --- a/monai/fl/client/monai_algo.py +++ b/monai/fl/client/monai_algo.py @@ -23,7 +23,7 @@ from monai.apps.auto3dseg.data_analyzer import DataAnalyzer from monai.apps.utils import get_logger from monai.auto3dseg import SegSummarizer -from monai.bundle import DEFAULT_EXP_MGMT_SETTINGS, ConfigComponent, ConfigItem, ConfigParser, patch_bundle_tracking +from monai.bundle import DEFAULT_EXP_MGMT_SETTINGS, ConfigComponent, ConfigItem, ConfigParser, ConfigWorkflow from monai.engines import SupervisedTrainer, Trainer from monai.fl.client import ClientAlgo, ClientAlgoStats from monai.fl.utils.constants import ( @@ -325,6 +325,7 @@ def _add_config_files(self, config_files): class MonaiAlgo(ClientAlgo, MonaiAlgoStats): """ Implementation of ``ClientAlgo`` to allow federated learning with MONAI bundle configurations. + FIXME: reimplement this class based on the bundle "ConfigWorkflow". Args: bundle_root: path of bundle. @@ -349,37 +350,13 @@ class MonaiAlgo(ClientAlgo, MonaiAlgoStats): multi_gpu: whether to run MonaiAlgo in a multi-GPU setting; defaults to `False`. backend: backend to use for torch.distributed; defaults to "nccl". init_method: init_method for torch.distributed; defaults to "env://". - tracking: enable the experiment tracking feature at runtime with optionally configurable and extensible. - if "mlflow", will add `MLFlowHandler` to the parsed bundle with default logging settings, - if other string, treat it as file path to load the logging settings, if `dict`, - treat it as logging settings, otherwise, use all the default settings. + tracking: if not None, enable the experiment tracking at runtime with optionally configurable and extensible. + if "mlflow", will add `MLFlowHandler` to the parsed bundle with default tracking settings, + if other string, treat it as file path to load the tracking settings. + if `dict`, treat it as tracking settings. will patch the target config content with `tracking handlers` and the top-level items of `configs`. - example of customized settings: - - .. code-block:: python - - tracking = { - "handlers_id": { - "trainer": {"id": "train#trainer", "handlers": "train#handlers"}, - "validator": {"id": "evaluate#evaluator", "handlers": "evaluate#handlers"}, - "evaluator": {"id": "evaluator", "handlers": "handlers"}, - }, - "configs": { - "tracking_uri": "", - "trainer": { - "_target_": "MLFlowHandler", - "tracking_uri": "@tracking_uri", - "iteration_log": True, - "output_transform": "$monai.handlers.from_engine(['loss'], first=True)", - }, - "validator": { - "_target_": "MLFlowHandler", "tracking_uri": "@tracking_uri", "iteration_log": False, - }, - "evaluator": { - "_target_": "MLFlowHandler", "tracking_uri": "@tracking_uri", "iteration_log": False, - }, - }, - }, + for detailed usage examples, plesae check the tutorial: + https://github.com/Project-MONAI/tutorials/blob/main/experiment_management/bundle_integrate_mlflow.ipynb. """ @@ -513,8 +490,8 @@ def initialize(self, extra=None): settings_ = DEFAULT_EXP_MGMT_SETTINGS[self.tracking] else: settings_ = ConfigParser.load_config_files(self.tracking) - patch_bundle_tracking(parser=self.train_parser, settings=settings_) - patch_bundle_tracking(parser=self.eval_parser, settings=settings_) + ConfigWorkflow.patch_bundle_tracking(parser=self.train_parser, settings=settings_) + ConfigWorkflow.patch_bundle_tracking(parser=self.eval_parser, settings=settings_) # Get trainer, evaluator self.trainer = self.train_parser.get_parsed_content( diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index af8a02b2d3..8210ec924c 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -20,6 +20,8 @@ Average, BlendMode, BoxModeName, + BundleProperty, + BundlePropertyConfig, ChannelMatching, ColorOrder, CommonKeys, diff --git a/monai/utils/enums.py b/monai/utils/enums.py index 77529acdef..3c3470b9f7 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -54,6 +54,8 @@ "HoVerNetMode", "HoVerNetBranch", "LazyAttr", + "BundleProperty", + "BundlePropertyConfig", ] @@ -636,3 +638,26 @@ class LazyAttr(StrEnum): INTERP_MODE = "lazy_interpolation_mode" DTYPE = "lazy_dtype" ALIGN_CORNERS = "lazy_align_corners" + + +class BundleProperty(StrEnum): + """ + Bundle property fields: + `DESC` is the description of the property. + `REQUIRED` is flag to indicate whether the property is required or optional. + """ + + DESC = "description" + REQUIRED = "required" + + +class BundlePropertyConfig(StrEnum): + """ + additional bundle property fields for config based bundle workflow: + `ID` is the config item ID of the property. + `REF_ID` is the ID of config item which is supposed to refer to this property. + this field is only useful to check the optional property ID. + """ + + ID = "id" + REF_ID = "refer_id" diff --git a/tests/min_tests.py b/tests/min_tests.py index b50c1c5e8c..1b1f4f450a 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -193,6 +193,7 @@ def run_testsuit(): "test_fastmri_reader", "test_metrics_reloaded", "test_spatial_combine_transforms", + "test_bundle_workflow", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_bundle_workflow.py b/tests/test_bundle_workflow.py new file mode 100644 index 0000000000..948c351a1c --- /dev/null +++ b/tests/test_bundle_workflow.py @@ -0,0 +1,251 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import shutil +import tempfile +import unittest +from copy import deepcopy + +import nibabel as nib +import numpy as np +import torch +from parameterized import parameterized + +from monai.bundle import BundleWorkflow, ConfigWorkflow +from monai.data import DataLoader, Dataset +from monai.engines import SupervisedEvaluator +from monai.inferers import SimpleInferer, SlidingWindowInferer +from monai.networks.nets import UNet +from monai.transforms import ( + Activationsd, + AsDiscreted, + Compose, + EnsureChannelFirstd, + LoadImage, + LoadImaged, + SaveImaged, + ScaleIntensityd, +) +from monai.utils import BundleProperty, set_determinism + +TEST_CASE_1 = [os.path.join(os.path.dirname(__file__), "testing_data", "inference.json")] + +TEST_CASE_2 = [os.path.join(os.path.dirname(__file__), "testing_data", "inference.yaml")] + +TEST_CASE_3 = [os.path.join(os.path.dirname(__file__), "testing_data", "config_fl_train.json")] + + +class NonConfigWorkflow(BundleWorkflow): + """ + Test class simulates the bundle workflow defined by Python script directly. + + """ + + def __init__(self, filename, output_dir): + super().__init__(workflow="inference") + self.filename = filename + self.output_dir = output_dir + self._bundle_root = "will override" + self._device = torch.device("cpu") + self._network_def = None + self._inferer = None + self._preprocessing = None + self._postprocessing = None + self._evaluator = None + + def initialize(self): + set_determinism(0) + if self._preprocessing is None: + self._preprocessing = Compose( + [LoadImaged(keys="image"), EnsureChannelFirstd(keys="image"), ScaleIntensityd(keys="image")] + ) + dataset = Dataset(data=[{"image": self.filename}], transform=self._preprocessing) + dataloader = DataLoader(dataset, batch_size=1, num_workers=4) + + if self._network_def is None: + self._network_def = UNet( + spatial_dims=3, + in_channels=1, + out_channels=2, + channels=[2, 2, 4, 8, 4], + strides=[2, 2, 2, 2], + num_res_units=2, + norm="batch", + ) + if self._inferer is None: + self._inferer = SlidingWindowInferer(roi_size=(64, 64, 32), sw_batch_size=4, overlap=0.25) + + if self._postprocessing is None: + self._postprocessing = Compose( + [ + Activationsd(keys="pred", softmax=True), + AsDiscreted(keys="pred", argmax=True), + SaveImaged(keys="pred", output_dir=self.output_dir, output_postfix="seg"), + ] + ) + + self._evaluator = SupervisedEvaluator( + device=self._device, + val_data_loader=dataloader, + network=self._network_def.to(self._device), + inferer=self._inferer, + postprocessing=self._postprocessing, + amp=False, + ) + + def run(self): + self._evaluator.run() + + def finalize(self): + return True + + def _get_property(self, name, property): + if name == "bundle_root": + return self._bundle_root + if name == "device": + return self._device + if name == "network_def": + return self._network_def + if name == "inferer": + return self._inferer + if name == "preprocessing": + return self._preprocessing + if name == "postprocessing": + return self._postprocessing + if property[BundleProperty.REQUIRED]: + raise ValueError(f"unsupported property '{name}' is required in the bundle properties.") + + def _set_property(self, name, property, value): + if name == "bundle_root": + self._bundle_root = value + elif name == "device": + self._device = value + elif name == "network_def": + self._network_def = value + elif name == "inferer": + self._inferer = value + elif name == "preprocessing": + self._preprocessing = value + elif name == "postprocessing": + self._postprocessing = value + elif property[BundleProperty.REQUIRED]: + raise ValueError(f"unsupported property '{name}' is required in the bundle properties.") + + +class TestBundleWorkflow(unittest.TestCase): + def setUp(self): + self.data_dir = tempfile.mkdtemp() + self.expected_shape = (128, 128, 128) + test_image = np.random.rand(*self.expected_shape) + self.filename = os.path.join(self.data_dir, "image.nii") + nib.save(nib.Nifti1Image(test_image, np.eye(4)), self.filename) + + def tearDown(self): + shutil.rmtree(self.data_dir) + + def _test_inferer(self, inferer): + # should initialize before parsing any bundle content + inferer.initialize() + # test required and optional properties + self.assertListEqual(inferer.check_properties(), []) + # test read / write the properties, note that we don't assume it as JSON or YAML config here + self.assertEqual(inferer.bundle_root, "will override") + self.assertEqual(inferer.device, torch.device("cpu")) + net = inferer.network_def + self.assertTrue(isinstance(net, UNet)) + sliding_window = inferer.inferer + self.assertTrue(isinstance(sliding_window, SlidingWindowInferer)) + preprocessing = inferer.preprocessing + self.assertTrue(isinstance(preprocessing, Compose)) + postprocessing = inferer.postprocessing + self.assertTrue(isinstance(postprocessing, Compose)) + # test optional properties get + self.assertTrue(inferer.key_metric is None) + inferer.bundle_root = "/workspace/data/spleen_ct_segmentation" + inferer.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + inferer.network_def = deepcopy(net) + inferer.inferer = deepcopy(sliding_window) + inferer.preprocessing = deepcopy(preprocessing) + inferer.postprocessing = deepcopy(postprocessing) + # test optional properties set + inferer.key_metric = "set optional properties" + + # should initialize and parse again as changed the bundle content + inferer.initialize() + inferer.run() + inferer.finalize() + # verify inference output + loader = LoadImage(image_only=True) + pred_file = os.path.join(self.data_dir, "image", "image_seg.nii.gz") + self.assertTupleEqual(loader(pred_file).shape, self.expected_shape) + os.remove(pred_file) + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + def test_inference_config(self, config_file): + override = { + "network": "$@network_def.to(@device)", + "dataset#_target_": "Dataset", + "dataset#data": [{"image": self.filename}], + "postprocessing#transforms#2#output_postfix": "seg", + "output_dir": self.data_dir, + } + # test standard MONAI model-zoo config workflow + inferer = ConfigWorkflow( + workflow="infer", + config_file=config_file, + logging_file=os.path.join(os.path.dirname(__file__), "testing_data", "logging.conf"), + **override, + ) + self._test_inferer(inferer) + + @parameterized.expand([TEST_CASE_3]) + def test_train_config(self, config_file): + # test standard MONAI model-zoo config workflow + trainer = ConfigWorkflow( + workflow="train", + config_file=config_file, + logging_file=os.path.join(os.path.dirname(__file__), "testing_data", "logging.conf"), + init_id="initialize", + run_id="run", + final_id="finalize", + ) + # should initialize before parsing any bundle content + trainer.initialize() + # test required and optional properties + self.assertListEqual(trainer.check_properties(), []) + # test read / write the properties + dataset = trainer.train_dataset + self.assertTrue(isinstance(dataset, Dataset)) + inferer = trainer.train_inferer + self.assertTrue(isinstance(inferer, SimpleInferer)) + # test optional properties get + self.assertTrue(trainer.train_key_metric is None) + trainer.train_dataset = deepcopy(dataset) + trainer.train_inferer = deepcopy(inferer) + # test optional properties set + trainer.train_key_metric = "set optional properties" + + # should initialize and parse again as changed the bundle content + trainer.initialize() + trainer.run() + trainer.finalize() + + def test_non_config(self): + # test user defined python style workflow + inferer = NonConfigWorkflow(self.filename, self.data_dir) + self._test_inferer(inferer) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_integration_bundle_run.py b/tests/test_integration_bundle_run.py index e2e313ca91..24c0286133 100644 --- a/tests/test_integration_bundle_run.py +++ b/tests/test_integration_bundle_run.py @@ -21,6 +21,7 @@ import nibabel as nib import numpy as np +import torch from parameterized import parameterized from monai.bundle import ConfigParser @@ -56,6 +57,7 @@ def test_tiny(self): json.dump( { "trainer": {"_target_": "tests.test_integration_bundle_run._Runnable42", "val": 42}, + # keep this test case to cover the "runner_id" arg "training": "$@trainer.run()", }, f, @@ -111,9 +113,10 @@ def test_shape(self, config_file, expected_shape): override = "--network $@network_def.to(@device) --dataset#_target_ Dataset" else: override = f"--network %{overridefile1}#move_net --dataset#_target_ %{overridefile2}" + device = "$torch.device('cuda:0')" if torch.cuda.is_available() else "$torch.device('cpu')" # test with `monai.bundle` as CLI entry directly - cmd = "-m monai.bundle run evaluating --postprocessing#transforms#2#output_postfix seg" - cmd += f" {override} --no_epoch False --output_dir {tempdir}" + cmd = "-m monai.bundle run --postprocessing#transforms#2#output_postfix seg" + cmd += f" {override} --no_epoch False --output_dir {tempdir} --device {device}" la = ["coverage", "run"] + cmd.split(" ") + ["--meta_file", meta_file] + ["--config_file", config_file] test_env = os.environ.copy() print(f"CUDA_VISIBLE_DEVICES in {__file__}", test_env.get("CUDA_VISIBLE_DEVICES")) @@ -124,8 +127,8 @@ def test_shape(self, config_file, expected_shape): tracking_uri = path_to_uri(tempdir) + "/mlflow_override2" # test override experiment management configs # here test the script with `google fire` tool as CLI - cmd = "-m fire monai.bundle.scripts run --runner_id evaluating --tracking mlflow --evaluator#amp False" - cmd += f" --tracking_uri {tracking_uri} {override} --output_dir {tempdir}" + cmd = "-m fire monai.bundle.scripts run --tracking mlflow --evaluator#amp False" + cmd += f" --tracking_uri {tracking_uri} {override} --output_dir {tempdir} --device {device}" la = ["coverage", "run"] + cmd.split(" ") + ["--meta_file", meta_file] + ["--config_file", config_file] command_line_tests(la) self.assertTupleEqual(loader(os.path.join(tempdir, "image", "image_trans.nii.gz")).shape, expected_shape) diff --git a/tests/testing_data/config_fl_train.json b/tests/testing_data/config_fl_train.json index f53a95bc02..bdb9792fce 100644 --- a/tests/testing_data/config_fl_train.json +++ b/tests/testing_data/config_fl_train.json @@ -1,6 +1,7 @@ { "bundle_root": "tests/testing_data", "dataset_dir": "@bundle_root", + "val_interval": 1, "imports": [ "$import os" ], @@ -66,13 +67,6 @@ "min_zoom": 0.9, "max_zoom": 1.1, "prob": 0.5 - }, - { - "_target_": "ToTensord", - "keys": [ - "image", - "label" - ] } ], "preprocessing": { @@ -104,6 +98,12 @@ "_target_": "SimpleInferer" }, "handlers": [ + { + "_target_": "ValidationHandler", + "validator": "@validate#evaluator", + "epoch_level": true, + "interval": "@val_interval" + }, { "_target_": "StatsHandler", "tag_name": "train_loss", @@ -121,5 +121,82 @@ "inferer": "@train#inferer", "train_handlers": "@train#handlers" } - } + }, + "validate": { + "preprocessing": { + "_target_": "Compose", + "transforms": [ + { + "_target_": "LoadImaged", + "keys": [ + "image" + ], + "image_only": true + }, + { + "_target_": "EnsureChannelFirstD", + "keys": [ + "image" + ] + }, + { + "_target_": "ScaleIntensityd", + "keys": [ + "image" + ] + } + ] + }, + "dataset": { + "_target_": "Dataset", + "data": [ + { + "image": "$os.path.join(@dataset_dir, 'image0.jpeg')", + "label": 0 + }, + { + "image": "$os.path.join(@dataset_dir, 'image1.jpeg')", + "label": 1 + } + ], + "transform": "@validate#preprocessing" + }, + "dataloader": { + "_target_": "DataLoader", + "dataset": "@validate#dataset", + "batch_size": 1, + "shuffle": true, + "num_workers": 2 + }, + "inferer": { + "_target_": "SimpleInferer" + }, + "postprocessing": { + "_target_": "Compose", + "transforms": [ + { + "_target_": "Activationsd", + "keys": "pred", + "softmax": true + } + ] + }, + "evaluator": { + "_target_": "SupervisedEvaluator", + "device": "@device", + "val_data_loader": "@validate#dataloader", + "network": "@network", + "inferer": "@validate#inferer", + "postprocessing": "@validate#postprocessing" + } + }, + "initialize": [ + "$monai.utils.set_determinism(seed=123)" + ], + "run": [ + "$@train#trainer.run()" + ], + "finalize": [ + "$monai.utils.set_determinism(seed=None)" + ] } diff --git a/tests/testing_data/inference.json b/tests/testing_data/inference.json index c222667101..7b1f9c20cf 100644 --- a/tests/testing_data/inference.json +++ b/tests/testing_data/inference.json @@ -1,9 +1,10 @@ { "dataset_dir": "/workspace/data/Task09_Spleen", + "bundle_root": "will override", "output_dir": "need override", "prediction_shape": "prediction shape:", "import_glob": "$import glob", - "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')", + "device": "$torch.device('cpu')", "print_test_name": "$print('json_test')", "print_glob_file": "$print(glob.__file__)", "network_def": { @@ -112,8 +113,10 @@ "postprocessing": "@postprocessing", "amp": false }, - "evaluating": [ - "$monai.utils.set_determinism(0)", + "initialize": [ + "$monai.utils.set_determinism(0)" + ], + "run": [ "$@evaluator.run()" ] } diff --git a/tests/testing_data/inference.yaml b/tests/testing_data/inference.yaml index a289b549db..0343ea0bae 100644 --- a/tests/testing_data/inference.yaml +++ b/tests/testing_data/inference.yaml @@ -1,8 +1,9 @@ --- dataset_dir: "/workspace/data/Task09_Spleen" +bundle_root: "will override" output_dir: "need override" prediction_shape: "prediction shape:" -device: "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')" +device: "$torch.device('cpu')" print_test_name: "$print('yaml_test')" network_def: _target_: UNet @@ -80,6 +81,9 @@ evaluator: inferer: "@inferer" postprocessing: "@postprocessing" amp: false -evaluating: +initialize: - "$monai.utils.set_determinism(0)" +run: - "$@evaluator.run()" +finalize: + - "$print('test finalize section.')"