|
22 | 22 |
|
23 | 23 | import numpy as np |
24 | 24 | import torch |
| 25 | +from torch import Tensor |
25 | 26 |
|
26 | 27 | import monai |
27 | 28 | from monai.config import DtypeLike, IndexSelection |
|
30 | 31 | from monai.networks.utils import meshgrid_ij |
31 | 32 | from monai.transforms.compose import Compose |
32 | 33 | from monai.transforms.transform import MapTransform, Transform, apply_transform |
| 34 | +from monai.transforms.utils_morphological_ops import erode |
33 | 35 | from monai.transforms.utils_pytorch_numpy_unification import ( |
34 | 36 | any_np_pt, |
35 | 37 | ascontiguousarray, |
|
65 | 67 | min_version, |
66 | 68 | optional_import, |
67 | 69 | pytorch_after, |
| 70 | + unsqueeze_left, |
| 71 | + unsqueeze_right, |
68 | 72 | ) |
69 | 73 | from monai.utils.enums import TransformBackends |
70 | 74 | from monai.utils.type_conversion import ( |
|
103 | 107 | "generate_spatial_bounding_box", |
104 | 108 | "get_extreme_points", |
105 | 109 | "get_largest_connected_component_mask", |
| 110 | + "get_largest_connected_component_mask_point", |
| 111 | + "convert_points_to_disc", |
106 | 112 | "remove_small_objects", |
107 | 113 | "img_bounds", |
108 | 114 | "in_bounds", |
@@ -1172,6 +1178,183 @@ def get_largest_connected_component_mask( |
1172 | 1178 | return convert_to_dst_type(out, dst=img, dtype=out.dtype)[0] |
1173 | 1179 |
|
1174 | 1180 |
|
| 1181 | +def get_largest_connected_component_mask_point( |
| 1182 | + img_pos: NdarrayTensor, |
| 1183 | + img_neg: NdarrayTensor, |
| 1184 | + point_coords: NdarrayTensor, |
| 1185 | + point_labels: NdarrayTensor, |
| 1186 | + pos_val: Sequence[int] = (1, 3), |
| 1187 | + neg_val: Sequence[int] = (0, 2), |
| 1188 | + margins: int = 3, |
| 1189 | +) -> NdarrayTensor: |
| 1190 | + """ |
| 1191 | + Gets the connected component of img_pos and img_neg that include the positive points and |
| 1192 | + negative points separately. The function is used for combining automatic results with interactive |
| 1193 | + results in VISTA3D. |
| 1194 | +
|
| 1195 | + Args: |
| 1196 | + img_pos: bool type tensor, shape [B, 1, H, W, D], where B means the foreground masks from a single 3D image. |
| 1197 | + img_neg: same format as img_pos but corresponds to negative points. |
| 1198 | + pos_val: positive point label values. |
| 1199 | + neg_val: negative point label values. |
| 1200 | + point_coords: the coordinates of each point, shape [B, N, 3], where N means the number of points. |
| 1201 | + point_labels: the label of each point, shape [B, N]. |
| 1202 | + """ |
| 1203 | + |
| 1204 | + cucim_skimage, has_cucim = optional_import("cucim.skimage") |
| 1205 | + |
| 1206 | + use_cp = has_cp and has_cucim and isinstance(img_pos, torch.Tensor) and img_pos.device != torch.device("cpu") |
| 1207 | + if use_cp: |
| 1208 | + img_pos_ = convert_to_cupy(img_pos.short()) # type: ignore |
| 1209 | + img_neg_ = convert_to_cupy(img_neg.short()) # type: ignore |
| 1210 | + label = cucim_skimage.measure.label |
| 1211 | + lib = cp |
| 1212 | + else: |
| 1213 | + if not has_measure: |
| 1214 | + raise RuntimeError("skimage.measure required.") |
| 1215 | + img_pos_, *_ = convert_data_type(img_pos, np.ndarray) |
| 1216 | + img_neg_, *_ = convert_data_type(img_neg, np.ndarray) |
| 1217 | + # for skimage.measure.label, the input must be bool type |
| 1218 | + if img_pos_.dtype != bool or img_neg_.dtype != bool: |
| 1219 | + raise ValueError("img_pos and img_neg must be bool type.") |
| 1220 | + label = measure.label |
| 1221 | + lib = np |
| 1222 | + |
| 1223 | + features_pos, _ = label(img_pos_, connectivity=3, return_num=True) |
| 1224 | + features_neg, _ = label(img_neg_, connectivity=3, return_num=True) |
| 1225 | + |
| 1226 | + outs = np.zeros_like(img_pos_) |
| 1227 | + for bs in range(point_coords.shape[0]): |
| 1228 | + for i, p in enumerate(point_coords[bs]): |
| 1229 | + if point_labels[bs, i] in pos_val: |
| 1230 | + features = features_pos |
| 1231 | + elif point_labels[bs, i] in neg_val: |
| 1232 | + features = features_neg |
| 1233 | + else: |
| 1234 | + # if -1 padding point, skip |
| 1235 | + continue |
| 1236 | + for margin in range(margins): |
| 1237 | + if isinstance(p, np.ndarray): |
| 1238 | + x, y, z = np.round(p).astype(int).tolist() |
| 1239 | + else: |
| 1240 | + x, y, z = p.float().round().int().tolist() |
| 1241 | + l, r = max(x - margin, 0), min(x + margin + 1, features.shape[-3]) |
| 1242 | + t, d = max(y - margin, 0), min(y + margin + 1, features.shape[-2]) |
| 1243 | + f, b = max(z - margin, 0), min(z + margin + 1, features.shape[-1]) |
| 1244 | + if (features[bs, 0, l:r, t:d, f:b] > 0).any(): |
| 1245 | + index = features[bs, 0, l:r, t:d, f:b].max() |
| 1246 | + outs[[bs]] += lib.isin(features[[bs]], index) |
| 1247 | + break |
| 1248 | + outs[outs > 1] = 1 |
| 1249 | + return convert_to_dst_type(outs, dst=img_pos, dtype=outs.dtype)[0] |
| 1250 | + |
| 1251 | + |
| 1252 | +def convert_points_to_disc( |
| 1253 | + image_size: Sequence[int], point: Tensor, point_label: Tensor, radius: int = 2, disc: bool = False |
| 1254 | +): |
| 1255 | + """ |
| 1256 | + Convert a 3D point coordinates into image mask. The returned mask has the same spatial |
| 1257 | + size as `image_size` while the batch dimension is the same as 'point' batch dimension. |
| 1258 | + The point is converted to a mask ball with radius defined by `radius`. The output |
| 1259 | + contains two channels each for negative (first channel) and positive points. |
| 1260 | +
|
| 1261 | + Args: |
| 1262 | + image_size: The output size of the converted mask. It should be a 3D tuple. |
| 1263 | + point: [B, N, 3], 3D point coordinates. |
| 1264 | + point_label: [B, N], 0 or 2 means negative points, 1 or 3 means postive points. |
| 1265 | + radius: disc ball radius size. |
| 1266 | + disc: If true, use regular disc, other use gaussian. |
| 1267 | + """ |
| 1268 | + masks = torch.zeros([point.shape[0], 2, image_size[0], image_size[1], image_size[2]], device=point.device) |
| 1269 | + _array = [ |
| 1270 | + torch.arange(start=0, end=image_size[i], step=1, dtype=torch.float32, device=point.device) for i in range(3) |
| 1271 | + ] |
| 1272 | + coord_rows, coord_cols, coord_z = torch.meshgrid(_array[2], _array[1], _array[0]) |
| 1273 | + # [1, 3, h, w, d] -> [b, 2, 3, h, w, d] |
| 1274 | + coords = unsqueeze_left(torch.stack((coord_rows, coord_cols, coord_z), dim=0), 6) |
| 1275 | + coords = coords.repeat(point.shape[0], 2, 1, 1, 1, 1) |
| 1276 | + for b, n in np.ndindex(*point.shape[:2]): |
| 1277 | + point_bn = unsqueeze_right(point[b, n], 6) |
| 1278 | + if point_label[b, n] > -1: |
| 1279 | + channel = 0 if (point_label[b, n] == 0 or point_label[b, n] == 2) else 1 |
| 1280 | + pow_diff = torch.pow(coords[b, channel] - point_bn[b, n], 2) |
| 1281 | + if disc: |
| 1282 | + masks[b, channel] += pow_diff.sum(0) < radius**2 |
| 1283 | + else: |
| 1284 | + masks[b, channel] += torch.exp(-pow_diff.sum(0) / (2 * radius**2)) |
| 1285 | + return masks |
| 1286 | + |
| 1287 | + |
| 1288 | +def sample_points_from_label( |
| 1289 | + labels: Tensor, |
| 1290 | + label_set: Sequence[int], |
| 1291 | + max_ppoint: int = 1, |
| 1292 | + max_npoint: int = 0, |
| 1293 | + device: torch.device | str | None = "cpu", |
| 1294 | + use_center: bool = False, |
| 1295 | +): |
| 1296 | + """Sample points from labels. |
| 1297 | +
|
| 1298 | + Args: |
| 1299 | + labels: [1, 1, H, W, D] |
| 1300 | + label_set: local index, must match values in labels. |
| 1301 | + max_ppoint: maximum positive point samples. |
| 1302 | + max_npoint: maximum negative point samples. |
| 1303 | + device: returned tensor device. |
| 1304 | + use_center: whether to sample points from center. |
| 1305 | +
|
| 1306 | + Returns: |
| 1307 | + point: point coordinates of [B, N, 3]. B equals to the length of label_set. |
| 1308 | + point_label: [B, N], always 0 for negative, 1 for positive. |
| 1309 | + """ |
| 1310 | + if not labels.shape[0] == 1: |
| 1311 | + raise ValueError("labels must have batch size 1.") |
| 1312 | + |
| 1313 | + if device is None: |
| 1314 | + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 1315 | + |
| 1316 | + labels = labels[0, 0] |
| 1317 | + unique_labels = labels.unique().cpu().numpy().tolist() |
| 1318 | + _point = [] |
| 1319 | + _point_label = [] |
| 1320 | + for id in label_set: |
| 1321 | + if id in unique_labels: |
| 1322 | + plabels = labels == int(id) |
| 1323 | + nlabels = ~plabels |
| 1324 | + _plabels = get_largest_connected_component_mask(erode(plabels.unsqueeze(0).unsqueeze(0))[0, 0]) |
| 1325 | + plabelpoints = torch.nonzero(_plabels).to(device) |
| 1326 | + if len(plabelpoints) == 0: |
| 1327 | + plabelpoints = torch.nonzero(plabels).to(device) |
| 1328 | + nlabelpoints = torch.nonzero(nlabels).to(device) |
| 1329 | + num_p = min(len(plabelpoints), max_ppoint) |
| 1330 | + num_n = min(len(nlabelpoints), max_npoint) |
| 1331 | + pad = max_ppoint + max_npoint - num_p - num_n |
| 1332 | + if use_center: |
| 1333 | + pmean = plabelpoints.float().mean(0) |
| 1334 | + pdis = ((plabelpoints - pmean) ** 2).sum(-1) |
| 1335 | + _, sorted_indices_tensor = torch.sort(pdis) |
| 1336 | + sorted_indices = sorted_indices_tensor.cpu().tolist() |
| 1337 | + else: |
| 1338 | + sorted_indices = list(range(len(plabelpoints))) |
| 1339 | + random.shuffle(sorted_indices) |
| 1340 | + _point.append( |
| 1341 | + torch.stack( |
| 1342 | + [plabelpoints[sorted_indices[i]] for i in range(num_p)] |
| 1343 | + + random.choices(nlabelpoints, k=num_n) |
| 1344 | + + [torch.tensor([0, 0, 0], device=device)] * pad |
| 1345 | + ) |
| 1346 | + ) |
| 1347 | + _point_label.append(torch.tensor([1] * num_p + [0] * num_n + [-1] * pad).to(device)) |
| 1348 | + else: |
| 1349 | + # pad the background labels |
| 1350 | + _point.append(torch.zeros(max_ppoint + max_npoint, 3).to(device)) |
| 1351 | + _point_label.append(torch.zeros(max_ppoint + max_npoint).to(device) - 1) |
| 1352 | + point = torch.stack(_point) |
| 1353 | + point_label = torch.stack(_point_label) |
| 1354 | + |
| 1355 | + return point, point_label |
| 1356 | + |
| 1357 | + |
1175 | 1358 | def remove_small_objects( |
1176 | 1359 | img: NdarrayTensor, |
1177 | 1360 | min_size: int = 64, |
|
0 commit comments