Skip to content

Commit fa7411a

Browse files
authored
add device in HoVerNetNuclearTypePostProcessing and HoVerNetInstanceMapPostProcessing (#6333)
Fixes # . ### Description Since some operations in post-processing of HoVerNet will convert data to numpy. And most of the time we need to calculate the metric from the model output and label which should both be on CUDA. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: KumoLiu <yunl@nvidia.com>
1 parent d14e8f0 commit fa7411a

File tree

2 files changed

+20
-6
lines changed

2 files changed

+20
-6
lines changed

monai/apps/pathology/transforms/post/array.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from monai.transforms.utils_pytorch_numpy_unification import max, maximum, min, sum, unique
3232
from monai.utils import TransformBackends, convert_to_numpy, optional_import
3333
from monai.utils.misc import ensure_tuple_rep
34-
from monai.utils.type_conversion import convert_to_dst_type
34+
from monai.utils.type_conversion import convert_to_dst_type, convert_to_tensor
3535

3636
label, _ = optional_import("scipy.ndimage.measurements", name="label")
3737
disk, _ = optional_import("skimage.morphology", name="disk")
@@ -671,6 +671,7 @@ class HoVerNetInstanceMapPostProcessing(Transform):
671671
min_num_points: minimum number of points to be considered as a contour. Defaults to 3.
672672
contour_level: an optional value for `skimage.measure.find_contours` to find contours in the array.
673673
If not provided, the level is set to `(max(image) + min(image)) / 2`.
674+
device: target device to put the output Tensor data.
674675
"""
675676

676677
def __init__(
@@ -686,9 +687,10 @@ def __init__(
686687
watershed_connectivity: int | None = 1,
687688
min_num_points: int = 3,
688689
contour_level: float | None = None,
690+
device: str | torch.device | None = None,
689691
) -> None:
690692
super().__init__()
691-
693+
self.device = device
692694
self.generate_watershed_mask = GenerateWatershedMask(
693695
activation=activation, threshold=mask_threshold, min_object_size=min_object_size
694696
)
@@ -742,7 +744,7 @@ def __call__( # type: ignore
742744
"centroid": instance_centroid,
743745
"contour": instance_contour,
744746
}
745-
747+
instance_map = convert_to_tensor(instance_map, device=self.device)
746748
return instance_info, instance_map
747749

748750

@@ -758,13 +760,19 @@ class HoVerNetNuclearTypePostProcessing(Transform):
758760
threshold: an optional float value to threshold to binarize probability map.
759761
If not provided, defaults to 0.5 when activation is not "softmax", otherwise None.
760762
return_type_map: whether to calculate and return pixel-level type map.
763+
device: target device to put the output Tensor data.
761764
762765
"""
763766

764767
def __init__(
765-
self, activation: str | Callable = "softmax", threshold: float | None = None, return_type_map: bool = True
768+
self,
769+
activation: str | Callable = "softmax",
770+
threshold: float | None = None,
771+
return_type_map: bool = True,
772+
device: str | torch.device | None = None,
766773
) -> None:
767774
super().__init__()
775+
self.device = device
768776
self.return_type_map = return_type_map
769777
self.generate_instance_type = GenerateInstanceType()
770778

@@ -824,5 +832,6 @@ def __call__( # type: ignore
824832
# update instance type map
825833
if type_map is not None:
826834
type_map[instance_map == inst_id] = instance_type
835+
type_map = convert_to_tensor(type_map, device=self.device)
827836

828837
return instance_info, type_map

monai/apps/pathology/transforms/post/dictionary.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from collections.abc import Callable, Hashable, Mapping
1515

1616
import numpy as np
17+
import torch
1718

1819
from monai.apps.pathology.transforms.post.array import (
1920
GenerateDistanceMap,
@@ -488,6 +489,7 @@ class HoVerNetInstanceMapPostProcessingd(Transform):
488489
min_num_points: minimum number of points to be considered as a contour. Defaults to 3.
489490
contour_level: an optional value for `skimage.measure.find_contours` to find contours in the array.
490491
If not provided, the level is set to `(max(image) + min(image)) / 2`.
492+
device: target device to put the output Tensor data.
491493
"""
492494

493495
def __init__(
@@ -507,6 +509,7 @@ def __init__(
507509
watershed_connectivity: int | None = 1,
508510
min_num_points: int = 3,
509511
contour_level: float | None = None,
512+
device: str | torch.device | None = None,
510513
) -> None:
511514
super().__init__()
512515
self.instance_map_post_process = HoVerNetInstanceMapPostProcessing(
@@ -521,6 +524,7 @@ def __init__(
521524
watershed_connectivity=watershed_connectivity,
522525
min_num_points=min_num_points,
523526
contour_level=contour_level,
527+
device=device,
524528
)
525529
self.nuclear_prediction_key = nuclear_prediction_key
526530
self.hover_map_key = hover_map_key
@@ -553,7 +557,7 @@ class HoVerNetNuclearTypePostProcessingd(Transform):
553557
Defaults to `"instance_info"`.
554558
instance_map_key: the key where instance map is stored. Defaults to `"instance_map"`.
555559
type_map_key: the output key where type map is written. Defaults to `"type_map"`.
556-
560+
device: target device to put the output Tensor data.
557561
558562
"""
559563

@@ -566,10 +570,11 @@ def __init__(
566570
activation: str | Callable = "softmax",
567571
threshold: float | None = None,
568572
return_type_map: bool = True,
573+
device: str | torch.device | None = None,
569574
) -> None:
570575
super().__init__()
571576
self.type_post_process = HoVerNetNuclearTypePostProcessing(
572-
activation=activation, threshold=threshold, return_type_map=return_type_map
577+
activation=activation, threshold=threshold, return_type_map=return_type_map, device=device
573578
)
574579
self.type_prediction_key = type_prediction_key
575580
self.instance_info_key = instance_info_key

0 commit comments

Comments
 (0)