Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add vista3d inferers #8021

Merged
merged 37 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
cae979a
Add vista3d inferers
heyufan1995 Aug 15, 2024
daf7b45
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 15, 2024
ee3ceaa
fix format issues
yiheng-wang-nv Aug 16, 2024
4a069d7
add tests
yiheng-wang-nv Aug 16, 2024
519154e
add vista3d transforms
yiheng-wang-nv Aug 16, 2024
4b579e1
update inputs doc string
yiheng-wang-nv Aug 16, 2024
42d29be
Add transforms
heyufan1995 Aug 16, 2024
4144b75
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 16, 2024
656f212
Add test
heyufan1995 Aug 18, 2024
8b647fd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 18, 2024
757ea61
Add more test
heyufan1995 Aug 18, 2024
f44e9df
fix issues
yiheng-wang-nv Aug 19, 2024
81a0984
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 19, 2024
255a96e
Change docstring
heyufan1995 Aug 19, 2024
c9979d7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 19, 2024
72aa299
resolve comments
yiheng-wang-nv Aug 20, 2024
cb4d5c8
Merge branch 'dev' into add-vista3d-other-utils
yiheng-wang-nv Aug 20, 2024
fe29e86
fix doc issue
yiheng-wang-nv Aug 20, 2024
b1e1822
Address docstring issue
heyufan1995 Aug 20, 2024
879f1f8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 20, 2024
2649eb4
Merge branch 'dev' into add-vista3d-other-utils
KumoLiu Aug 20, 2024
cb7446b
Update docstring
heyufan1995 Aug 20, 2024
6c8bba5
Merge branch 'add-vista3d-other-utils' of github.com:heyufan1995/MONA…
heyufan1995 Aug 20, 2024
144b753
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 20, 2024
7127c10
Update docstring
heyufan1995 Aug 20, 2024
d4bb986
Merge branch 'add-vista3d-other-utils' of github.com:heyufan1995/MONA…
heyufan1995 Aug 20, 2024
2e21d89
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 20, 2024
c0e20b5
fix doc issue
yiheng-wang-nv Aug 21, 2024
c8e1e44
Add generate_prompt_pairs
heyufan1995 Aug 21, 2024
729d235
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 21, 2024
f1686bc
fix issues
yiheng-wang-nv Aug 22, 2024
8c1d7b6
update kwargs
yiheng-wang-nv Aug 22, 2024
9363fa3
Fix bug in convert point to disc and add more doc
heyufan1995 Aug 22, 2024
e5c8baa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 22, 2024
7b5f5f6
autofix
yiheng-wang-nv Aug 23, 2024
6ee3629
Merge branch 'dev' into add-vista3d-other-utils
KumoLiu Aug 23, 2024
194bd77
resolve issues
yiheng-wang-nv Aug 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/apps.rst
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,10 @@ FastMRIReader
~~~~~~~~~~~~~
.. autofunction:: monai.apps.reconstruction.complex_utils.complex_conj

`Vista3d`
---------
.. autofunction:: monai.apps.vista3d.inferer.point_based_window_inferer

`Auto3DSeg`
-----------
.. automodule:: monai.apps.auto3dseg
Expand Down
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
File renamed without changes.
165 changes: 165 additions & 0 deletions monai/apps/vista3d/inferer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import copy
from collections.abc import Sequence
from typing import Any

import torch

from monai.data.meta_tensor import MetaTensor
from monai.utils import optional_import

tqdm, _ = optional_import("tqdm", name="tqdm")

__all__ = ["point_based_window_inferer"]


def point_based_window_inferer(
inputs: torch.Tensor | MetaTensor,
roi_size: Sequence[int],
predictor: torch.nn.Module,
point_coords: torch.Tensor,
point_labels: torch.Tensor,
class_vector: torch.Tensor | None = None,
prompt_class: torch.Tensor | None = None,
prev_mask: torch.Tensor | MetaTensor | None = None,
point_start: int = 0,
**kwargs: Any,
) -> torch.Tensor:
"""
Point based window inferer, crop a patch centered at the point, and perform inference.
heyufan1995 marked this conversation as resolved.
Show resolved Hide resolved
Different patches are combined with gaussian weighted weights.
mingxin-zheng marked this conversation as resolved.
Show resolved Hide resolved

Args:
inputs: input image to be processed (assuming NCHW[D])
yiheng-wang-nv marked this conversation as resolved.
Show resolved Hide resolved
roi_size: the spatial window size for inferences.
When its components have None or non-positives, the corresponding inputs dimension will be used.
if the components of the `roi_size` are non-positive values, the transform will use the
corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted
to `(32, 64)` if the second spatial dimension size of img is `64`.
sw_batch_size: the batch size to run window slices.
predictor: partial(infer_wrapper, model). infer_wrapper transpose the model output.
The model output is [B, 1, H, W, D] which needs to be transposed to [1, B, H, W, D].
point_coords: [B, N, 3]
yiheng-wang-nv marked this conversation as resolved.
Show resolved Hide resolved
heyufan1995 marked this conversation as resolved.
Show resolved Hide resolved
point_labels: [B, N]
heyufan1995 marked this conversation as resolved.
Show resolved Hide resolved
class_vector: [B]
prev_mask: [1, B, H, W, D]. The value is before sigmoid.
heyufan1995 marked this conversation as resolved.
Show resolved Hide resolved
Returns:
stitched_output: [1, B, H, W, D]. The value is before sigmoid.
Notice: The function only supports SINGLE OBJECT INFERENCE with B=1.
"""
if not point_coords.shape[0] == 1:
yiheng-wang-nv marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError("Only supports single object point click.")
image, pad = _pad_previous_mask(copy.deepcopy(inputs), roi_size)
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
point_coords = point_coords + torch.tensor([pad[-2], pad[-4], pad[-6]]).to(point_coords.device)
prev_mask = _pad_previous_mask(copy.deepcopy(prev_mask), roi_size)[0] if prev_mask is not None else None
stitched_output = None
center_only = True
for p in point_coords[0][point_start:]:
lx_, rx_ = _get_window_idx(p[0], roi_size[0], image.shape[-3], center_only=center_only, margin=5)
heyufan1995 marked this conversation as resolved.
Show resolved Hide resolved
ly_, ry_ = _get_window_idx(p[1], roi_size[1], image.shape[-2], center_only=center_only, margin=5)
lz_, rz_ = _get_window_idx(p[2], roi_size[2], image.shape[-1], center_only=center_only, margin=5)
for i in range(len(lx_)):
for j in range(len(ly_)):
for k in range(len(lz_)):
lx, rx, ly, ry, lz, rz = (lx_[i], rx_[i], ly_[j], ry_[j], lz_[k], rz_[k])
unravel_slice = [
slice(None),
slice(None),
slice(int(lx), int(rx)),
slice(int(ly), int(ry)),
slice(int(lz), int(rz)),
]
batch_image = image[unravel_slice]
mingxin-zheng marked this conversation as resolved.
Show resolved Hide resolved
output = predictor(
batch_image,
point_coords=point_coords,
point_labels=point_labels,
class_vector=class_vector,
prompt_class=prompt_class,
patch_coords=unravel_slice,
prev_mask=prev_mask,
**kwargs,
)
if stitched_output is None:
stitched_output = torch.zeros(
[1, output.shape[1], image.shape[-3], image.shape[-2], image.shape[-1]], device="cpu"
)
stitched_mask = torch.zeros(
[1, output.shape[1], image.shape[-3], image.shape[-2], image.shape[-1]], device="cpu"
)
stitched_output[unravel_slice] += output.to("cpu")
stitched_mask[unravel_slice] = 1
# if stitched_mask is 0, then NaN value
stitched_output = stitched_output / stitched_mask
# revert padding
stitched_output = stitched_output[
:, :, pad[4] : image.shape[-3] - pad[5], pad[2] : image.shape[-2] - pad[3], pad[0] : image.shape[-1] - pad[1]
]
stitched_mask = stitched_mask[
:, :, pad[4] : image.shape[-3] - pad[5], pad[2] : image.shape[-2] - pad[3], pad[0] : image.shape[-1] - pad[1]
]
if prev_mask is not None:
prev_mask = prev_mask[
:,
:,
pad[4] : image.shape[-3] - pad[5],
pad[2] : image.shape[-2] - pad[3],
pad[0] : image.shape[-1] - pad[1],
]
prev_mask = prev_mask.to("cpu") # type: ignore
# for un-calculated place, use previous mask
stitched_output[stitched_mask < 1] = prev_mask[stitched_mask < 1]
if isinstance(inputs, torch.Tensor):
inputs = MetaTensor(inputs)
if not hasattr(stitched_output, "meta"):
stitched_output = MetaTensor(stitched_output, affine=inputs.meta["affine"], meta=inputs.meta)
return stitched_output


def _get_window_idx_c(p: int, roi: int, s: int) -> tuple[int, int]:
"""Helper function to get the window index."""
if p - roi // 2 < 0:
left, right = 0, roi
elif p + roi // 2 > s:
left, right = s - roi, s
else:
left, right = int(p) - roi // 2, int(p) + roi // 2
return left, right


def _get_window_idx(p: int, roi: int, s: int, center_only: bool = True, margin: int = 5) -> tuple[list[int], list[int]]:
"""Get the window index."""
left, right = _get_window_idx_c(p, roi, s)
if center_only:
return [left], [right]
left_most = max(0, p - roi + margin)
right_most = min(s, p + roi - margin)
left_list = [left_most, right_most - roi, left]
right_list = [left_most + roi, right_most, right]
return left_list, right_list


def _pad_previous_mask(
inputs: torch.Tensor | MetaTensor, roi_size: Sequence[int], padvalue: int = 0
) -> tuple[torch.Tensor | MetaTensor, list[int]]:
"""Helper function to pad inputs."""
pad_size = []
for k in range(len(inputs.shape) - 1, 1, -1):
diff = max(roi_size[k - 2] - inputs.shape[k], 0)
half = diff // 2
pad_size.extend([half, diff - half])
if any(pad_size):
inputs = torch.nn.functional.pad(inputs, pad=pad_size, mode="constant", value=padvalue) # type: ignore
return inputs, pad_size
156 changes: 156 additions & 0 deletions monai/apps/vista3d/transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations
yiheng-wang-nv marked this conversation as resolved.
Show resolved Hide resolved

from typing import Sequence

import numpy as np

from monai.config import DtypeLike, KeysCollection
from monai.transforms import MapLabelValue
from monai.transforms.transform import MapTransform
from monai.utils import look_up_option

heyufan1995 marked this conversation as resolved.
Show resolved Hide resolved

def _get_name_to_index_mapping(labels_dict: dict | None) -> dict:
"""get the label name to index mapping"""
name_to_index_mapping = {}
if labels_dict is not None:
name_to_index_mapping = {v.lower(): int(k) for k, v in labels_dict.items()}
return name_to_index_mapping


def _convert_name_to_index(name_to_index_mapping: dict, label_prompt: list | None) -> list | None:
"""convert the label name to index"""
if label_prompt is not None and isinstance(label_prompt, list):
converted_label_prompt = []
# for new class, add to the mapping
for l in label_prompt:
if isinstance(l, str) and not l.isdigit():
if l.lower() not in name_to_index_mapping:
name_to_index_mapping[l.lower()] = len(name_to_index_mapping)
for l in label_prompt:
if isinstance(l, (int, str)):
converted_label_prompt.append(
name_to_index_mapping.get(l.lower(), int(l) if l.isdigit() else 0) if isinstance(l, str) else int(l)
)
else:
converted_label_prompt.append(l)
return converted_label_prompt
return label_prompt


class VistaPreTransform(MapTransform):
heyufan1995 marked this conversation as resolved.
Show resolved Hide resolved
yiheng-wang-nv marked this conversation as resolved.
Show resolved Hide resolved
def __init__(
self,
keys: KeysCollection,
allow_missing_keys: bool = False,
special_index: Sequence[int] = (25, 26, 27, 28, 29, 117),
labels_dict: dict | None = None,
subclass: dict | None = None,
) -> None:
"""
Pre-transform for Vista3d.

Args:
keys: keys of the corresponding items to be transformed.
dataset_transforms: a dictionary specifies the transform for corresponding dataset:
key: dataset name, value: list of data transforms.
dataset_key: key to get the dataset name from the data dictionary, default to "dataset_name".
allow_missing_keys: don't raise exception if key is missing.
special_index: the class index that need to be handled differently.
"""
super().__init__(keys, allow_missing_keys)
self.special_index = special_index
self.subclass = subclass
self.name_to_index_mapping = _get_name_to_index_mapping(labels_dict)

def __call__(self, data):
label_prompt = data.get("label_prompt", None)
point_labels = data.get("point_labels", None)
# convert the label name to index if needed
label_prompt = _convert_name_to_index(self.name_to_index_mapping, label_prompt)
try:
# The evaluator will check prompt. The invalid prompt will be skipped here and captured by evaluator.
if self.subclass is not None and label_prompt is not None:
_label_prompt = []
subclass_keys = list(map(int, self.subclass.keys()))
for i in range(len(label_prompt)):
if label_prompt[i] in subclass_keys:
_label_prompt.extend(self.subclass[str(label_prompt[i])])
else:
_label_prompt.append(label_prompt[i])
data["label_prompt"] = _label_prompt

if label_prompt is not None and point_labels is not None:
if label_prompt[0] in self.special_index:
point_labels = np.array(point_labels)
point_labels[point_labels == 0] = 2
point_labels[point_labels == 1] = 3
point_labels = point_labels.tolist()
data["point_labels"] = point_labels
except Exception:
pass
heyufan1995 marked this conversation as resolved.
Show resolved Hide resolved

return data


class RelabelD(MapTransform):
yiheng-wang-nv marked this conversation as resolved.
Show resolved Hide resolved
def __init__(
self,
keys: KeysCollection,
label_mappings: dict[str, list[tuple[int, int]]],
dtype: DtypeLike = np.int16,
dataset_key: str = "dataset_name",
allow_missing_keys: bool = False,
) -> None:
"""
Remap the voxel labels in the input data dictionary based on the specified mapping.

This list of local -> global label mappings will be applied to each input `data[keys]`.
if `data[dataset_key]` is not in `label_mappings`, label_mappings['default']` will be used.
if `label_mappings[data[dataset_key]]` is None, no relabeling will be performed.

Args:
keys: keys of the corresponding items to be transformed.
label_mappings: a dictionary specifies how local dataset class indices are mapped to the
global class indices, format:
key: dataset name.
value: list of (local label, global label) pairs. This list of local -> global label mappings
will be applied to each input `data[keys]`. If `data[dataset_key]` is not in `label_mappings`,
label_mappings['default']` will be used. if `label_mappings[data[dataset_key]]` is None,
no relabeling will be performed.
set `label_mappings={}` to completely skip this transform.
dtype: convert the output data to dtype, default to float32.
dataset_key: key to get the dataset name from the data dictionary, default to "dataset_name".
allow_missing_keys: don't raise exception if key is missing.

"""
super().__init__(keys, allow_missing_keys)
self.mappers = {}
self.dataset_key = dataset_key
for name, mapping in label_mappings.items():
self.mappers[name] = MapLabelValue(
orig_labels=[int(pair[0]) for pair in mapping],
target_labels=[int(pair[1]) for pair in mapping],
dtype=dtype,
)

def __call__(self, data):
d = dict(data)
dataset_name = d.get(self.dataset_key, "default")
_m = look_up_option(dataset_name, self.mappers, default=None)
if _m is None:
return d
for key in self.key_iterator(d):
d[key] = _m(d[key])
return d
1 change: 1 addition & 0 deletions monai/inferers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ def sliding_window_inference(

# remove padding if image_size smaller than roi_size
if any(pad_size):
kwargs.update({"pad_size": pad_size})
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
for ss, output_i in enumerate(output_image_list):
zoom_scale = [_shape_d / _roi_size_d for _shape_d, _roi_size_d in zip(output_i.shape[2:], roi_size)]
final_slicing: list[slice] = []
Expand Down
34 changes: 33 additions & 1 deletion monai/networks/nets/vista3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,35 @@ def __init__(self, image_encoder: nn.Module, class_head: nn.Module, point_head:
self.NINF_VALUE = -9999
self.PINF_VALUE = 9999

def update_slidingwindow_padding(
self,
pad_size: list | None,
labels: torch.Tensor | None,
prev_mask: torch.Tensor | None,
point_coords: torch.Tensor | None,
):
"""
Image has been padded by sliding window inferer.
The related padding need to be performed outside of slidingwindow inferer.

Args:
pad_size: padding size passed from sliding window inferer.
labels: image label ground truth.
prev_mask: previous segmentation mask.
point_coords: point click coordinates.
"""
if pad_size is None:
return labels, prev_mask, point_coords
if labels is not None:
labels = F.pad(labels, pad=pad_size, mode="constant", value=0)
if prev_mask is not None:
prev_mask = F.pad(prev_mask, pad=pad_size, mode="constant", value=0)
if point_coords is not None:
point_coords = point_coords + torch.tensor(
[pad_size[-2], pad_size[-4], pad_size[-6]], device=point_coords.device
)
return labels, prev_mask, point_coords

def get_foreground_class_count(self, class_vector: torch.Tensor | None, point_coords: torch.Tensor | None) -> int:
"""Get number of foreground classes based on class and point prompt."""
if class_vector is None:
Expand Down Expand Up @@ -329,7 +358,7 @@ def forward(
point_coords: [B, N, 3]
point_labels: [B, N], -1 represents padding. 0/1 means negative/positive points for regular class.
2/3 means negative/postive ponits for special supported class like tumor.
class_vector: [B, 1], the global class index
class_vector: [B, 1], the global class index.
prompt_class: [B, 1], the global class index. This value is associated with point_coords to identify if
the points are for zero-shot or supported class. When class_vector and point_coords are both
provided, prompt_class is the same as class_vector. For prompt_class[b] > 512, point_coords[b]
Expand All @@ -348,6 +377,9 @@ def forward(
val_point_sampler: function used to sample points from labels. This is only used for point-only evaluation.

"""
labels, prev_mask, point_coords = self.update_slidingwindow_padding(
kwargs.get("pad_size", None), labels, prev_mask, point_coords
)
image_size = input_images.shape[-3:]
device = input_images.device
if point_coords is None and class_vector is None:
Expand Down
Loading
Loading