diff --git a/monai/apps/pathology/engines/__init__.py b/monai/apps/pathology/engines/__init__.py new file mode 100644 index 0000000000..68c084d40d --- /dev/null +++ b/monai/apps/pathology/engines/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .utils import PrepareBatchHoVerNet diff --git a/monai/apps/pathology/engines/utils.py b/monai/apps/pathology/engines/utils.py new file mode 100644 index 0000000000..3a190a146b --- /dev/null +++ b/monai/apps/pathology/engines/utils.py @@ -0,0 +1,54 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Optional, Sequence, Union + +import torch + +from monai.engines import PrepareBatch, PrepareBatchExtraInput +from monai.utils import ensure_tuple +from monai.utils.enums import HoVerNetBranch + +__all__ = ["PrepareBatchHoVerNet"] + + +class PrepareBatchHoVerNet(PrepareBatch): + """ + Customized prepare batch callable for trainers or evaluators which support label to be a dictionary. + Extra items are specified by the `extra_keys` parameter and are extracted from the input dictionary (ie. the batch). + This assumes label is a dictionary. + + Args: + extra_keys: If a sequence of strings is provided, values from the input dictionary are extracted from + those keys and passed to the nework as extra positional arguments. + """ + + def __init__(self, extra_keys: Sequence[str]) -> None: + if len(ensure_tuple(extra_keys)) != 2: + raise ValueError(f"length of `extra_keys` should be 2, get {len(ensure_tuple(extra_keys))}") + self.prepare_batch = PrepareBatchExtraInput(extra_keys) + + def __call__( + self, + batchdata: Dict[str, torch.Tensor], + device: Optional[Union[str, torch.device]] = None, + non_blocking: bool = False, + **kwargs, + ): + """ + Args `batchdata`, `device`, `non_blocking` refer to the ignite API: + https://pytorch.org/ignite/v0.4.8/generated/ignite.engine.create_supervised_trainer.html. + `kwargs` supports other args for `Tensor.to()` API. + """ + image, _label, extra_label, _ = self.prepare_batch(batchdata, device, non_blocking, **kwargs) + label = {HoVerNetBranch.NP: _label, HoVerNetBranch.NC: extra_label[0], HoVerNetBranch.HV: extra_label[1]} + + return image, label diff --git a/tests/min_tests.py b/tests/min_tests.py index 915e67e120..765cec8adf 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -140,6 +140,7 @@ def run_testsuit(): "test_png_saver", "test_prepare_batch_default", "test_prepare_batch_extra_input", + "test_prepare_batch_hovernet", "test_rand_grid_patch", "test_rand_rotate", "test_rand_rotated", diff --git a/tests/test_prepare_batch_hovernet.py b/tests/test_prepare_batch_hovernet.py new file mode 100644 index 0000000000..9aed8e94c7 --- /dev/null +++ b/tests/test_prepare_batch_hovernet.py @@ -0,0 +1,66 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.apps.pathology.engines import PrepareBatchHoVerNet +from monai.engines import SupervisedEvaluator +from monai.utils.enums import HoVerNetBranch +from tests.utils import assert_allclose + +TEST_CASE_0 = [ + {"extra_keys": ["extra_label1", "extra_label2"]}, + {HoVerNetBranch.NP: torch.tensor([1, 2]), HoVerNetBranch.NC: torch.tensor([4, 4]), HoVerNetBranch.HV: 16}, +] + + +class TestNet(torch.nn.Module): + def forward(self, x: torch.Tensor): + return {HoVerNetBranch.NP: torch.tensor([1, 2]), HoVerNetBranch.NC: torch.tensor([4, 4]), HoVerNetBranch.HV: 16} + + +class TestPrepareBatchHoVerNet(unittest.TestCase): + @parameterized.expand([TEST_CASE_0]) + def test_content(self, input_args, expected_value): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dataloader = [ + { + "image": torch.tensor([1, 2]), + "label": torch.tensor([1, 2]), + "extra_label1": torch.tensor([3, 4]), + "extra_label2": 16, + } + ] + # set up engine + evaluator = SupervisedEvaluator( + device=device, + val_data_loader=dataloader, + epoch_length=1, + network=TestNet(), + non_blocking=True, + prepare_batch=PrepareBatchHoVerNet(**input_args), + decollate=False, + ) + evaluator.run() + output = evaluator.state.output + assert_allclose(output["image"], torch.tensor([1, 2], device=device)) + for k, v in output["pred"].items(): + if isinstance(v, torch.Tensor): + assert_allclose(v, expected_value[k].to(device)) + else: + self.assertEqual(v, expected_value[k]) + + +if __name__ == "__main__": + unittest.main()