Skip to content

Commit

Permalink
Extend PrepareBatchExtraInput to work with HoVerNet (#5448)
Browse files Browse the repository at this point in the history
Signed-off-by: KumoLiu <yunl@nvidia.com>

Fixes #5028.

### Description
Since the output for HoVerNet is a dictionary, we already have
`PrepareBatchExtraInput` to support extra input data for the network,
but it still can't meet the requirements. This PR is to extend
`PrepareBatchExtraInput` to make the label can be a dictionary.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

Signed-off-by: KumoLiu <yunl@nvidia.com>
Co-authored-by: Nic Ma <nma@nvidia.com>
  • Loading branch information
KumoLiu and Nic-Ma authored Nov 3, 2022
1 parent b05ef9c commit c38d503
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 0 deletions.
12 changes: 12 additions & 0 deletions monai/apps/pathology/engines/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .utils import PrepareBatchHoVerNet
54 changes: 54 additions & 0 deletions monai/apps/pathology/engines/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, Optional, Sequence, Union

import torch

from monai.engines import PrepareBatch, PrepareBatchExtraInput
from monai.utils import ensure_tuple
from monai.utils.enums import HoVerNetBranch

__all__ = ["PrepareBatchHoVerNet"]


class PrepareBatchHoVerNet(PrepareBatch):
"""
Customized prepare batch callable for trainers or evaluators which support label to be a dictionary.
Extra items are specified by the `extra_keys` parameter and are extracted from the input dictionary (ie. the batch).
This assumes label is a dictionary.
Args:
extra_keys: If a sequence of strings is provided, values from the input dictionary are extracted from
those keys and passed to the nework as extra positional arguments.
"""

def __init__(self, extra_keys: Sequence[str]) -> None:
if len(ensure_tuple(extra_keys)) != 2:
raise ValueError(f"length of `extra_keys` should be 2, get {len(ensure_tuple(extra_keys))}")
self.prepare_batch = PrepareBatchExtraInput(extra_keys)

def __call__(
self,
batchdata: Dict[str, torch.Tensor],
device: Optional[Union[str, torch.device]] = None,
non_blocking: bool = False,
**kwargs,
):
"""
Args `batchdata`, `device`, `non_blocking` refer to the ignite API:
https://pytorch.org/ignite/v0.4.8/generated/ignite.engine.create_supervised_trainer.html.
`kwargs` supports other args for `Tensor.to()` API.
"""
image, _label, extra_label, _ = self.prepare_batch(batchdata, device, non_blocking, **kwargs)
label = {HoVerNetBranch.NP: _label, HoVerNetBranch.NC: extra_label[0], HoVerNetBranch.HV: extra_label[1]}

return image, label
1 change: 1 addition & 0 deletions tests/min_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def run_testsuit():
"test_png_saver",
"test_prepare_batch_default",
"test_prepare_batch_extra_input",
"test_prepare_batch_hovernet",
"test_rand_grid_patch",
"test_rand_rotate",
"test_rand_rotated",
Expand Down
66 changes: 66 additions & 0 deletions tests/test_prepare_batch_hovernet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch
from parameterized import parameterized

from monai.apps.pathology.engines import PrepareBatchHoVerNet
from monai.engines import SupervisedEvaluator
from monai.utils.enums import HoVerNetBranch
from tests.utils import assert_allclose

TEST_CASE_0 = [
{"extra_keys": ["extra_label1", "extra_label2"]},
{HoVerNetBranch.NP: torch.tensor([1, 2]), HoVerNetBranch.NC: torch.tensor([4, 4]), HoVerNetBranch.HV: 16},
]


class TestNet(torch.nn.Module):
def forward(self, x: torch.Tensor):
return {HoVerNetBranch.NP: torch.tensor([1, 2]), HoVerNetBranch.NC: torch.tensor([4, 4]), HoVerNetBranch.HV: 16}


class TestPrepareBatchHoVerNet(unittest.TestCase):
@parameterized.expand([TEST_CASE_0])
def test_content(self, input_args, expected_value):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataloader = [
{
"image": torch.tensor([1, 2]),
"label": torch.tensor([1, 2]),
"extra_label1": torch.tensor([3, 4]),
"extra_label2": 16,
}
]
# set up engine
evaluator = SupervisedEvaluator(
device=device,
val_data_loader=dataloader,
epoch_length=1,
network=TestNet(),
non_blocking=True,
prepare_batch=PrepareBatchHoVerNet(**input_args),
decollate=False,
)
evaluator.run()
output = evaluator.state.output
assert_allclose(output["image"], torch.tensor([1, 2], device=device))
for k, v in output["pred"].items():
if isinstance(v, torch.Tensor):
assert_allclose(v, expected_value[k].to(device))
else:
self.assertEqual(v, expected_value[k])


if __name__ == "__main__":
unittest.main()

0 comments on commit c38d503

Please sign in to comment.