Skip to content

Commit 872585d

Browse files
heyufan1995pre-commit-ci[bot]yiheng-wang-nvKumoLiu
authored
Add vista3d inferers (#8021)
Fixes # . ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: heyufan1995 <heyufan1995@gmail.com> Signed-off-by: Yiheng Wang <vennw@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Yiheng Wang <vennw@nvidia.com> Co-authored-by: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
1 parent a5fbe71 commit 872585d

File tree

13 files changed

+988
-17
lines changed

13 files changed

+988
-17
lines changed

docs/source/apps.rst

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,22 @@ FastMRIReader
248248
~~~~~~~~~~~~~
249249
.. autofunction:: monai.apps.reconstruction.complex_utils.complex_conj
250250

251+
`Vista3d`
252+
---------
253+
.. automodule:: monai.apps.vista3d.inferer
254+
.. autofunction:: point_based_window_inferer
255+
256+
.. automodule:: monai.apps.vista3d.transforms
257+
.. autoclass:: VistaPreTransformd
258+
:members:
259+
.. autoclass:: VistaPostTransformd
260+
:members:
261+
.. autoclass:: Relabeld
262+
:members:
263+
264+
.. automodule:: monai.apps.vista3d.sampler
265+
.. autofunction:: sample_prompt_pairs
266+
251267
`Auto3DSeg`
252268
-----------
253269
.. automodule:: monai.apps.auto3dseg
File renamed without changes.

monai/apps/vista3d/inferer.py

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import copy
15+
from collections.abc import Sequence
16+
from typing import Any
17+
18+
import torch
19+
20+
from monai.data.meta_tensor import MetaTensor
21+
from monai.utils import optional_import
22+
23+
tqdm, _ = optional_import("tqdm", name="tqdm")
24+
25+
__all__ = ["point_based_window_inferer"]
26+
27+
28+
def point_based_window_inferer(
29+
inputs: torch.Tensor | MetaTensor,
30+
roi_size: Sequence[int],
31+
predictor: torch.nn.Module,
32+
point_coords: torch.Tensor,
33+
point_labels: torch.Tensor,
34+
class_vector: torch.Tensor | None = None,
35+
prompt_class: torch.Tensor | None = None,
36+
prev_mask: torch.Tensor | MetaTensor | None = None,
37+
point_start: int = 0,
38+
center_only: bool = True,
39+
margin: int = 5,
40+
**kwargs: Any,
41+
) -> torch.Tensor:
42+
"""
43+
Point-based window inferer that takes an input image, a set of points, and a model, and returns a segmented image.
44+
The inferer algorithm crops the input image into patches that centered at the point sets, which is followed by
45+
patch inference and average output stitching, and finally returns the segmented mask.
46+
47+
Args:
48+
inputs: [1CHWD], input image to be processed.
49+
roi_size: the spatial window size for inferences.
50+
When its components have None or non-positives, the corresponding inputs dimension will be used.
51+
if the components of the `roi_size` are non-positive values, the transform will use the
52+
corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted
53+
to `(32, 64)` if the second spatial dimension size of img is `64`.
54+
sw_batch_size: the batch size to run window slices.
55+
predictor: the model. For vista3D, the output is [B, 1, H, W, D] which needs to be transposed to [1, B, H, W, D].
56+
Add transpose=True in kwargs for vista3d.
57+
point_coords: [B, N, 3]. Point coordinates for B foreground objects, each has N points.
58+
point_labels: [B, N]. Point labels. 0/1 means negative/positive points for regular supported or zero-shot classes.
59+
2/3 means negative/positive points for special supported classes (e.g. tumor, vessel).
60+
class_vector: [B]. Used for class-head automatic segmentation. Can be None value.
61+
prompt_class: [B]. The same as class_vector representing the point class and inform point head about
62+
supported class or zeroshot, not used for automatic segmentation. If None, point head is default
63+
to supported class segmentation.
64+
prev_mask: [1, B, H, W, D]. The value is before sigmoid. An optional tensor of previously segmented masks.
65+
point_start: only use points starting from this number. All points before this number is used to generate
66+
prev_mask. This is used to avoid re-calculating the points in previous iterations if given prev_mask.
67+
center_only: for each point, only crop the patch centered at this point. If false, crop 3 patches for each point.
68+
margin: if center_only is false, this value is the distance between point to the patch boundary.
69+
Returns:
70+
stitched_output: [1, B, H, W, D]. The value is before sigmoid.
71+
Notice: The function only supports SINGLE OBJECT INFERENCE with B=1.
72+
"""
73+
if not point_coords.shape[0] == 1:
74+
raise ValueError("Only supports single object point click.")
75+
if not len(inputs.shape) == 5:
76+
raise ValueError("Input image should be 5D.")
77+
image, pad = _pad_previous_mask(copy.deepcopy(inputs), roi_size)
78+
point_coords = point_coords + torch.tensor([pad[-2], pad[-4], pad[-6]]).to(point_coords.device)
79+
prev_mask = _pad_previous_mask(copy.deepcopy(prev_mask), roi_size)[0] if prev_mask is not None else None
80+
stitched_output = None
81+
for p in point_coords[0][point_start:]:
82+
lx_, rx_ = _get_window_idx(p[0], roi_size[0], image.shape[-3], center_only=center_only, margin=margin)
83+
ly_, ry_ = _get_window_idx(p[1], roi_size[1], image.shape[-2], center_only=center_only, margin=margin)
84+
lz_, rz_ = _get_window_idx(p[2], roi_size[2], image.shape[-1], center_only=center_only, margin=margin)
85+
for i in range(len(lx_)):
86+
for j in range(len(ly_)):
87+
for k in range(len(lz_)):
88+
lx, rx, ly, ry, lz, rz = (lx_[i], rx_[i], ly_[j], ry_[j], lz_[k], rz_[k])
89+
unravel_slice = [
90+
slice(None),
91+
slice(None),
92+
slice(int(lx), int(rx)),
93+
slice(int(ly), int(ry)),
94+
slice(int(lz), int(rz)),
95+
]
96+
batch_image = image[unravel_slice]
97+
output = predictor(
98+
batch_image,
99+
point_coords=point_coords,
100+
point_labels=point_labels,
101+
class_vector=class_vector,
102+
prompt_class=prompt_class,
103+
patch_coords=unravel_slice,
104+
prev_mask=prev_mask,
105+
**kwargs,
106+
)
107+
if stitched_output is None:
108+
stitched_output = torch.zeros(
109+
[1, output.shape[1], image.shape[-3], image.shape[-2], image.shape[-1]], device="cpu"
110+
)
111+
stitched_mask = torch.zeros(
112+
[1, output.shape[1], image.shape[-3], image.shape[-2], image.shape[-1]], device="cpu"
113+
)
114+
stitched_output[unravel_slice] += output.to("cpu")
115+
stitched_mask[unravel_slice] = 1
116+
# if stitched_mask is 0, then NaN value
117+
stitched_output = stitched_output / stitched_mask
118+
# revert padding
119+
stitched_output = stitched_output[
120+
:, :, pad[4] : image.shape[-3] - pad[5], pad[2] : image.shape[-2] - pad[3], pad[0] : image.shape[-1] - pad[1]
121+
]
122+
stitched_mask = stitched_mask[
123+
:, :, pad[4] : image.shape[-3] - pad[5], pad[2] : image.shape[-2] - pad[3], pad[0] : image.shape[-1] - pad[1]
124+
]
125+
if prev_mask is not None:
126+
prev_mask = prev_mask[
127+
:,
128+
:,
129+
pad[4] : image.shape[-3] - pad[5],
130+
pad[2] : image.shape[-2] - pad[3],
131+
pad[0] : image.shape[-1] - pad[1],
132+
]
133+
prev_mask = prev_mask.to("cpu") # type: ignore
134+
# for un-calculated place, use previous mask
135+
stitched_output[stitched_mask < 1] = prev_mask[stitched_mask < 1]
136+
if isinstance(inputs, torch.Tensor):
137+
inputs = MetaTensor(inputs)
138+
if not hasattr(stitched_output, "meta"):
139+
stitched_output = MetaTensor(stitched_output, affine=inputs.meta["affine"], meta=inputs.meta)
140+
return stitched_output
141+
142+
143+
def _get_window_idx_c(p: int, roi: int, s: int) -> tuple[int, int]:
144+
"""Helper function to get the window index."""
145+
if p - roi // 2 < 0:
146+
left, right = 0, roi
147+
elif p + roi // 2 > s:
148+
left, right = s - roi, s
149+
else:
150+
left, right = int(p) - roi // 2, int(p) + roi // 2
151+
return left, right
152+
153+
154+
def _get_window_idx(p: int, roi: int, s: int, center_only: bool = True, margin: int = 5) -> tuple[list[int], list[int]]:
155+
"""Get the window index."""
156+
left, right = _get_window_idx_c(p, roi, s)
157+
if center_only:
158+
return [left], [right]
159+
left_most = max(0, p - roi + margin)
160+
right_most = min(s, p + roi - margin)
161+
left_list = [left_most, right_most - roi, left]
162+
right_list = [left_most + roi, right_most, right]
163+
return left_list, right_list
164+
165+
166+
def _pad_previous_mask(
167+
inputs: torch.Tensor | MetaTensor, roi_size: Sequence[int], padvalue: int = 0
168+
) -> tuple[torch.Tensor | MetaTensor, list[int]]:
169+
"""Helper function to pad inputs."""
170+
pad_size = []
171+
for k in range(len(inputs.shape) - 1, 1, -1):
172+
diff = max(roi_size[k - 2] - inputs.shape[k], 0)
173+
half = diff // 2
174+
pad_size.extend([half, diff - half])
175+
if any(pad_size):
176+
inputs = torch.nn.functional.pad(inputs, pad=pad_size, mode="constant", value=padvalue) # type: ignore
177+
return inputs, pad_size

monai/apps/vista3d/sampler.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import copy
15+
import random
16+
from collections.abc import Callable, Sequence
17+
from typing import Any
18+
19+
import numpy as np
20+
import torch
21+
from torch import Tensor
22+
23+
__all__ = ["sample_prompt_pairs"]
24+
25+
ENABLE_SPECIAL = True
26+
SPECIAL_INDEX = (23, 24, 25, 26, 27, 57, 128)
27+
MERGE_LIST = {
28+
1: [25, 26], # hepatic tumor and vessel merge into liver
29+
4: [24], # pancreatic tumor merge into pancreas
30+
132: [57], # overlap with trachea merge into airway
31+
}
32+
33+
34+
def _get_point_label(id: int) -> tuple[int, int]:
35+
if id in SPECIAL_INDEX and ENABLE_SPECIAL:
36+
return 2, 3
37+
else:
38+
return 0, 1
39+
40+
41+
def sample_prompt_pairs(
42+
labels: Tensor,
43+
label_set: Sequence[int],
44+
max_prompt: int | None = None,
45+
max_foreprompt: int | None = None,
46+
max_backprompt: int = 1,
47+
max_point: int = 20,
48+
include_background: bool = False,
49+
drop_label_prob: float = 0.2,
50+
drop_point_prob: float = 0.2,
51+
point_sampler: Callable | None = None,
52+
**point_sampler_kwargs: Any,
53+
) -> tuple[Tensor | None, Tensor | None, Tensor | None, Tensor | None]:
54+
"""
55+
Sample training pairs for VISTA3D training.
56+
57+
Args:
58+
labels: [1, 1, H, W, D], ground truth labels.
59+
label_set: the label list for the specific dataset. Note if 0 is included in label_set,
60+
it will be added into automatic branch training. Recommend removing 0 from label_set
61+
for multi-partially-labeled-dataset training, and adding 0 for finetuning specific dataset.
62+
The reason is region with 0 in one partially labeled dataset may contain foregrounds in
63+
another dataset.
64+
max_prompt: int, max number of total prompt, including foreground and background.
65+
max_foreprompt: int, max number of prompt from foreground.
66+
max_backprompt: int, max number of prompt from background.
67+
max_point: maximum number of points for each object.
68+
include_background: if include 0 into training prompt. If included, background 0 is treated
69+
the same as foreground. Always be False for multi-partial-dataset training. If needed,
70+
can be true for finetuning specific dataset, .
71+
drop_label_prob: probability to drop label prompt.
72+
drop_point_prob: probability to drop point prompt.
73+
point_sampler: sampler to augment masks with supervoxel.
74+
point_sampler_kwargs: arguments for point_sampler.
75+
76+
Returns:
77+
label_prompt: [B, 1]. The classes used for training automatic segmentation.
78+
point: [B, N, 3]. The corresponding points for each class.
79+
Note that background label prompt requires matching point as well ([0,0,0] is used).
80+
point_label: [B, N]. The corresponding point labels for each point (negative or positive).
81+
-1 is used for padding the background label prompt and will be ignored.
82+
prompt_class: [B, 1], exactly the same with label_prompt for label indexing for training loss.
83+
label_prompt can be None, and prompt_class is used to identify point classes.
84+
"""
85+
# class label number
86+
if not labels.shape[0] == 1:
87+
raise ValueError("only support batch size 1")
88+
labels = labels[0, 0]
89+
device = labels.device
90+
unique_labels = labels.unique().cpu().numpy().tolist()
91+
if include_background:
92+
unique_labels = list(set(unique_labels) - (set(unique_labels) - set(label_set)))
93+
else:
94+
unique_labels = list(set(unique_labels) - (set(unique_labels) - set(label_set)) - {0})
95+
background_labels = list(set(label_set) - set(unique_labels))
96+
# during training, balance background and foreground prompts
97+
if max_backprompt is not None:
98+
if len(background_labels) > max_backprompt:
99+
random.shuffle(background_labels)
100+
background_labels = background_labels[:max_backprompt]
101+
102+
if max_foreprompt is not None:
103+
if len(unique_labels) > max_foreprompt:
104+
random.shuffle(unique_labels)
105+
unique_labels = unique_labels[:max_foreprompt]
106+
107+
if max_prompt is not None:
108+
if len(unique_labels) + len(background_labels) > max_prompt:
109+
if len(unique_labels) > max_prompt:
110+
unique_labels = random.sample(unique_labels, max_prompt)
111+
background_labels = []
112+
else:
113+
background_labels = random.sample(background_labels, max_prompt - len(unique_labels))
114+
_point = []
115+
_point_label = []
116+
# if use regular sampling
117+
if point_sampler is None:
118+
num_p = min(max_point, int(np.abs(random.gauss(mu=0, sigma=max_point // 2))) + 1)
119+
num_n = min(max_point, int(np.abs(random.gauss(mu=0, sigma=max_point // 2))))
120+
for id in unique_labels:
121+
neg_id, pos_id = _get_point_label(id)
122+
plabels = labels == int(id)
123+
nlabels = ~plabels
124+
plabelpoints = torch.nonzero(plabels)
125+
nlabelpoints = torch.nonzero(nlabels)
126+
# final sampled positive points
127+
num_pa = min(len(plabelpoints), num_p)
128+
# final sampled negative points
129+
num_na = min(len(nlabelpoints), num_n)
130+
_point.append(
131+
torch.stack(
132+
random.choices(plabelpoints, k=num_pa)
133+
+ random.choices(nlabelpoints, k=num_na)
134+
+ [torch.tensor([0, 0, 0], device=device)] * (num_p + num_n - num_pa - num_na)
135+
)
136+
)
137+
_point_label.append(
138+
torch.tensor([pos_id] * num_pa + [neg_id] * num_na + [-1] * (num_p + num_n - num_pa - num_na)).to(
139+
device
140+
)
141+
)
142+
for _ in background_labels:
143+
# pad the background labels
144+
_point.append(torch.zeros(num_p + num_n, 3).to(device)) # all 0
145+
_point_label.append(torch.zeros(num_p + num_n).to(device) - 1) # -1 not a point
146+
else:
147+
_point, _point_label = point_sampler(unique_labels, **point_sampler_kwargs)
148+
for _ in background_labels:
149+
# pad the background labels
150+
_point.append(torch.zeros(len(_point_label[0]), 3).to(device)) # all 0
151+
_point_label.append(torch.zeros(len(_point_label[0])).to(device) - 1) # -1 not a point
152+
if len(unique_labels) == 0 and len(background_labels) == 0:
153+
# if max_backprompt is 0 and len(unique_labels), there is no effective prompt and the iteration must
154+
# be skipped. Handle this in trainer.
155+
label_prompt, point, point_label, prompt_class = None, None, None, None
156+
else:
157+
label_prompt = torch.tensor(unique_labels + background_labels).unsqueeze(-1).to(device).long()
158+
point = torch.stack(_point)
159+
point_label = torch.stack(_point_label)
160+
prompt_class = copy.deepcopy(label_prompt)
161+
if random.uniform(0, 1) < drop_label_prob and len(unique_labels) > 0:
162+
label_prompt = None
163+
# If label prompt is dropped, there is no need to pad with points with label -1.
164+
pad = len(background_labels)
165+
point = point[: len(point) - pad] # type: ignore
166+
point_label = point_label[: len(point_label) - pad]
167+
prompt_class = prompt_class[: len(prompt_class) - pad]
168+
else:
169+
if random.uniform(0, 1) < drop_point_prob:
170+
point = None
171+
point_label = None
172+
return label_prompt, point, point_label, prompt_class

0 commit comments

Comments
 (0)