From 77304dd114227b8b9b1059665aedba295db0ffc7 Mon Sep 17 00:00:00 2001 From: Yufan He <59374597+heyufan1995@users.noreply.github.com> Date: Thu, 15 Aug 2024 01:50:33 -0500 Subject: [PATCH] Add vista network (#7987) Fixes # . ### Description Add VISTA3D model architecture to MONAI core ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: heyufan1995 Signed-off-by: Yufan He Signed-off-by: Yiheng Wang Signed-off-by: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Co-authored-by: Yiheng Wang Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- docs/source/networks.rst | 10 + monai/networks/nets/__init__.py | 3 +- monai/networks/nets/segresnet_ds.py | 128 +++- monai/networks/nets/vista3d.py | 908 ++++++++++++++++++++++++++++ monai/transforms/utils.py | 4 +- tests/test_segresnet_ds.py | 86 ++- tests/test_vista3d.py | 85 +++ 7 files changed, 1189 insertions(+), 35 deletions(-) create mode 100644 monai/networks/nets/vista3d.py create mode 100644 tests/test_vista3d.py diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 249375dfc1..1810fec49b 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -481,6 +481,11 @@ Nets .. autoclass:: SegResNetDS :members: +`SegResNetDS2` +~~~~~~~~~~~~~~ +.. autoclass:: SegResNetDS2 + :members: + `SegResNetVAE` ~~~~~~~~~~~~~~ .. autoclass:: SegResNetVAE @@ -556,6 +561,11 @@ Nets .. autoclass:: UNETR :members: +`VISTA3D` +~~~~~~~~~ +.. autoclass:: VISTA3D + :members: + `SwinUNETR` ~~~~~~~~~~~ .. autoclass:: SwinUNETR diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index c777fe6442..0570c9fcc1 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -76,7 +76,7 @@ resnet200, ) from .segresnet import SegResNet, SegResNetVAE -from .segresnet_ds import SegResNetDS +from .segresnet_ds import SegResNetDS, SegResNetDS2 from .senet import ( SENet, SEnet, @@ -118,6 +118,7 @@ from .unet import UNet, Unet from .unetr import UNETR from .varautoencoder import VarAutoEncoder +from .vista3d import VISTA3D, vista3d132 from .vit import ViT from .vitautoenc import ViTAutoEnc from .vnet import VNet diff --git a/monai/networks/nets/segresnet_ds.py b/monai/networks/nets/segresnet_ds.py index 6430f5fdc9..1ac5a79ee3 100644 --- a/monai/networks/nets/segresnet_ds.py +++ b/monai/networks/nets/segresnet_ds.py @@ -11,6 +11,7 @@ from __future__ import annotations +import copy from collections.abc import Callable from typing import Union @@ -23,7 +24,7 @@ from monai.networks.layers.utils import get_act_layer, get_norm_layer from monai.utils import UpsampleMode, has_option -__all__ = ["SegResNetDS"] +__all__ = ["SegResNetDS", "SegResNetDS2"] def scales_for_resolution(resolution: tuple | list, n_stages: int | None = None): @@ -425,3 +426,128 @@ def _forward(self, x: torch.Tensor) -> Union[None, torch.Tensor, list[torch.Tens def forward(self, x: torch.Tensor) -> Union[None, torch.Tensor, list[torch.Tensor]]: return self._forward(x) + + +class SegResNetDS2(SegResNetDS): + """ + SegResNetDS2 adds an additional decorder branch to SegResNetDS and is the image encoder of VISTA3D + `_. + + Args: + spatial_dims: spatial dimension of the input data. Defaults to 3. + init_filters: number of output channels for initial convolution layer. Defaults to 32. + in_channels: number of input channels for the network. Defaults to 1. + out_channels: number of output channels for the network. Defaults to 2. + act: activation type and arguments. Defaults to ``RELU``. + norm: feature normalization type and arguments. Defaults to ``BATCH``. + blocks_down: number of downsample blocks in each layer. Defaults to ``[1,2,2,4]``. + blocks_up: number of upsample blocks (optional). + dsdepth: number of levels for deep supervision. This will be the length of the list of outputs at each scale level. + At dsdepth==1,only a single output is returned. + preprocess: optional callable function to apply before the model's forward pass + resolution: optional input image resolution. When provided, the network will first use non-isotropic kernels to bring + image spacing into an approximately isotropic space. + Otherwise, by default, the kernel size and downsampling is always isotropic. + + """ + + def __init__( + self, + spatial_dims: int = 3, + init_filters: int = 32, + in_channels: int = 1, + out_channels: int = 2, + act: tuple | str = "relu", + norm: tuple | str = "batch", + blocks_down: tuple = (1, 2, 2, 4), + blocks_up: tuple | None = None, + dsdepth: int = 1, + preprocess: nn.Module | Callable | None = None, + upsample_mode: UpsampleMode | str = "deconv", + resolution: tuple | None = None, + ): + super().__init__( + spatial_dims=spatial_dims, + init_filters=init_filters, + in_channels=in_channels, + out_channels=out_channels, + act=act, + norm=norm, + blocks_down=blocks_down, + blocks_up=blocks_up, + dsdepth=dsdepth, + preprocess=preprocess, + upsample_mode=upsample_mode, + resolution=resolution, + ) + + self.up_layers_auto = nn.ModuleList([copy.deepcopy(layer) for layer in self.up_layers]) + + def forward( # type: ignore + self, x: torch.Tensor, with_point: bool = True, with_label: bool = True + ) -> tuple[Union[None, torch.Tensor, list[torch.Tensor]], Union[None, torch.Tensor, list[torch.Tensor]]]: + """ + Args: + x: input tensor. + with_point: if true, return the point branch output. + with_label: if true, return the label branch output. + """ + if self.preprocess is not None: + x = self.preprocess(x) + + if not self.is_valid_shape(x): + raise ValueError(f"Input spatial dims {x.shape} must be divisible by {self.shape_factor()}") + + x_down = self.encoder(x) + + x_down.reverse() + x = x_down.pop(0) + + if len(x_down) == 0: + x_down = [torch.zeros(1, device=x.device, dtype=x.dtype)] + + outputs: list[torch.Tensor] = [] + outputs_auto: list[torch.Tensor] = [] + x_ = x.clone() + if with_point: + i = 0 + for level in self.up_layers: + x = level["upsample"](x) + x = x + x_down[i] + x = level["blocks"](x) + + if len(self.up_layers) - i <= self.dsdepth: + outputs.append(level["head"](x)) + i = i + 1 + + outputs.reverse() + x = x_ + if with_label: + i = 0 + for level in self.up_layers_auto: + x = level["upsample"](x) + x = x + x_down[i] + x = level["blocks"](x) + + if len(self.up_layers) - i <= self.dsdepth: + outputs_auto.append(level["head"](x)) + i = i + 1 + + outputs_auto.reverse() + + return outputs[0] if len(outputs) == 1 else outputs, outputs_auto[0] if len(outputs_auto) == 1 else outputs_auto + + def set_auto_grad(self, auto_freeze=False, point_freeze=False): + """ + Args: + auto_freeze: if true, freeze the image encoder and the auto-branch. + point_freeze: if true, freeze the image encoder and the point-branch. + """ + for param in self.encoder.parameters(): + param.requires_grad = (not auto_freeze) and (not point_freeze) + + for param in self.up_layers_auto.parameters(): + param.requires_grad = not auto_freeze + + for param in self.up_layers.parameters(): + param.requires_grad = not point_freeze diff --git a/monai/networks/nets/vista3d.py b/monai/networks/nets/vista3d.py new file mode 100644 index 0000000000..fe7f93d493 --- /dev/null +++ b/monai/networks/nets/vista3d.py @@ -0,0 +1,908 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import math +from typing import Any, Callable, Optional, Sequence, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +import monai +from monai.networks.blocks import MLPBlock, UnetrBasicBlock +from monai.networks.nets import SegResNetDS2 +from monai.transforms.utils import convert_points_to_disc +from monai.transforms.utils import get_largest_connected_component_mask_point as lcc +from monai.transforms.utils import sample_points_from_label +from monai.utils import optional_import, unsqueeze_left, unsqueeze_right + +rearrange, _ = optional_import("einops", name="rearrange") + +__all__ = ["VISTA3D", "vista3d132"] + + +def vista3d132(encoder_embed_dim: int = 48, in_channels: int = 1): + """ + Exact VISTA3D network configuration used in https://arxiv.org/abs/2406.05285>`_. + The model treats class index larger than 132 as zero-shot. + + Args: + encoder_embed_dim: hidden dimension for encoder. + in_channels: input channel number. + """ + segresnet = SegResNetDS2( + in_channels=in_channels, + blocks_down=(1, 2, 2, 4, 4), + norm="instance", + out_channels=encoder_embed_dim, + init_filters=encoder_embed_dim, + dsdepth=1, + ) + point_head = PointMappingSAM(feature_size=encoder_embed_dim, n_classes=512, last_supported=132) + class_head = ClassMappingClassify(n_classes=512, feature_size=encoder_embed_dim, use_mlp=True) + vista = VISTA3D(image_encoder=segresnet, class_head=class_head, point_head=point_head) + return vista + + +class VISTA3D(nn.Module): + """ + VISTA3D based on: + `VISTA3D: Versatile Imaging SegmenTation and Annotation model for 3D Computed Tomography + `_. + + Args: + image_encoder: image encoder backbone for feature extraction. + class_head: class head used for class index based segmentation + point_head: point head used for interactive segmetnation + """ + + def __init__(self, image_encoder: nn.Module, class_head: nn.Module, point_head: nn.Module): + super().__init__() + self.image_encoder = image_encoder + self.class_head = class_head + self.point_head = point_head + self.image_embeddings = None + self.auto_freeze = False + self.point_freeze = False + self.NINF_VALUE = -9999 + self.PINF_VALUE = 9999 + + def get_foreground_class_count(self, class_vector: torch.Tensor | None, point_coords: torch.Tensor | None) -> int: + """Get number of foreground classes based on class and point prompt.""" + if class_vector is None: + if point_coords is None: + raise ValueError("class_vector and point_coords cannot be both None.") + return point_coords.shape[0] + else: + return class_vector.shape[0] + + def convert_point_label( + self, + point_label: torch.Tensor, + label_set: Sequence[int] | None = None, + special_index: Sequence[int] = (23, 24, 25, 26, 27, 57, 128), + ): + """ + Convert point label based on its class prompt. For special classes defined in special index, + the positive/negative point label will be converted from 1/0 to 3/2. The purpose is to separate those + classes with ambiguous classes. + + Args: + point_label: the point label tensor, [B, N]. + label_set: the label index matching the indexes in labels. If labels are mapped to global index using RelabelID, + this label_set should be global mapped index. If labels are not mapped to global index, e.g. in zero-shot + evaluation, this label_set should be the original index. + special_index: the special class index that needs to be converted. + """ + if label_set is None: + return point_label + if not point_label.shape[0] == len(label_set): + raise ValueError("point_label and label_set must have the same length.") + + for i in range(len(label_set)): + if label_set[i] in special_index: + for j in range(len(point_label[i])): + point_label[i, j] = point_label[i, j] + 2 if point_label[i, j] > -1 else point_label[i, j] + return point_label + + def sample_points_patch_val( + self, + labels: torch.Tensor, + patch_coords: Sequence[slice], + label_set: Sequence[int], + use_center: bool = True, + mapped_label_set: Sequence[int] | None = None, + max_ppoint: int = 1, + max_npoint: int = 0, + ): + """ + Sample points for patch during sliding window validation. Only used for point only validation. + + Args: + labels: shape [1, 1, H, W, D]. + patch_coords: a sequence of sliding window slice objects. + label_set: local index, must match values in labels. + use_center: sample points from the center. + mapped_label_set: global index, it is used to identify special classes and is the global index + for the sampled points. + max_ppoint/max_npoint: positive points and negative points to sample. + """ + point_coords, point_labels = sample_points_from_label( + labels[patch_coords], + label_set, + max_ppoint=max_ppoint, + max_npoint=max_npoint, + device=labels.device, + use_center=use_center, + ) + point_labels = self.convert_point_label(point_labels, mapped_label_set) + return (point_coords, point_labels, torch.tensor(label_set).to(point_coords.device).unsqueeze(-1)) + + def update_point_to_patch( + self, patch_coords: Sequence[slice], point_coords: torch.Tensor, point_labels: torch.Tensor + ): + """ + Update point_coords with respect to patch coords. + If point is outside of the patch, remove the coordinates and set label to -1. + + Args: + patch_coords: a sequence of the python slice objects representing the patch coordinates during sliding window inference. + This value is passed from sliding_window_inferer. + point_coords: point coordinates, [B, N, 3]. + point_labels: point labels, [B, N]. + """ + patch_ends = [patch_coords[-3].stop, patch_coords[-2].stop, patch_coords[-1].stop] + patch_starts = [patch_coords[-3].start, patch_coords[-2].start, patch_coords[-1].start] + # update point coords + patch_starts_tensor = unsqueeze_left(torch.tensor(patch_starts, device=point_coords.device), 2) + patch_ends_tensor = unsqueeze_left(torch.tensor(patch_ends, device=point_coords.device), 2) + # [1 N 1] + indices = torch.logical_and( + ((point_coords - patch_starts_tensor) > 0).all(2), ((patch_ends_tensor - point_coords) > 0).all(2) + ) + # check if it's within patch coords + point_coords = point_coords.clone() - patch_starts_tensor + point_labels = point_labels.clone() + if indices.any(): + point_labels[~indices] = -1 + point_coords[~indices] = 0 + # also remove padded points, mainly used for inference. + not_pad_indices = (point_labels != -1).any(0) + point_coords = point_coords[:, not_pad_indices] + point_labels = point_labels[:, not_pad_indices] + return point_coords, point_labels + return None, None + + def connected_components_combine( + self, + logits: torch.Tensor, + point_logits: torch.Tensor, + point_coords: torch.Tensor, + point_labels: torch.Tensor, + mapping_index: torch.Tensor, + thred: float = 0.5, + ): + """ + Combine auto results with point click response. The auto results have shape [B, 1, H, W, D] which means B foreground masks + from a single image patch. + Out of those B foreground masks, user may add points to a subset of B1 foreground masks for editing. + mapping_index represents the correspondence between B and B1. + For mapping_index with point clicks, NaN values in logits will be replaced with point_logits. Meanwhile, the added/removed + region in point clicks must be updated by the lcc function. + Notice, if a positive point is within logits/prev_mask, the components containing the positive point will be added. + + Args: + logits: automatic branch results, [B, 1, H, W, D]. + point_logits: point branch results, [B1, 1, H, W, D]. + point_coords: point coordinates, [B1, N, 3]. + point_labels: point labels, [B1, N]. + mapping_index: [B]. + thred: the threshold to convert logits to binary. + """ + logits = logits.as_tensor() if isinstance(logits, monai.data.MetaTensor) else logits + _logits = logits[mapping_index] + inside = [] + for i in range(_logits.shape[0]): + inside.append( + np.any( + [ + _logits[i, 0, p[0], p[1], p[2]].item() > 0 + for p in point_coords[i].cpu().numpy().round().astype(int) + ] + ) + ) + inside_tensor = torch.tensor(inside).to(logits.device) + nan_mask = torch.isnan(_logits) + # _logits are converted to binary [B1, 1, H, W, D] + _logits = torch.nan_to_num(_logits, nan=self.NINF_VALUE).sigmoid() + pos_region = point_logits.sigmoid() > thred + diff_pos = torch.logical_and(torch.logical_or(_logits <= thred, unsqueeze_right(inside_tensor, 5)), pos_region) + diff_neg = torch.logical_and((_logits > thred), ~pos_region) + cc = lcc(diff_pos, diff_neg, point_coords=point_coords, point_labels=point_labels) + # cc is the region that can be updated by point_logits. + cc = cc.to(logits.device) + # Need to replace NaN with point_logits. diff_neg will never lie in nan_mask, + # only remove unconnected positive region. + uc_pos_region = torch.logical_and(pos_region, ~cc) + fill_mask = torch.logical_and(nan_mask, uc_pos_region) + if fill_mask.any(): + # fill in the mean negative value + point_logits[fill_mask] = -1 + # replace logits nan value and cc with point_logits + cc = torch.logical_or(nan_mask, cc).to(logits.dtype) + logits[mapping_index] *= 1 - cc + logits[mapping_index] += cc * point_logits + return logits + + def gaussian_combine( + self, + logits: torch.Tensor, + point_logits: torch.Tensor, + point_coords: torch.Tensor, + point_labels: torch.Tensor, + mapping_index: torch.Tensor, + radius: int | None = None, + ): + """ + Combine point results with auto results using gaussian. + + Args: + logits: automatic branch results, [B, 1, H, W, D]. + point_logits: point branch results, [B1, 1, H, W, D]. + point_coords: point coordinates, [B1, N, 3]. + point_labels: point labels, [B1, N]. + mapping_index: [B]. + radius: gaussian ball radius. + """ + if radius is None: + radius = min(point_logits.shape[-3:]) // 5 # empirical value 5 + weight = 1 - convert_points_to_disc(point_logits.shape[-3:], point_coords, point_labels, radius=radius).sum( + 1, keepdims=True + ) + weight[weight < 0] = 0 + logits = logits.as_tensor() if isinstance(logits, monai.data.MetaTensor) else logits + logits[mapping_index] *= weight + logits[mapping_index] += (1 - weight) * point_logits + return logits + + def set_auto_grad(self, auto_freeze: bool = False, point_freeze: bool = False): + """ + Freeze auto-branch or point-branch. + + Args: + auto_freeze: whether to freeze the auto branch. + point_freeze: whether to freeze the point branch. + """ + if auto_freeze != self.auto_freeze: + if hasattr(self.image_encoder, "set_auto_grad"): + self.image_encoder.set_auto_grad(auto_freeze=auto_freeze, point_freeze=point_freeze) + else: + for param in self.image_encoder.parameters(): + param.requires_grad = (not auto_freeze) and (not point_freeze) + for param in self.class_head.parameters(): + param.requires_grad = not auto_freeze + self.auto_freeze = auto_freeze + + if point_freeze != self.point_freeze: + if hasattr(self.image_encoder, "set_auto_grad"): + self.image_encoder.set_auto_grad(auto_freeze=auto_freeze, point_freeze=point_freeze) + else: + for param in self.image_encoder.parameters(): + param.requires_grad = (not auto_freeze) and (not point_freeze) + for param in self.point_head.parameters(): + param.requires_grad = not point_freeze + self.point_freeze = point_freeze + + def forward( + self, + input_images: torch.Tensor, + point_coords: torch.Tensor | None = None, + point_labels: torch.Tensor | None = None, + class_vector: torch.Tensor | None = None, + prompt_class: torch.Tensor | None = None, + patch_coords: Sequence[slice] | None = None, + labels: torch.Tensor | None = None, + label_set: Sequence[int] | None = None, + prev_mask: torch.Tensor | None = None, + radius: int | None = None, + val_point_sampler: Callable | None = None, + **kwargs, + ): + """ + The forward function for VISTA3D. We only support single patch in training and inference. + One exception is allowing sliding window batch size > 1 for automatic segmentation only case. + B represents number of objects, N represents number of points for each objects. + + Args: + input_images: [1, 1, H, W, D] + point_coords: [B, N, 3] + point_labels: [B, N], -1 represents padding. 0/1 means negative/positive points for regular class. + 2/3 means negative/postive ponits for special supported class like tumor. + class_vector: [B, 1], the global class index + prompt_class: [B, 1], the global class index. This value is associated with point_coords to identify if + the points are for zero-shot or supported class. When class_vector and point_coords are both + provided, prompt_class is the same as class_vector. For prompt_class[b] > 512, point_coords[b] + will be considered novel class. + patch_coords: a sequence of the python slice objects representing the patch coordinates during sliding window inference. + This value is passed from sliding_window_inferer. This is an indicator for training phase or validation phase. + labels: [1, 1, H, W, D], the groundtruth label tensor, only used for point-only evaluation + label_set: the label index matching the indexes in labels. If labels are mapped to global index using RelabelID, + this label_set should be global mapped index. If labels are not mapped to global index, e.g. in zero-shot + evaluation, this label_set should be the original index. + prev_mask: [B, N, H_fullsize, W_fullsize, D_fullsize]. + This is the transposed raw output from sliding_window_inferer before any postprocessing. + When user click points to perform auto-results correction, this can be the auto-results. + radius: single float value controling the gaussian blur when combining point and auto results. + The gaussian combine is not used in VISTA3D training but might be useful for finetuning purposes. + val_point_sampler: function used to sample points from labels. This is only used for point-only evaluation. + + """ + image_size = input_images.shape[-3:] + device = input_images.device + if point_coords is None and class_vector is None: + return self.NINF_VALUE + torch.zeros([1, 1, *image_size], device=device) + + bs = self.get_foreground_class_count(class_vector, point_coords) + if patch_coords is not None: + # if during validation and perform enable based point-validation. + if labels is not None and label_set is not None: + # if labels is not None, sample from labels for each patch. + if val_point_sampler is None: + # TODO: think about how to refactor this part. + val_point_sampler = self.sample_points_patch_val + point_coords, point_labels, prompt_class = val_point_sampler(labels, patch_coords, label_set) + if prompt_class[0].item() == 0: # type: ignore + point_labels[0] = -1 # type: ignore + labels, prev_mask = None, None + elif point_coords is not None: + # If not performing patch-based point only validation, use user provided click points for inference. + # the point clicks is in original image space, convert it to current patch-coordinate space. + point_coords, point_labels = self.update_point_to_patch(patch_coords, point_coords, point_labels) # type: ignore + + if point_coords is not None and point_labels is not None: + # remove points that used for padding purposes (point_label = -1) + mapping_index = ((point_labels != -1).sum(1) > 0).to(torch.bool) + if mapping_index.any(): + point_coords = point_coords[mapping_index] + point_labels = point_labels[mapping_index] + if prompt_class is not None: + prompt_class = prompt_class[mapping_index] + else: + if self.auto_freeze or (class_vector is None and patch_coords is None): + # if auto_freeze, point prompt must exist to allow loss backward + # in training, class_vector and point cannot both be None due to loss.backward() + mapping_index.fill_(True) + else: + point_coords, point_labels = None, None + + if point_coords is None and class_vector is None: + return self.NINF_VALUE + torch.zeros([bs, 1, *image_size], device=device) + + if self.image_embeddings is not None and kwargs.get("keep_cache", False) and class_vector is None: + out, out_auto = self.image_embeddings, None + else: + out, out_auto = self.image_encoder( + input_images, with_point=point_coords is not None, with_label=class_vector is not None + ) + # release memory + input_images = None # type: ignore + + # force releasing memories that set to None + torch.cuda.empty_cache() + if class_vector is not None: + logits, _ = self.class_head(out_auto, class_vector) + if point_coords is not None: + point_logits = self.point_head(out, point_coords, point_labels, class_vector=prompt_class) + if patch_coords is None: + logits = self.gaussian_combine( + logits, point_logits, point_coords, point_labels, mapping_index, radius # type: ignore + ) + else: + # during validation use largest component + logits = self.connected_components_combine( + logits, point_logits, point_coords, point_labels, mapping_index # type: ignore + ) + else: + logits = self.NINF_VALUE + torch.zeros([bs, 1, *image_size], device=device, dtype=out.dtype) + logits[mapping_index] = self.point_head(out, point_coords, point_labels, class_vector=prompt_class) + if prev_mask is not None and patch_coords is not None: + logits = self.connected_components_combine( + prev_mask[patch_coords].transpose(1, 0).to(logits.device), + logits[mapping_index], + point_coords, # type: ignore + point_labels, # type: ignore + mapping_index, + ) + + if kwargs.get("keep_cache", False) and class_vector is None: + self.image_embeddings = out.detach() + return logits + + +class PointMappingSAM(nn.Module): + def __init__(self, feature_size: int, max_prompt: int = 32, n_classes: int = 512, last_supported: int = 132): + """Interactive point head used for VISTA3D. + Adapted from segment anything: + `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/mask_decoder.py`. + + Args: + feature_size: feature channel from encoder. + max_prompt: max prompt number in each forward iteration. + n_classes: number of classes the model can potentially support. This is the maximum number of class embeddings. + last_supported: number of classes the model support, this value should match the trained model weights. + """ + super().__init__() + transformer_dim = feature_size + self.max_prompt = max_prompt + self.feat_downsample = nn.Sequential( + nn.Conv3d(in_channels=feature_size, out_channels=feature_size, kernel_size=3, stride=2, padding=1), + nn.InstanceNorm3d(feature_size), + nn.GELU(), + nn.Conv3d(in_channels=feature_size, out_channels=transformer_dim, kernel_size=3, stride=1, padding=1), + nn.InstanceNorm3d(feature_size), + ) + + self.mask_downsample = nn.Conv3d(in_channels=2, out_channels=2, kernel_size=3, stride=2, padding=1) + + self.transformer = TwoWayTransformer(depth=2, embedding_dim=transformer_dim, mlp_dim=512, num_heads=4) + self.pe_layer = PositionEmbeddingRandom(transformer_dim // 2) + self.point_embeddings = nn.ModuleList([nn.Embedding(1, transformer_dim), nn.Embedding(1, transformer_dim)]) + self.not_a_point_embed = nn.Embedding(1, transformer_dim) + self.special_class_embed = nn.Embedding(1, transformer_dim) + self.mask_tokens = nn.Embedding(1, transformer_dim) + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose3d(transformer_dim, transformer_dim, kernel_size=3, stride=2, padding=1, output_padding=1), + nn.InstanceNorm3d(transformer_dim), + nn.GELU(), + nn.Conv3d(transformer_dim, transformer_dim, kernel_size=3, stride=1, padding=1), + ) + + self.output_hypernetworks_mlps = MLP(transformer_dim, transformer_dim, transformer_dim, 3) + # class embedding + self.n_classes = n_classes + self.last_supported = last_supported + self.class_embeddings = nn.Embedding(n_classes, feature_size) + self.zeroshot_embed = nn.Embedding(1, transformer_dim) + self.supported_embed = nn.Embedding(1, transformer_dim) + + def forward( + self, + out: torch.Tensor, + point_coords: torch.Tensor, + point_labels: torch.Tensor, + class_vector: torch.Tensor | None = None, + ): + """Args: + out: feature from encoder, [1, C, H, W, C] + point_coords: point coordinates, [B, N, 3] + point_labels: point labels, [B, N] + class_vector: class prompts, [B] + """ + # downsample out + out_low = self.feat_downsample(out) + out_shape = tuple(out.shape[-3:]) + # release memory + out = None # type: ignore + torch.cuda.empty_cache() + # embed points + points = point_coords + 0.5 # Shift to center of pixel + point_embedding = self.pe_layer.forward_with_coords(points, out_shape) # type: ignore + point_embedding[point_labels == -1] = 0.0 + point_embedding[point_labels == -1] += self.not_a_point_embed.weight + point_embedding[point_labels == 0] += self.point_embeddings[0].weight + point_embedding[point_labels == 1] += self.point_embeddings[1].weight + point_embedding[point_labels == 2] += self.point_embeddings[0].weight + self.special_class_embed.weight + point_embedding[point_labels == 3] += self.point_embeddings[1].weight + self.special_class_embed.weight + output_tokens = self.mask_tokens.weight + + output_tokens = output_tokens.unsqueeze(0).expand(point_embedding.size(0), -1, -1) + if class_vector is None: + tokens_all = torch.cat( + ( + output_tokens, + point_embedding, + self.supported_embed.weight.unsqueeze(0).expand(point_embedding.size(0), -1, -1), + ), + dim=1, + ) + # tokens_all = torch.cat((output_tokens, point_embedding), dim=1) + else: + class_embeddings = [] + for i in class_vector: + if i > self.last_supported: + class_embeddings.append(self.zeroshot_embed.weight) + else: + class_embeddings.append(self.supported_embed.weight) + tokens_all = torch.cat((output_tokens, point_embedding, torch.stack(class_embeddings)), dim=1) + # cross attention + masks = [] + max_prompt = self.max_prompt + for i in range(int(np.ceil(tokens_all.shape[0] / max_prompt))): + # remove variables in previous for loops to save peak memory for self.transformer + src, upscaled_embedding, hyper_in = None, None, None + torch.cuda.empty_cache() + idx = (i * max_prompt, min((i + 1) * max_prompt, tokens_all.shape[0])) + tokens = tokens_all[idx[0] : idx[1]] + src = torch.repeat_interleave(out_low, tokens.shape[0], dim=0) + pos_src = torch.repeat_interleave(self.pe_layer(out_low.shape[-3:]).unsqueeze(0), tokens.shape[0], dim=0) + b, c, h, w, d = src.shape + hs, src = self.transformer(src, pos_src, tokens) + mask_tokens_out = hs[:, :1, :] + hyper_in = self.output_hypernetworks_mlps(mask_tokens_out) + src = src.transpose(1, 2).view(b, c, h, w, d) # type: ignore + upscaled_embedding = self.output_upscaling(src) + b, c, h, w, d = upscaled_embedding.shape + mask = hyper_in @ upscaled_embedding.view(b, c, h * w * d) + masks.append(mask.view(-1, 1, h, w, d)) + + return torch.vstack(masks) + + +class ClassMappingClassify(nn.Module): + """Class head that performs automatic segmentation based on class vector.""" + + def __init__(self, n_classes: int, feature_size: int, use_mlp: bool = True): + """Args: + n_classes: maximum number of class embedding. + feature_size: class embedding size. + use_mlp: use mlp to further map class embedding. + """ + super().__init__() + self.use_mlp = use_mlp + if use_mlp: + self.mlp = nn.Sequential( + nn.Linear(feature_size, feature_size), + nn.InstanceNorm1d(1), + nn.GELU(), + nn.Linear(feature_size, feature_size), + ) + self.class_embeddings = nn.Embedding(n_classes, feature_size) + self.image_post_mapping = nn.Sequential( + UnetrBasicBlock( + spatial_dims=3, + in_channels=feature_size, + out_channels=feature_size, + kernel_size=3, + stride=1, + norm_name="instance", + res_block=True, + ), + UnetrBasicBlock( + spatial_dims=3, + in_channels=feature_size, + out_channels=feature_size, + kernel_size=3, + stride=1, + norm_name="instance", + res_block=True, + ), + ) + + def forward(self, src: torch.Tensor, class_vector: torch.Tensor): + b, c, h, w, d = src.shape + src = self.image_post_mapping(src) + class_embedding = self.class_embeddings(class_vector) + if self.use_mlp: + class_embedding = self.mlp(class_embedding) + # [b,1,feat] @ [1,feat,dim], batch dimension become class_embedding batch dimension. + masks = [] + for i in range(b): + mask = class_embedding @ src[[i]].view(1, c, h * w * d) + masks.append(mask.view(-1, 1, h, w, d)) + + return torch.cat(masks, 1), class_embedding + + +class TwoWayTransformer(nn.Module): + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: tuple | str = "relu", + attention_downsample_rate: int = 2, + ) -> None: + """ + A transformer decoder that attends to an input image using + queries whose positional embedding is supplied. + Adapted from `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/transformer.py`. + + Args: + depth: number of layers in the transformer. + embedding_dim: the channel dimension for the input embeddings. + num_heads: the number of heads for multihead attention. Must divide embedding_dim. + mlp_dim: the channel dimension internal to the MLP block. + activation: the activation to use in the MLP block. + attention_downsample_rate: the rate at which to downsample the image before projecting. + """ + super().__init__() + self.depth = depth + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.layers = nn.ModuleList() + + for i in range(depth): + self.layers.append( + TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + ) + ) + + self.final_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate) + self.norm_final_attn = nn.LayerNorm(embedding_dim) + + def forward( + self, image_embedding: torch.Tensor, image_pe: torch.Tensor, point_embedding: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + image_embedding: image to attend to. Should be shape + B x embedding_dim x h x w for any h and w. + image_pe: the positional encoding to add to the image. Must + have the same shape as image_embedding. + point_embedding: the embedding to add to the query points. + Must have shape B x N_points x embedding_dim for any N_points. + + Returns: + torch.Tensor: the processed point_embedding. + torch.Tensor: the processed image_embedding. + """ + # BxCxHxW -> BxHWxC == B x N_image_tokens x C + image_embedding = image_embedding.flatten(2).permute(0, 2, 1) + image_pe = image_pe.flatten(2).permute(0, 2, 1) + + # Prepare queries + queries = point_embedding + keys = image_embedding + + # Apply transformer blocks and final layernorm + for layer in self.layers: + queries, keys = layer(queries=queries, keys=keys, query_pe=point_embedding, key_pe=image_pe) + + # Apply the final attention layer from the points to the image + q = queries + point_embedding + k = keys + image_pe + attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm_final_attn(queries) + + return queries, keys + + +class TwoWayAttentionBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int = 2048, + activation: tuple | str = "relu", + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + A transformer block with four layers: (1) self-attention of sparse + inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp + block on sparse inputs, and (4) cross attention of dense inputs to sparse + inputs. + Adapted from `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/transformer.py`. + + Args: + embedding_dim: the channel dimension of the embeddings. + num_heads: the number of heads in the attention layers. + mlp_dim: the hidden dimension of the mlp block. + activation: the activation of the mlp block. + skip_first_layer_pe: skip the PE on the first layer. + """ + super().__init__() + self.self_attn = Attention(embedding_dim, num_heads) + self.norm1 = nn.LayerNorm(embedding_dim) + + self.cross_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate) + self.norm2 = nn.LayerNorm(embedding_dim) + + self.mlp = MLPBlock(hidden_size=embedding_dim, mlp_dim=mlp_dim, act=activation, dropout_mode="vista3d") + self.norm3 = nn.LayerNorm(embedding_dim) + + self.norm4 = nn.LayerNorm(embedding_dim) + self.cross_attn_image_to_token = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, queries: torch.Tensor, keys: torch.Tensor, query_pe: torch.Tensor, key_pe: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(q=queries, k=queries, v=queries) + else: + q = queries + query_pe + attn_out = self.self_attn(q=q, k=q, v=queries) + queries = queries + attn_out + queries = self.norm1(queries) + + # Cross attention block, tokens attending to image embedding + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.norm3(queries) + + # Cross attention block, image embedding attending to tokens + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) + keys = keys + attn_out + keys = self.norm4(keys) + + return queries, keys + + +class Attention(nn.Module): + """ + An attention layer that allows for downscaling the size of the embedding + after projection to queries, keys, and values. + Adapted from `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/transformer.py`. + + Args: + embedding_dim: the channel dimension of the embeddings. + num_heads: the number of heads in the attention layers. + downsample_rate: the rate at which to downsample the image before projecting. + """ + + def __init__(self, embedding_dim: int, num_heads: int, downsample_rate: int = 1) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.internal_dim = embedding_dim // downsample_rate + self.num_heads = num_heads + if not self.internal_dim % num_heads == 0: + raise ValueError("num_heads must divide embedding_dim.") + + self.q_proj = nn.Linear(embedding_dim, self.internal_dim) + self.k_proj = nn.Linear(embedding_dim, self.internal_dim) + self.v_proj = nn.Linear(embedding_dim, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, embedding_dim) + + def _separate_heads(self, x: torch.Tensor, num_heads: int) -> torch.Tensor: + b, n, c = x.shape + x = x.reshape(b, n, num_heads, c // num_heads) + # B x N_heads x N_tokens x C_per_head + return x.transpose(1, 2) + + def _recombine_heads(self, x: torch.Tensor) -> torch.Tensor: + b, n_heads, n_tokens, c_per_head = x.shape + x = x.transpose(1, 2) + # B x N_tokens x C + return x.reshape(b, n_tokens, n_heads * c_per_head) + + def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Attention + _, _, _, c_per_head = q.shape + attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens + attn = attn / math.sqrt(c_per_head) + attn = torch.softmax(attn, dim=-1) + + # Get output + out = attn @ v + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out + + +class PositionEmbeddingRandom(nn.Module): + """ + Positional encoding using random spatial frequencies. + Adapted from `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/prompt_encoder.py`. + + Args: + num_pos_feats: the number of positional encoding features. + scale: the scale of the positional encoding. + """ + + def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: + super().__init__() + if scale is None or scale <= 0.0: + scale = 1.0 + self.register_buffer("positional_encoding_gaussian_matrix", scale * torch.randn((3, num_pos_feats))) + + def _pe_encoding(self, coords: torch.torch.Tensor) -> torch.torch.Tensor: + """Positionally encode points that are normalized to [0,1].""" + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coords = 2 * coords - 1 + # [bs=1,N=2,2] @ [2,128] + # [bs=1, N=2, 128] + coords = coords @ self.positional_encoding_gaussian_matrix + coords = 2 * np.pi * coords + # outputs d_1 x ... x d_n x C shape + # [bs=1, N=2, 128+128=256] + return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) + + def forward(self, size: Tuple[int, int, int]) -> torch.torch.Tensor: + """Generate positional encoding for a grid of the specified size.""" + h, w, d = size + device: Any = self.positional_encoding_gaussian_matrix.device + grid = torch.ones((h, w, d), device=device, dtype=torch.float32) + x_embed = grid.cumsum(dim=0) - 0.5 + y_embed = grid.cumsum(dim=1) - 0.5 + z_embed = grid.cumsum(dim=2) - 0.5 + x_embed = x_embed / h + y_embed = y_embed / w + z_embed = z_embed / d + pe = self._pe_encoding(torch.stack([x_embed, y_embed, z_embed], dim=-1)) + # C x H x W + return pe.permute(3, 0, 1, 2) + + def forward_with_coords( + self, coords_input: torch.torch.Tensor, image_size: Tuple[int, int, int] + ) -> torch.torch.Tensor: + """Positionally encode points that are not normalized to [0,1].""" + coords = coords_input.clone() + coords[:, :, 0] = coords[:, :, 0] / image_size[0] + coords[:, :, 1] = coords[:, :, 1] / image_size[1] + coords[:, :, 2] = coords[:, :, 2] / image_size[2] + # B x N x C + return self._pe_encoding(coords.to(torch.float)) + + +class MLP(nn.Module): + """ + Multi-layer perceptron. This class is only used for `PointMappingSAM`. + Adapted from `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/mask_decoder.py`. + + Args: + input_dim: the input dimension. + hidden_dim: the hidden dimension. + output_dim: the output dimension. + num_layers: the number of layers. + sigmoid_output: whether to apply a sigmoid activation to the output. + """ + + def __init__( + self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, sigmoid_output: bool = False + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + self.sigmoid_output = sigmoid_output + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = F.sigmoid(x) + return x diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index e32bf6fc48..363fce91be 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -1274,10 +1274,10 @@ def convert_points_to_disc( coords = unsqueeze_left(torch.stack((coord_rows, coord_cols, coord_z), dim=0), 6) coords = coords.repeat(point.shape[0], 2, 1, 1, 1, 1) for b, n in np.ndindex(*point.shape[:2]): - point_bn = unsqueeze_right(point[b, n], 6) + point_bn = unsqueeze_right(point[b, n], 4) if point_label[b, n] > -1: channel = 0 if (point_label[b, n] == 0 or point_label[b, n] == 2) else 1 - pow_diff = torch.pow(coords[b, channel] - point_bn[b, n], 2) + pow_diff = torch.pow(coords[b, channel] - point_bn, 2) if disc: masks[b, channel] += pow_diff.sum(0) < radius**2 else: diff --git a/tests/test_segresnet_ds.py b/tests/test_segresnet_ds.py index 5372fcc8ae..eab7bac9a0 100644 --- a/tests/test_segresnet_ds.py +++ b/tests/test_segresnet_ds.py @@ -17,7 +17,7 @@ from parameterized import parameterized from monai.networks import eval_mode -from monai.networks.nets import SegResNetDS +from monai.networks.nets import SegResNetDS, SegResNetDS2 from tests.utils import SkipIfBeforePyTorchVersion, test_script_save device = "cuda" if torch.cuda.is_available() else "cpu" @@ -71,7 +71,7 @@ ] -class TestResNetDS(unittest.TestCase): +class TestSegResNetDS(unittest.TestCase): @parameterized.expand(TEST_CASE_SEGRESNET_DS) def test_shape(self, input_param, input_shape, expected_shape): @@ -80,47 +80,71 @@ def test_shape(self, input_param, input_shape, expected_shape): result = net(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape, msg=str(input_param)) + @parameterized.expand(TEST_CASE_SEGRESNET_DS) + def test_shape_ds2(self, input_param, input_shape, expected_shape): + net = SegResNetDS2(**input_param).to(device) + with eval_mode(net): + result = net(torch.randn(input_shape).to(device), with_label=False) + self.assertEqual(result[0].shape, expected_shape, msg=str(input_param)) + self.assertTrue(result[1] == []) + + result = net(torch.randn(input_shape).to(device), with_point=False) + self.assertEqual(result[1].shape, expected_shape, msg=str(input_param)) + self.assertTrue(result[0] == []) + @parameterized.expand(TEST_CASE_SEGRESNET_DS2) def test_shape2(self, input_param, input_shape, expected_shape): dsdepth = input_param.get("dsdepth", 1) - net = SegResNetDS(**input_param).to(device) - - net.train() - result = net(torch.randn(input_shape).to(device)) - if dsdepth > 1: - assert isinstance(result, list) - self.assertEqual(dsdepth, len(result)) - for i in range(dsdepth): - self.assertEqual( - result[i].shape, - expected_shape[:2] + tuple(e // (2**i) for e in expected_shape[2:]), - msg=str(input_param), - ) - else: - assert isinstance(result, torch.Tensor) - self.assertEqual(result.shape, expected_shape, msg=str(input_param)) - - net.eval() - result = net(torch.randn(input_shape).to(device)) - assert isinstance(result, torch.Tensor) - self.assertEqual(result.shape, expected_shape, msg=str(input_param)) + for net in [SegResNetDS, SegResNetDS2]: + net = net(**input_param).to(device) + net.train() + if isinstance(net, SegResNetDS2): + result = net(torch.randn(input_shape).to(device), with_label=False)[0] + else: + result = net(torch.randn(input_shape).to(device)) + if dsdepth > 1: + assert isinstance(result, list) + self.assertEqual(dsdepth, len(result)) + for i in range(dsdepth): + self.assertEqual( + result[i].shape, + expected_shape[:2] + tuple(e // (2**i) for e in expected_shape[2:]), + msg=str(input_param), + ) + else: + assert isinstance(result, torch.Tensor) + self.assertEqual(result.shape, expected_shape, msg=str(input_param)) + + if not isinstance(net, SegResNetDS2): + # eval mode of SegResNetDS2 has same output as training mode + # so only test eval mode for SegResNetDS + net.eval() + result = net(torch.randn(input_shape).to(device)) + assert isinstance(result, torch.Tensor) + self.assertEqual(result.shape, expected_shape, msg=str(input_param)) @parameterized.expand(TEST_CASE_SEGRESNET_DS3) def test_shape3(self, input_param, input_shape, expected_shapes): dsdepth = input_param.get("dsdepth", 1) - net = SegResNetDS(**input_param).to(device) - - net.train() - result = net(torch.randn(input_shape).to(device)) - assert isinstance(result, list) - self.assertEqual(dsdepth, len(result)) - for i in range(dsdepth): - self.assertEqual(result[i].shape, expected_shapes[i], msg=str(input_param)) + for net in [SegResNetDS, SegResNetDS2]: + net = net(**input_param).to(device) + net.train() + if isinstance(net, SegResNetDS2): + result = net(torch.randn(input_shape).to(device), with_point=False)[1] + else: + result = net(torch.randn(input_shape).to(device)) + assert isinstance(result, list) + self.assertEqual(dsdepth, len(result)) + for i in range(dsdepth): + self.assertEqual(result[i].shape, expected_shapes[i], msg=str(input_param)) def test_ill_arg(self): with self.assertRaises(ValueError): SegResNetDS(spatial_dims=4) + with self.assertRaises(ValueError): + SegResNetDS2(spatial_dims=4) + @SkipIfBeforePyTorchVersion((1, 10)) def test_script(self): input_param, input_shape, _ = TEST_CASE_SEGRESNET_DS[0] diff --git a/tests/test_vista3d.py b/tests/test_vista3d.py new file mode 100644 index 0000000000..d3b4e0c10e --- /dev/null +++ b/tests/test_vista3d.py @@ -0,0 +1,85 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets import VISTA3D, SegResNetDS2 +from monai.networks.nets.vista3d import ClassMappingClassify, PointMappingSAM +from tests.utils import SkipIfBeforePyTorchVersion, skip_if_quick + +device = "cuda" if torch.cuda.is_available() else "cpu" + +TEST_CASES = [ + [{"encoder_embed_dim": 48, "in_channels": 1}, {}, (1, 1, 64, 64, 64), (1, 1, 64, 64, 64)], + [{"encoder_embed_dim": 48, "in_channels": 2}, {}, (1, 2, 64, 64, 64), (1, 1, 64, 64, 64)], + [ + {"encoder_embed_dim": 48, "in_channels": 1}, + {"class_vector": torch.tensor([1, 2, 3], device=device)}, + (1, 1, 64, 64, 64), + (3, 1, 64, 64, 64), + ], + [ + {"encoder_embed_dim": 48, "in_channels": 1}, + { + "point_coords": torch.tensor([[[1, 2, 3], [1, 2, 3]]], device=device), + "point_labels": torch.tensor([[1, 0]], device=device), + }, + (1, 1, 64, 64, 64), + (1, 1, 64, 64, 64), + ], + [ + {"encoder_embed_dim": 48, "in_channels": 1}, + { + "class_vector": torch.tensor([1, 2], device=device), + "point_coords": torch.tensor([[[1, 2, 3], [1, 2, 3]], [[1, 2, 3], [1, 2, 3]]], device=device), + "point_labels": torch.tensor([[1, 0], [1, 0]], device=device), + }, + (1, 1, 64, 64, 64), + (2, 1, 64, 64, 64), + ], +] + + +@SkipIfBeforePyTorchVersion((1, 11)) +@skip_if_quick +class TestVista3d(unittest.TestCase): + + @parameterized.expand(TEST_CASES) + def test_vista3d_shape(self, args, input_params, input_shape, expected_shape): + segresnet = SegResNetDS2( + in_channels=args["in_channels"], + blocks_down=(1, 2, 2, 4, 4), + norm="instance", + out_channels=args["encoder_embed_dim"], + init_filters=args["encoder_embed_dim"], + dsdepth=1, + ) + point_head = PointMappingSAM(feature_size=args["encoder_embed_dim"], n_classes=512, last_supported=132) + class_head = ClassMappingClassify(n_classes=512, feature_size=args["encoder_embed_dim"], use_mlp=True) + net = VISTA3D(image_encoder=segresnet, class_head=class_head, point_head=point_head).to(device) + with eval_mode(net): + result = net.forward( + torch.randn(input_shape).to(device), + point_coords=input_params.get("point_coords", None), + point_labels=input_params.get("point_labels", None), + class_vector=input_params.get("class_vector", None), + ) + self.assertEqual(result.shape, expected_shape) + + +if __name__ == "__main__": + unittest.main()