Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion monai/networks/blocks/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(
Args:
hidden_size: dimension of hidden layer.
mlp_dim: dimension of feedforward layer. If 0, `hidden_size` will be used.
dropout_rate: faction of the input units to drop.
dropout_rate: fraction of the input units to drop.
act: activation type and arguments. Defaults to GELU. Also supports "GEGLU" and others.
dropout_mode: dropout mode, can be "vit" or "swin".
"vit" mode uses two dropout instances as implemented in
Expand Down
51 changes: 37 additions & 14 deletions monai/networks/blocks/patchembedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@
import torch.nn.functional as F
from torch.nn import LayerNorm

from monai.networks.blocks.pos_embed_utils import build_sincos_position_embedding
from monai.networks.layers import Conv, trunc_normal_
from monai.utils import ensure_tuple_rep, optional_import
from monai.utils import deprecated_arg, ensure_tuple_rep, optional_import
from monai.utils.module import look_up_option

Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
SUPPORTED_EMBEDDING_TYPES = {"conv", "perceptron"}
SUPPORTED_PATCH_EMBEDDING_TYPES = {"conv", "perceptron"}
SUPPORTED_POS_EMBEDDING_TYPES = {"none", "learnable", "sincos"}


class PatchEmbeddingBlock(nn.Module):
Expand All @@ -35,18 +37,22 @@ class PatchEmbeddingBlock(nn.Module):
Example::

>>> from monai.networks.blocks import PatchEmbeddingBlock
>>> PatchEmbeddingBlock(in_channels=4, img_size=32, patch_size=8, hidden_size=32, num_heads=4, pos_embed="conv")
>>> PatchEmbeddingBlock(in_channels=4, img_size=32, patch_size=8, hidden_size=32, num_heads=4,
>>> proj_type="conv", pos_embed_type="sincos")

"""

@deprecated_arg(name="pos_embed", since="1.2", new_name="proj_type", msg_suffix="please use `proj_type` instead.")
def __init__(
self,
in_channels: int,
img_size: Sequence[int] | int,
patch_size: Sequence[int] | int,
hidden_size: int,
num_heads: int,
pos_embed: str,
pos_embed: str = "conv",
proj_type: str = "conv",
pos_embed_type: str = "learnable",
dropout_rate: float = 0.0,
spatial_dims: int = 3,
) -> None:
Expand All @@ -57,11 +63,12 @@ def __init__(
patch_size: dimension of patch size.
hidden_size: dimension of hidden layer.
num_heads: number of attention heads.
pos_embed: position embedding layer type.
dropout_rate: faction of the input units to drop.
proj_type: patch embedding layer type.
pos_embed_type: position embedding layer type.
dropout_rate: fraction of the input units to drop.
spatial_dims: number of spatial dimensions.


.. deprecated:: 1.4
``pos_embed`` is deprecated in favor of ``proj_type``.
"""

super().__init__()
Expand All @@ -72,24 +79,25 @@ def __init__(
if hidden_size % num_heads != 0:
raise ValueError(f"hidden size {hidden_size} should be divisible by num_heads {num_heads}.")

self.pos_embed = look_up_option(pos_embed, SUPPORTED_EMBEDDING_TYPES)
self.proj_type = look_up_option(proj_type, SUPPORTED_PATCH_EMBEDDING_TYPES)
self.pos_embed_type = look_up_option(pos_embed_type, SUPPORTED_POS_EMBEDDING_TYPES)

img_size = ensure_tuple_rep(img_size, spatial_dims)
patch_size = ensure_tuple_rep(patch_size, spatial_dims)
for m, p in zip(img_size, patch_size):
if m < p:
raise ValueError("patch_size should be smaller than img_size.")
if self.pos_embed == "perceptron" and m % p != 0:
if self.proj_type == "perceptron" and m % p != 0:
raise ValueError("patch_size should be divisible by img_size for perceptron.")
self.n_patches = np.prod([im_d // p_d for im_d, p_d in zip(img_size, patch_size)])
self.patch_dim = int(in_channels * np.prod(patch_size))

self.patch_embeddings: nn.Module
if self.pos_embed == "conv":
if self.proj_type == "conv":
self.patch_embeddings = Conv[Conv.CONV, spatial_dims](
in_channels=in_channels, out_channels=hidden_size, kernel_size=patch_size, stride=patch_size
)
elif self.pos_embed == "perceptron":
elif self.proj_type == "perceptron":
# for 3d: "b c (h p1) (w p2) (d p3)-> b (h w d) (p1 p2 p3 c)"
chars = (("h", "p1"), ("w", "p2"), ("d", "p3"))[:spatial_dims]
from_chars = "b c " + " ".join(f"({k} {v})" for k, v in chars)
Expand All @@ -100,7 +108,22 @@ def __init__(
)
self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, hidden_size))
self.dropout = nn.Dropout(dropout_rate)
trunc_normal_(self.position_embeddings, mean=0.0, std=0.02, a=-2.0, b=2.0)

if self.pos_embed_type == "none":
pass
elif self.pos_embed_type == "learnable":
trunc_normal_(self.position_embeddings, mean=0.0, std=0.02, a=-2.0, b=2.0)
elif self.pos_embed_type == "sincos":
grid_size = []
for in_size, pa_size in zip(img_size, patch_size):
grid_size.append(in_size // pa_size)

with torch.no_grad():
pos_embeddings = build_sincos_position_embedding(grid_size, hidden_size, spatial_dims)
self.position_embeddings.data.copy_(pos_embeddings.float())
else:
raise ValueError(f"pos_embed_type {self.pos_embed_type} not supported.")

self.apply(self._init_weights)

def _init_weights(self, m):
Expand All @@ -114,7 +137,7 @@ def _init_weights(self, m):

def forward(self, x):
x = self.patch_embeddings(x)
if self.pos_embed == "conv":
if self.proj_type == "conv":
x = x.flatten(2).transpose(-1, -2)
embeddings = x + self.position_embeddings
embeddings = self.dropout(embeddings)
Expand Down
103 changes: 103 additions & 0 deletions monai/networks/blocks/pos_embed_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import collections.abc
from itertools import repeat
from typing import List, Union

import torch
import torch.nn as nn

__all__ = ["build_sincos_position_embedding"]


# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
return tuple(x)
return tuple(repeat(x, n))

return parse


def build_sincos_position_embedding(
grid_size: Union[int, List[int]], embed_dim: int, spatial_dims: int = 3, temperature: float = 10000.0
) -> torch.nn.Parameter:
"""
Builds a sin-cos position embedding based on the given grid size, embed dimension, spatial dimensions, and temperature.
Reference: https://github.com/cvlab-stonybrook/SelfMedMAE/blob/68d191dfcc1c7d0145db93a6a570362de29e3b30/lib/models/mae3d.py

Args:
grid_size (List[int]): The size of the grid in each spatial dimension.
embed_dim (int): The dimension of the embedding.
spatial_dims (int): The number of spatial dimensions (2 for 2D, 3 for 3D).
temperature (float): The temperature for the sin-cos position embedding.

Returns:
pos_embed (nn.Parameter): The sin-cos position embedding as a learnable parameter.
"""

if spatial_dims == 2:
to_2tuple = _ntuple(2)
grid_size_t = to_2tuple(grid_size)
h, w = grid_size_t
grid_h = torch.arange(h, dtype=torch.float32)
grid_w = torch.arange(w, dtype=torch.float32)

grid_h, grid_w = torch.meshgrid(grid_h, grid_w, indexing="ij")

assert embed_dim % 4 == 0, "Embed dimension must be divisible by 4 for 2D sin-cos position embedding"

pos_dim = embed_dim // 4
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
omega = 1.0 / (temperature**omega)
out_h = torch.einsum("m,d->md", [grid_h.flatten(), omega])
out_w = torch.einsum("m,d->md", [grid_w.flatten(), omega])
pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :]
elif spatial_dims == 3:
to_3tuple = _ntuple(3)
grid_size_t = to_3tuple(grid_size)
h, w, d = grid_size_t
grid_h = torch.arange(h, dtype=torch.float32)
grid_w = torch.arange(w, dtype=torch.float32)
grid_d = torch.arange(d, dtype=torch.float32)

grid_h, grid_w, grid_d = torch.meshgrid(grid_h, grid_w, grid_d, indexing="ij")

assert embed_dim % 6 == 0, "Embed dimension must be divisible by 6 for 3D sin-cos position embedding"

pos_dim = embed_dim // 6
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
omega = 1.0 / (temperature**omega)
out_h = torch.einsum("m,d->md", [grid_h.flatten(), omega])
out_w = torch.einsum("m,d->md", [grid_w.flatten(), omega])
out_d = torch.einsum("m,d->md", [grid_d.flatten(), omega])
pos_emb = torch.cat(
[
torch.sin(out_w),
torch.cos(out_w),
torch.sin(out_h),
torch.cos(out_h),
torch.sin(out_d),
torch.cos(out_d),
],
dim=1,
)[None, :, :]
else:
raise NotImplementedError("Spatial Dimension Size {spatial_dims} Not Implemented!")

pos_embed = nn.Parameter(pos_emb)
pos_embed.requires_grad = False

return pos_embed
2 changes: 1 addition & 1 deletion monai/networks/blocks/selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(
Args:
hidden_size (int): dimension of hidden layer.
num_heads (int): number of attention heads.
dropout_rate (float, optional): faction of the input units to drop. Defaults to 0.0.
dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.
qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False.
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.

Expand Down
2 changes: 1 addition & 1 deletion monai/networks/blocks/transformerblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(
hidden_size (int): dimension of hidden layer.
mlp_dim (int): dimension of feedforward layer.
num_heads (int): number of attention heads.
dropout_rate (float, optional): faction of the input units to drop. Defaults to 0.0.
dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.
qkv_bias (bool, optional): apply bias term for the qkv linear layer. Defaults to False.
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.

Expand Down
2 changes: 1 addition & 1 deletion monai/networks/nets/transchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def __init__(
num_language_layers: number of language transformer layers.
num_vision_layers: number of vision transformer layers.
num_mixed_layers: number of mixed transformer layers.
drop_out: faction of the input units to drop.
drop_out: fraction of the input units to drop.

The other parameters are part of the `bert_config` to `MultiModal.from_pretrained`.

Expand Down
13 changes: 9 additions & 4 deletions monai/networks/nets/unetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from monai.networks.blocks.dynunet_block import UnetOutBlock
from monai.networks.blocks.unetr_block import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock
from monai.networks.nets.vit import ViT
from monai.utils import ensure_tuple_rep
from monai.utils import deprecated_arg, ensure_tuple_rep


class UNETR(nn.Module):
Expand All @@ -27,6 +27,7 @@ class UNETR(nn.Module):
UNETR: Transformers for 3D Medical Image Segmentation <https://arxiv.org/abs/2103.10504>"
"""

@deprecated_arg(name="pos_embed", since="1.2", new_name="proj_type", msg_suffix="please use `proj_type` instead.")
def __init__(
self,
in_channels: int,
Expand All @@ -37,6 +38,7 @@ def __init__(
mlp_dim: int = 3072,
num_heads: int = 12,
pos_embed: str = "conv",
proj_type: str = "conv",
norm_name: tuple | str = "instance",
conv_block: bool = True,
res_block: bool = True,
Expand All @@ -54,7 +56,7 @@ def __init__(
hidden_size: dimension of hidden layer. Defaults to 768.
mlp_dim: dimension of feedforward layer. Defaults to 3072.
num_heads: number of attention heads. Defaults to 12.
pos_embed: position embedding layer type. Defaults to "conv".
proj_type: patch embedding layer type. Defaults to "conv".
norm_name: feature normalization type and arguments. Defaults to "instance".
conv_block: if convolutional block is used. Defaults to True.
res_block: if residual block is used. Defaults to True.
Expand All @@ -63,6 +65,9 @@ def __init__(
qkv_bias: apply the bias term for the qkv linear layer in self attention block. Defaults to False.
save_attn: to make accessible the attention in self attention block. Defaults to False.

.. deprecated:: 1.4
``pos_embed`` is deprecated in favor of ``proj_type``.

Examples::

# for single channel input 4-channel output with image size of (96,96,96), feature size of 32 and batch norm
Expand All @@ -72,7 +77,7 @@ def __init__(
>>> net = UNETR(in_channels=1, out_channels=4, img_size=96, feature_size=32, norm_name='batch', spatial_dims=2)

# for 4-channel input 3-channel output with image size of (128,128,128), conv position embedding and instance norm
>>> net = UNETR(in_channels=4, out_channels=3, img_size=(128,128,128), pos_embed='conv', norm_name='instance')
>>> net = UNETR(in_channels=4, out_channels=3, img_size=(128,128,128), proj_type='conv', norm_name='instance')

"""

Expand All @@ -98,7 +103,7 @@ def __init__(
mlp_dim=mlp_dim,
num_layers=self.num_layers,
num_heads=num_heads,
pos_embed=pos_embed,
proj_type=proj_type,
classification=self.classification,
dropout_rate=dropout_rate,
spatial_dims=spatial_dims,
Expand Down
22 changes: 16 additions & 6 deletions monai/networks/nets/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from monai.networks.blocks.patchembedding import PatchEmbeddingBlock
from monai.networks.blocks.transformerblock import TransformerBlock
from monai.utils import deprecated_arg

__all__ = ["ViT"]

Expand All @@ -30,6 +31,7 @@ class ViT(nn.Module):
ViT supports Torchscript but only works for Pytorch after 1.8.
"""

@deprecated_arg(name="pos_embed", since="1.2", new_name="proj_type", msg_suffix="please use `proj_type` instead.")
def __init__(
self,
in_channels: int,
Expand All @@ -40,6 +42,8 @@ def __init__(
num_layers: int = 12,
num_heads: int = 12,
pos_embed: str = "conv",
proj_type: str = "conv",
pos_embed_type: str = "learnable",
classification: bool = False,
num_classes: int = 2,
dropout_rate: float = 0.0,
Expand All @@ -57,27 +61,32 @@ def __init__(
mlp_dim (int, optional): dimension of feedforward layer. Defaults to 3072.
num_layers (int, optional): number of transformer blocks. Defaults to 12.
num_heads (int, optional): number of attention heads. Defaults to 12.
pos_embed (str, optional): position embedding layer type. Defaults to "conv".
proj_type (str, optional): patch embedding layer type. Defaults to "conv".
pos_embed_type (str, optional): position embedding type. Defaults to "learnable".
classification (bool, optional): bool argument to determine if classification is used. Defaults to False.
num_classes (int, optional): number of classes if classification is used. Defaults to 2.
dropout_rate (float, optional): faction of the input units to drop. Defaults to 0.0.
dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.
spatial_dims (int, optional): number of spatial dimensions. Defaults to 3.
post_activation (str, optional): add a final acivation function to the classification head
when `classification` is True. Default to "Tanh" for `nn.Tanh()`.
Set to other values to remove this function.
qkv_bias (bool, optional): apply bias to the qkv linear layer in self attention block. Defaults to False.
save_attn (bool, optional): to make accessible the attention in self attention block. Defaults to False.

.. deprecated:: 1.4
``pos_embed`` is deprecated in favor of ``proj_type``.

Examples::

# for single channel input with image size of (96,96,96), conv position embedding and segmentation backbone
>>> net = ViT(in_channels=1, img_size=(96,96,96), pos_embed='conv')
>>> net = ViT(in_channels=1, img_size=(96,96,96), proj_type='conv', pos_embed_type='sincos')

# for 3-channel with image size of (128,128,128), 24 layers and classification backbone
>>> net = ViT(in_channels=3, img_size=(128,128,128), pos_embed='conv', classification=True)
>>> net = ViT(in_channels=3, img_size=(128,128,128), proj_type='conv', pos_embed_type='sincos', classification=True)

# for 3-channel with image size of (224,224), 12 layers and classification backbone
>>> net = ViT(in_channels=3, img_size=(224,224), pos_embed='conv', classification=True, spatial_dims=2)
>>> net = ViT(in_channels=3, img_size=(224,224), proj_type='conv', pos_embed_type='sincos', classification=True,
>>> spatial_dims=2)

"""

Expand All @@ -96,7 +105,8 @@ def __init__(
patch_size=patch_size,
hidden_size=hidden_size,
num_heads=num_heads,
pos_embed=pos_embed,
proj_type=proj_type,
pos_embed_type=pos_embed_type,
dropout_rate=dropout_rate,
spatial_dims=spatial_dims,
)
Expand Down
Loading