Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions docs/source/config_syntax.md
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,10 @@ _Description:_ `_requires_`, `_disabled_`, `_desc_`, and `_mode_` are optional k
- `_mode_` specifies the operating mode when the component is instantiated or the callable is called.
it currently supports the following values:
- `"default"` (default) -- return the return value of ``_target_(**kwargs)``
- `"partial"` -- return a partial function of ``functools.partial(_target_, **kwargs)`` (this is often
useful when some portion of the full set of arguments are supplied to the ``_target_``, and the user wants to
call it with additional arguments later).
- `"callable"` -- return a callable, either as ``_target_`` itself or, if ``kwargs`` are provided, as a
partial function of ``functools.partial(_target_, **kwargs)``. Useful for defining a class or function
that will be instantied or called later. User can pre-define some arguments to the ``_target_`` and call
it with additional arguments later.
- `"debug"` -- execute with debug prompt and return the return value of ``pdb.runcall(_target_, **kwargs)``,
see also [`pdb.runcall`](https://docs.python.org/3/library/pdb.html#pdb.runcall).

Expand Down
2 changes: 1 addition & 1 deletion monai/bundle/config_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ class ConfigComponent(ConfigItem, Instantiable):
- ``"_mode_"`` (optional): operating mode for invoking the callable ``component`` defined by ``"_target_"``:

- ``"default"``: returns ``component(**kwargs)``
- ``"partial"``: returns ``functools.partial(component, **kwargs)``
- ``"callable"``: returns ``component`` or, if ``kwargs`` are provided, ``functools.partial(component, **kwargs)``
- ``"debug"``: returns ``pdb.runcall(component, **kwargs)``

Other fields in the config content are input arguments to the python module.
Expand Down
2 changes: 1 addition & 1 deletion monai/utils/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ class CompInitMode(StrEnum):
"""

DEFAULT = "default"
PARTIAL = "partial"
CALLABLE = "callable"
DEBUG = "debug"


Expand Down
11 changes: 7 additions & 4 deletions monai/utils/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,14 @@ def instantiate(__path: str, __mode: str, **kwargs: Any) -> Any:

Args:
__path: if a string is provided, it's interpreted as the full path of the target class or function component.
If a callable is provided, ``__path(**kwargs)`` or ``functools.partial(__path, **kwargs)`` will be returned.
If a callable is provided, ``__path(**kwargs)`` will be invoked and returned for ``__mode="default"``.
For ``__mode="callable"``, the callable will be returned as ``__path`` or, if ``kwargs`` are provided,
as ``functools.partial(__path, **kwargs)`` for future invoking.

__mode: the operating mode for invoking the (callable) ``component`` represented by ``__path``:

- ``"default"``: returns ``component(**kwargs)``
- ``"partial"``: returns ``functools.partial(component, **kwargs)``
- ``"callable"``: returns ``component`` or, if ``kwargs`` are provided, ``functools.partial(component, **kwargs)``
- ``"debug"``: returns ``pdb.runcall(component, **kwargs)``

kwargs: keyword arguments to the callable represented by ``__path``.
Expand All @@ -259,8 +262,8 @@ def instantiate(__path: str, __mode: str, **kwargs: Any) -> Any:
return component
if m == CompInitMode.DEFAULT:
return component(**kwargs)
if m == CompInitMode.PARTIAL:
return partial(component, **kwargs)
if m == CompInitMode.CALLABLE:
return partial(component, **kwargs) if kwargs else component
if m == CompInitMode.DEBUG:
warnings.warn(
f"\n\npdb: instantiating component={component}, mode={m}\n"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_config_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
TEST_CASE_5 = [{"_target_": "LoadImaged", "_disabled_": "true", "keys": ["image"]}, dict]
# test non-monai modules and excludes
TEST_CASE_6 = [{"_target_": "torch.optim.Adam", "params": torch.nn.PReLU().parameters(), "lr": 1e-4}, torch.optim.Adam]
TEST_CASE_7 = [{"_target_": "decollate_batch", "detach": True, "pad": True, "_mode_": "partial"}, partial]
TEST_CASE_7 = [{"_target_": "decollate_batch", "detach": True, "pad": True, "_mode_": "callable"}, partial]
# test args contains "name" field
TEST_CASE_8 = [
{"_target_": "RandTorchVisiond", "keys": "image", "name": "ColorJitter", "brightness": 0.25},
Expand Down
6 changes: 2 additions & 4 deletions tests/test_config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def case_pdb_inst(sarg=None):


class TestClass:

@staticmethod
def compute(a, b, func=lambda x, y: x + y):
return func(a, b)
Expand Down Expand Up @@ -127,7 +126,6 @@ def __call__(self, a, b):


class TestConfigParser(unittest.TestCase):

def test_config_content(self):
test_config = {"preprocessing": [{"_target_": "LoadImage"}], "dataset": {"_target_": "Dataset"}}
parser = ConfigParser(config=test_config)
Expand Down Expand Up @@ -183,7 +181,7 @@ def test_function(self, config):
parser = ConfigParser(config=config, globals={"TestClass": TestClass})
for id in config:
if id in ("compute", "cls_compute"):
parser[f"{id}#_mode_"] = "partial"
parser[f"{id}#_mode_"] = "callable"
func = parser.get_parsed_content(id=id)
self.assertTrue(id in parser.ref_resolver.resolved_content)
if id == "error_func":
Expand Down Expand Up @@ -279,7 +277,7 @@ def test_lambda_reference(self):

def test_non_str_target(self):
configs = {
"fwd": {"_target_": "$@model.forward", "x": "$torch.rand(1, 3, 256, 256)", "_mode_": "partial"},
"fwd": {"_target_": "$@model.forward", "x": "$torch.rand(1, 3, 256, 256)", "_mode_": "callable"},
"model": {"_target_": "monai.networks.nets.resnet.resnet18", "pretrained": False, "spatial_dims": 2},
}
self.assertTrue(callable(ConfigParser(config=configs).fwd))
Expand Down