Skip to content

Commit f848002

Browse files
yiheng-wang-nvKumoLiuericspod
authored
Add utils for vista3d (#7999)
This PR is a part of #7987 ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Yiheng Wang <vennw@nvidia.com> Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com>
1 parent 6be7b13 commit f848002

File tree

8 files changed

+324
-9
lines changed

8 files changed

+324
-9
lines changed

docs/source/apps.rst

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -261,11 +261,3 @@ FastMRIReader
261261

262262
.. autoclass:: monai.apps.nnunet.nnUNetV2Runner
263263
:members:
264-
265-
`Generative AI`
266-
---------------
267-
268-
`MAISI Utilities`
269-
~~~~~~~~~~~~~~~~~
270-
.. automodule:: monai.apps.generation.maisi.utils.morphological_ops
271-
:members:

docs/source/transforms.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2310,6 +2310,9 @@ Utilities
23102310
.. automodule:: monai.transforms.utils_pytorch_numpy_unification
23112311
:members:
23122312

2313+
.. automodule:: monai.transforms.utils_morphological_ops
2314+
:members:
2315+
23132316
By Categories
23142317
-------------
23152318
.. toctree::

monai/transforms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -688,6 +688,7 @@
688688
weighted_patch_samples,
689689
zero_margins,
690690
)
691+
from .utils_morphological_ops import dilate, erode
691692
from .utils_pytorch_numpy_unification import (
692693
allclose,
693694
any_np_pt,

monai/transforms/utils.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import numpy as np
2424
import torch
25+
from torch import Tensor
2526

2627
import monai
2728
from monai.config import DtypeLike, IndexSelection
@@ -30,6 +31,7 @@
3031
from monai.networks.utils import meshgrid_ij
3132
from monai.transforms.compose import Compose
3233
from monai.transforms.transform import MapTransform, Transform, apply_transform
34+
from monai.transforms.utils_morphological_ops import erode
3335
from monai.transforms.utils_pytorch_numpy_unification import (
3436
any_np_pt,
3537
ascontiguousarray,
@@ -65,6 +67,8 @@
6567
min_version,
6668
optional_import,
6769
pytorch_after,
70+
unsqueeze_left,
71+
unsqueeze_right,
6872
)
6973
from monai.utils.enums import TransformBackends
7074
from monai.utils.type_conversion import (
@@ -103,6 +107,8 @@
103107
"generate_spatial_bounding_box",
104108
"get_extreme_points",
105109
"get_largest_connected_component_mask",
110+
"get_largest_connected_component_mask_point",
111+
"convert_points_to_disc",
106112
"remove_small_objects",
107113
"img_bounds",
108114
"in_bounds",
@@ -1172,6 +1178,183 @@ def get_largest_connected_component_mask(
11721178
return convert_to_dst_type(out, dst=img, dtype=out.dtype)[0]
11731179

11741180

1181+
def get_largest_connected_component_mask_point(
1182+
img_pos: NdarrayTensor,
1183+
img_neg: NdarrayTensor,
1184+
point_coords: NdarrayTensor,
1185+
point_labels: NdarrayTensor,
1186+
pos_val: Sequence[int] = (1, 3),
1187+
neg_val: Sequence[int] = (0, 2),
1188+
margins: int = 3,
1189+
) -> NdarrayTensor:
1190+
"""
1191+
Gets the connected component of img_pos and img_neg that include the positive points and
1192+
negative points separately. The function is used for combining automatic results with interactive
1193+
results in VISTA3D.
1194+
1195+
Args:
1196+
img_pos: bool type tensor, shape [B, 1, H, W, D], where B means the foreground masks from a single 3D image.
1197+
img_neg: same format as img_pos but corresponds to negative points.
1198+
pos_val: positive point label values.
1199+
neg_val: negative point label values.
1200+
point_coords: the coordinates of each point, shape [B, N, 3], where N means the number of points.
1201+
point_labels: the label of each point, shape [B, N].
1202+
"""
1203+
1204+
cucim_skimage, has_cucim = optional_import("cucim.skimage")
1205+
1206+
use_cp = has_cp and has_cucim and isinstance(img_pos, torch.Tensor) and img_pos.device != torch.device("cpu")
1207+
if use_cp:
1208+
img_pos_ = convert_to_cupy(img_pos.short()) # type: ignore
1209+
img_neg_ = convert_to_cupy(img_neg.short()) # type: ignore
1210+
label = cucim_skimage.measure.label
1211+
lib = cp
1212+
else:
1213+
if not has_measure:
1214+
raise RuntimeError("skimage.measure required.")
1215+
img_pos_, *_ = convert_data_type(img_pos, np.ndarray)
1216+
img_neg_, *_ = convert_data_type(img_neg, np.ndarray)
1217+
# for skimage.measure.label, the input must be bool type
1218+
if img_pos_.dtype != bool or img_neg_.dtype != bool:
1219+
raise ValueError("img_pos and img_neg must be bool type.")
1220+
label = measure.label
1221+
lib = np
1222+
1223+
features_pos, _ = label(img_pos_, connectivity=3, return_num=True)
1224+
features_neg, _ = label(img_neg_, connectivity=3, return_num=True)
1225+
1226+
outs = np.zeros_like(img_pos_)
1227+
for bs in range(point_coords.shape[0]):
1228+
for i, p in enumerate(point_coords[bs]):
1229+
if point_labels[bs, i] in pos_val:
1230+
features = features_pos
1231+
elif point_labels[bs, i] in neg_val:
1232+
features = features_neg
1233+
else:
1234+
# if -1 padding point, skip
1235+
continue
1236+
for margin in range(margins):
1237+
if isinstance(p, np.ndarray):
1238+
x, y, z = np.round(p).astype(int).tolist()
1239+
else:
1240+
x, y, z = p.float().round().int().tolist()
1241+
l, r = max(x - margin, 0), min(x + margin + 1, features.shape[-3])
1242+
t, d = max(y - margin, 0), min(y + margin + 1, features.shape[-2])
1243+
f, b = max(z - margin, 0), min(z + margin + 1, features.shape[-1])
1244+
if (features[bs, 0, l:r, t:d, f:b] > 0).any():
1245+
index = features[bs, 0, l:r, t:d, f:b].max()
1246+
outs[[bs]] += lib.isin(features[[bs]], index)
1247+
break
1248+
outs[outs > 1] = 1
1249+
return convert_to_dst_type(outs, dst=img_pos, dtype=outs.dtype)[0]
1250+
1251+
1252+
def convert_points_to_disc(
1253+
image_size: Sequence[int], point: Tensor, point_label: Tensor, radius: int = 2, disc: bool = False
1254+
):
1255+
"""
1256+
Convert a 3D point coordinates into image mask. The returned mask has the same spatial
1257+
size as `image_size` while the batch dimension is the same as 'point' batch dimension.
1258+
The point is converted to a mask ball with radius defined by `radius`. The output
1259+
contains two channels each for negative (first channel) and positive points.
1260+
1261+
Args:
1262+
image_size: The output size of the converted mask. It should be a 3D tuple.
1263+
point: [B, N, 3], 3D point coordinates.
1264+
point_label: [B, N], 0 or 2 means negative points, 1 or 3 means postive points.
1265+
radius: disc ball radius size.
1266+
disc: If true, use regular disc, other use gaussian.
1267+
"""
1268+
masks = torch.zeros([point.shape[0], 2, image_size[0], image_size[1], image_size[2]], device=point.device)
1269+
_array = [
1270+
torch.arange(start=0, end=image_size[i], step=1, dtype=torch.float32, device=point.device) for i in range(3)
1271+
]
1272+
coord_rows, coord_cols, coord_z = torch.meshgrid(_array[2], _array[1], _array[0])
1273+
# [1, 3, h, w, d] -> [b, 2, 3, h, w, d]
1274+
coords = unsqueeze_left(torch.stack((coord_rows, coord_cols, coord_z), dim=0), 6)
1275+
coords = coords.repeat(point.shape[0], 2, 1, 1, 1, 1)
1276+
for b, n in np.ndindex(*point.shape[:2]):
1277+
point_bn = unsqueeze_right(point[b, n], 6)
1278+
if point_label[b, n] > -1:
1279+
channel = 0 if (point_label[b, n] == 0 or point_label[b, n] == 2) else 1
1280+
pow_diff = torch.pow(coords[b, channel] - point_bn[b, n], 2)
1281+
if disc:
1282+
masks[b, channel] += pow_diff.sum(0) < radius**2
1283+
else:
1284+
masks[b, channel] += torch.exp(-pow_diff.sum(0) / (2 * radius**2))
1285+
return masks
1286+
1287+
1288+
def sample_points_from_label(
1289+
labels: Tensor,
1290+
label_set: Sequence[int],
1291+
max_ppoint: int = 1,
1292+
max_npoint: int = 0,
1293+
device: torch.device | str | None = "cpu",
1294+
use_center: bool = False,
1295+
):
1296+
"""Sample points from labels.
1297+
1298+
Args:
1299+
labels: [1, 1, H, W, D]
1300+
label_set: local index, must match values in labels.
1301+
max_ppoint: maximum positive point samples.
1302+
max_npoint: maximum negative point samples.
1303+
device: returned tensor device.
1304+
use_center: whether to sample points from center.
1305+
1306+
Returns:
1307+
point: point coordinates of [B, N, 3]. B equals to the length of label_set.
1308+
point_label: [B, N], always 0 for negative, 1 for positive.
1309+
"""
1310+
if not labels.shape[0] == 1:
1311+
raise ValueError("labels must have batch size 1.")
1312+
1313+
if device is None:
1314+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1315+
1316+
labels = labels[0, 0]
1317+
unique_labels = labels.unique().cpu().numpy().tolist()
1318+
_point = []
1319+
_point_label = []
1320+
for id in label_set:
1321+
if id in unique_labels:
1322+
plabels = labels == int(id)
1323+
nlabels = ~plabels
1324+
_plabels = get_largest_connected_component_mask(erode(plabels.unsqueeze(0).unsqueeze(0))[0, 0])
1325+
plabelpoints = torch.nonzero(_plabels).to(device)
1326+
if len(plabelpoints) == 0:
1327+
plabelpoints = torch.nonzero(plabels).to(device)
1328+
nlabelpoints = torch.nonzero(nlabels).to(device)
1329+
num_p = min(len(plabelpoints), max_ppoint)
1330+
num_n = min(len(nlabelpoints), max_npoint)
1331+
pad = max_ppoint + max_npoint - num_p - num_n
1332+
if use_center:
1333+
pmean = plabelpoints.float().mean(0)
1334+
pdis = ((plabelpoints - pmean) ** 2).sum(-1)
1335+
_, sorted_indices_tensor = torch.sort(pdis)
1336+
sorted_indices = sorted_indices_tensor.cpu().tolist()
1337+
else:
1338+
sorted_indices = list(range(len(plabelpoints)))
1339+
random.shuffle(sorted_indices)
1340+
_point.append(
1341+
torch.stack(
1342+
[plabelpoints[sorted_indices[i]] for i in range(num_p)]
1343+
+ random.choices(nlabelpoints, k=num_n)
1344+
+ [torch.tensor([0, 0, 0], device=device)] * pad
1345+
)
1346+
)
1347+
_point_label.append(torch.tensor([1] * num_p + [0] * num_n + [-1] * pad).to(device))
1348+
else:
1349+
# pad the background labels
1350+
_point.append(torch.zeros(max_ppoint + max_npoint, 3).to(device))
1351+
_point_label.append(torch.zeros(max_ppoint + max_npoint).to(device) - 1)
1352+
point = torch.stack(_point)
1353+
point_label = torch.stack(_point_label)
1354+
1355+
return point, point_label
1356+
1357+
11751358
def remove_small_objects(
11761359
img: NdarrayTensor,
11771360
min_size: int = 64,

monai/apps/generation/maisi/utils/morphological_ops.py renamed to monai/transforms/utils_morphological_ops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from monai.config import NdarrayOrTensor
2121
from monai.utils import convert_data_type, convert_to_dst_type, ensure_tuple_rep
2222

23+
__all__ = ["erode", "dilate"]
24+
2325

2426
def erode(mask: NdarrayOrTensor, filter_size: int | Sequence[int] = 3, pad_value: float = 1.0) -> NdarrayOrTensor:
2527
"""

tests/min_tests.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ def run_testsuit():
209209
"test_zarr_avg_merger",
210210
"test_perceptual_loss",
211211
"test_ultrasound_confidence_map_transform",
212+
"test_vista3d_utils",
212213
]
213214
assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}"
214215

tests/test_morphological_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import torch
1717
from parameterized import parameterized
1818

19-
from monai.apps.generation.maisi.utils.morphological_ops import dilate, erode, get_morphological_filter_result_t
19+
from monai.transforms.utils_morphological_ops import dilate, erode, get_morphological_filter_result_t
2020
from tests.utils import TEST_NDARRAYS, assert_allclose
2121

2222
TESTS_SHAPE = []

0 commit comments

Comments
 (0)