Skip to content

Commit 281cb01

Browse files
authored
6973 sincos pos embed (#6986)
Fixes #6973 ### Description Adding support for sincos positional embedding for monai.networks.blocks.patchembedding.PatchEmbedding class. This pull request corresponds to this opened issue #6973 ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [x] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: NoTody <howardwong1780@gmail.com>
1 parent 56ca224 commit 281cb01

File tree

10 files changed

+215
-53
lines changed

10 files changed

+215
-53
lines changed

monai/networks/blocks/mlp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def __init__(
3232
Args:
3333
hidden_size: dimension of hidden layer.
3434
mlp_dim: dimension of feedforward layer. If 0, `hidden_size` will be used.
35-
dropout_rate: faction of the input units to drop.
35+
dropout_rate: fraction of the input units to drop.
3636
act: activation type and arguments. Defaults to GELU. Also supports "GEGLU" and others.
3737
dropout_mode: dropout mode, can be "vit" or "swin".
3838
"vit" mode uses two dropout instances as implemented in

monai/networks/blocks/patchembedding.py

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,14 @@
1919
import torch.nn.functional as F
2020
from torch.nn import LayerNorm
2121

22+
from monai.networks.blocks.pos_embed_utils import build_sincos_position_embedding
2223
from monai.networks.layers import Conv, trunc_normal_
23-
from monai.utils import ensure_tuple_rep, optional_import
24+
from monai.utils import deprecated_arg, ensure_tuple_rep, optional_import
2425
from monai.utils.module import look_up_option
2526

2627
Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
27-
SUPPORTED_EMBEDDING_TYPES = {"conv", "perceptron"}
28+
SUPPORTED_PATCH_EMBEDDING_TYPES = {"conv", "perceptron"}
29+
SUPPORTED_POS_EMBEDDING_TYPES = {"none", "learnable", "sincos"}
2830

2931

3032
class PatchEmbeddingBlock(nn.Module):
@@ -35,18 +37,22 @@ class PatchEmbeddingBlock(nn.Module):
3537
Example::
3638
3739
>>> from monai.networks.blocks import PatchEmbeddingBlock
38-
>>> PatchEmbeddingBlock(in_channels=4, img_size=32, patch_size=8, hidden_size=32, num_heads=4, pos_embed="conv")
40+
>>> PatchEmbeddingBlock(in_channels=4, img_size=32, patch_size=8, hidden_size=32, num_heads=4,
41+
>>> proj_type="conv", pos_embed_type="sincos")
3942
4043
"""
4144

45+
@deprecated_arg(name="pos_embed", since="1.2", new_name="proj_type", msg_suffix="please use `proj_type` instead.")
4246
def __init__(
4347
self,
4448
in_channels: int,
4549
img_size: Sequence[int] | int,
4650
patch_size: Sequence[int] | int,
4751
hidden_size: int,
4852
num_heads: int,
49-
pos_embed: str,
53+
pos_embed: str = "conv",
54+
proj_type: str = "conv",
55+
pos_embed_type: str = "learnable",
5056
dropout_rate: float = 0.0,
5157
spatial_dims: int = 3,
5258
) -> None:
@@ -57,11 +63,12 @@ def __init__(
5763
patch_size: dimension of patch size.
5864
hidden_size: dimension of hidden layer.
5965
num_heads: number of attention heads.
60-
pos_embed: position embedding layer type.
61-
dropout_rate: faction of the input units to drop.
66+
proj_type: patch embedding layer type.
67+
pos_embed_type: position embedding layer type.
68+
dropout_rate: fraction of the input units to drop.
6269
spatial_dims: number of spatial dimensions.
63-
64-
70+
.. deprecated:: 1.4
71+
``pos_embed`` is deprecated in favor of ``proj_type``.
6572
"""
6673

6774
super().__init__()
@@ -72,24 +79,25 @@ def __init__(
7279
if hidden_size % num_heads != 0:
7380
raise ValueError(f"hidden size {hidden_size} should be divisible by num_heads {num_heads}.")
7481

75-
self.pos_embed = look_up_option(pos_embed, SUPPORTED_EMBEDDING_TYPES)
82+
self.proj_type = look_up_option(proj_type, SUPPORTED_PATCH_EMBEDDING_TYPES)
83+
self.pos_embed_type = look_up_option(pos_embed_type, SUPPORTED_POS_EMBEDDING_TYPES)
7684

7785
img_size = ensure_tuple_rep(img_size, spatial_dims)
7886
patch_size = ensure_tuple_rep(patch_size, spatial_dims)
7987
for m, p in zip(img_size, patch_size):
8088
if m < p:
8189
raise ValueError("patch_size should be smaller than img_size.")
82-
if self.pos_embed == "perceptron" and m % p != 0:
90+
if self.proj_type == "perceptron" and m % p != 0:
8391
raise ValueError("patch_size should be divisible by img_size for perceptron.")
8492
self.n_patches = np.prod([im_d // p_d for im_d, p_d in zip(img_size, patch_size)])
8593
self.patch_dim = int(in_channels * np.prod(patch_size))
8694

8795
self.patch_embeddings: nn.Module
88-
if self.pos_embed == "conv":
96+
if self.proj_type == "conv":
8997
self.patch_embeddings = Conv[Conv.CONV, spatial_dims](
9098
in_channels=in_channels, out_channels=hidden_size, kernel_size=patch_size, stride=patch_size
9199
)
92-
elif self.pos_embed == "perceptron":
100+
elif self.proj_type == "perceptron":
93101
# for 3d: "b c (h p1) (w p2) (d p3)-> b (h w d) (p1 p2 p3 c)"
94102
chars = (("h", "p1"), ("w", "p2"), ("d", "p3"))[:spatial_dims]
95103
from_chars = "b c " + " ".join(f"({k} {v})" for k, v in chars)
@@ -100,7 +108,22 @@ def __init__(
100108
)
101109
self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, hidden_size))
102110
self.dropout = nn.Dropout(dropout_rate)
103-
trunc_normal_(self.position_embeddings, mean=0.0, std=0.02, a=-2.0, b=2.0)
111+
112+
if self.pos_embed_type == "none":
113+
pass
114+
elif self.pos_embed_type == "learnable":
115+
trunc_normal_(self.position_embeddings, mean=0.0, std=0.02, a=-2.0, b=2.0)
116+
elif self.pos_embed_type == "sincos":
117+
grid_size = []
118+
for in_size, pa_size in zip(img_size, patch_size):
119+
grid_size.append(in_size // pa_size)
120+
121+
with torch.no_grad():
122+
pos_embeddings = build_sincos_position_embedding(grid_size, hidden_size, spatial_dims)
123+
self.position_embeddings.data.copy_(pos_embeddings.float())
124+
else:
125+
raise ValueError(f"pos_embed_type {self.pos_embed_type} not supported.")
126+
104127
self.apply(self._init_weights)
105128

106129
def _init_weights(self, m):
@@ -114,7 +137,7 @@ def _init_weights(self, m):
114137

115138
def forward(self, x):
116139
x = self.patch_embeddings(x)
117-
if self.pos_embed == "conv":
140+
if self.proj_type == "conv":
118141
x = x.flatten(2).transpose(-1, -2)
119142
embeddings = x + self.position_embeddings
120143
embeddings = self.dropout(embeddings)
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import collections.abc
15+
from itertools import repeat
16+
from typing import List, Union
17+
18+
import torch
19+
import torch.nn as nn
20+
21+
__all__ = ["build_sincos_position_embedding"]
22+
23+
24+
# From PyTorch internals
25+
def _ntuple(n):
26+
def parse(x):
27+
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
28+
return tuple(x)
29+
return tuple(repeat(x, n))
30+
31+
return parse
32+
33+
34+
def build_sincos_position_embedding(
35+
grid_size: Union[int, List[int]], embed_dim: int, spatial_dims: int = 3, temperature: float = 10000.0
36+
) -> torch.nn.Parameter:
37+
"""
38+
Builds a sin-cos position embedding based on the given grid size, embed dimension, spatial dimensions, and temperature.
39+
Reference: https://github.com/cvlab-stonybrook/SelfMedMAE/blob/68d191dfcc1c7d0145db93a6a570362de29e3b30/lib/models/mae3d.py
40+
41+
Args:
42+
grid_size (List[int]): The size of the grid in each spatial dimension.
43+
embed_dim (int): The dimension of the embedding.
44+
spatial_dims (int): The number of spatial dimensions (2 for 2D, 3 for 3D).
45+
temperature (float): The temperature for the sin-cos position embedding.
46+
47+
Returns:
48+
pos_embed (nn.Parameter): The sin-cos position embedding as a learnable parameter.
49+
"""
50+
51+
if spatial_dims == 2:
52+
to_2tuple = _ntuple(2)
53+
grid_size_t = to_2tuple(grid_size)
54+
h, w = grid_size_t
55+
grid_h = torch.arange(h, dtype=torch.float32)
56+
grid_w = torch.arange(w, dtype=torch.float32)
57+
58+
grid_h, grid_w = torch.meshgrid(grid_h, grid_w, indexing="ij")
59+
60+
assert embed_dim % 4 == 0, "Embed dimension must be divisible by 4 for 2D sin-cos position embedding"
61+
62+
pos_dim = embed_dim // 4
63+
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
64+
omega = 1.0 / (temperature**omega)
65+
out_h = torch.einsum("m,d->md", [grid_h.flatten(), omega])
66+
out_w = torch.einsum("m,d->md", [grid_w.flatten(), omega])
67+
pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :]
68+
elif spatial_dims == 3:
69+
to_3tuple = _ntuple(3)
70+
grid_size_t = to_3tuple(grid_size)
71+
h, w, d = grid_size_t
72+
grid_h = torch.arange(h, dtype=torch.float32)
73+
grid_w = torch.arange(w, dtype=torch.float32)
74+
grid_d = torch.arange(d, dtype=torch.float32)
75+
76+
grid_h, grid_w, grid_d = torch.meshgrid(grid_h, grid_w, grid_d, indexing="ij")
77+
78+
assert embed_dim % 6 == 0, "Embed dimension must be divisible by 6 for 3D sin-cos position embedding"
79+
80+
pos_dim = embed_dim // 6
81+
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
82+
omega = 1.0 / (temperature**omega)
83+
out_h = torch.einsum("m,d->md", [grid_h.flatten(), omega])
84+
out_w = torch.einsum("m,d->md", [grid_w.flatten(), omega])
85+
out_d = torch.einsum("m,d->md", [grid_d.flatten(), omega])
86+
pos_emb = torch.cat(
87+
[
88+
torch.sin(out_w),
89+
torch.cos(out_w),
90+
torch.sin(out_h),
91+
torch.cos(out_h),
92+
torch.sin(out_d),
93+
torch.cos(out_d),
94+
],
95+
dim=1,
96+
)[None, :, :]
97+
else:
98+
raise NotImplementedError("Spatial Dimension Size {spatial_dims} Not Implemented!")
99+
100+
pos_embed = nn.Parameter(pos_emb)
101+
pos_embed.requires_grad = False
102+
103+
return pos_embed

monai/networks/blocks/selfattention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __init__(
3737
Args:
3838
hidden_size (int): dimension of hidden layer.
3939
num_heads (int): number of attention heads.
40-
dropout_rate (float, optional): faction of the input units to drop. Defaults to 0.0.
40+
dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.
4141
qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False.
4242
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
4343

monai/networks/blocks/transformerblock.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __init__(
3737
hidden_size (int): dimension of hidden layer.
3838
mlp_dim (int): dimension of feedforward layer.
3939
num_heads (int): number of attention heads.
40-
dropout_rate (float, optional): faction of the input units to drop. Defaults to 0.0.
40+
dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.
4141
qkv_bias (bool, optional): apply bias term for the qkv linear layer. Defaults to False.
4242
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
4343

monai/networks/nets/transchex.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ def __init__(
314314
num_language_layers: number of language transformer layers.
315315
num_vision_layers: number of vision transformer layers.
316316
num_mixed_layers: number of mixed transformer layers.
317-
drop_out: faction of the input units to drop.
317+
drop_out: fraction of the input units to drop.
318318
319319
The other parameters are part of the `bert_config` to `MultiModal.from_pretrained`.
320320

monai/networks/nets/unetr.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from monai.networks.blocks.dynunet_block import UnetOutBlock
1919
from monai.networks.blocks.unetr_block import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock
2020
from monai.networks.nets.vit import ViT
21-
from monai.utils import ensure_tuple_rep
21+
from monai.utils import deprecated_arg, ensure_tuple_rep
2222

2323

2424
class UNETR(nn.Module):
@@ -27,6 +27,7 @@ class UNETR(nn.Module):
2727
UNETR: Transformers for 3D Medical Image Segmentation <https://arxiv.org/abs/2103.10504>"
2828
"""
2929

30+
@deprecated_arg(name="pos_embed", since="1.2", new_name="proj_type", msg_suffix="please use `proj_type` instead.")
3031
def __init__(
3132
self,
3233
in_channels: int,
@@ -37,6 +38,7 @@ def __init__(
3738
mlp_dim: int = 3072,
3839
num_heads: int = 12,
3940
pos_embed: str = "conv",
41+
proj_type: str = "conv",
4042
norm_name: tuple | str = "instance",
4143
conv_block: bool = True,
4244
res_block: bool = True,
@@ -54,7 +56,7 @@ def __init__(
5456
hidden_size: dimension of hidden layer. Defaults to 768.
5557
mlp_dim: dimension of feedforward layer. Defaults to 3072.
5658
num_heads: number of attention heads. Defaults to 12.
57-
pos_embed: position embedding layer type. Defaults to "conv".
59+
proj_type: patch embedding layer type. Defaults to "conv".
5860
norm_name: feature normalization type and arguments. Defaults to "instance".
5961
conv_block: if convolutional block is used. Defaults to True.
6062
res_block: if residual block is used. Defaults to True.
@@ -63,6 +65,9 @@ def __init__(
6365
qkv_bias: apply the bias term for the qkv linear layer in self attention block. Defaults to False.
6466
save_attn: to make accessible the attention in self attention block. Defaults to False.
6567
68+
.. deprecated:: 1.4
69+
``pos_embed`` is deprecated in favor of ``proj_type``.
70+
6671
Examples::
6772
6873
# for single channel input 4-channel output with image size of (96,96,96), feature size of 32 and batch norm
@@ -72,7 +77,7 @@ def __init__(
7277
>>> net = UNETR(in_channels=1, out_channels=4, img_size=96, feature_size=32, norm_name='batch', spatial_dims=2)
7378
7479
# for 4-channel input 3-channel output with image size of (128,128,128), conv position embedding and instance norm
75-
>>> net = UNETR(in_channels=4, out_channels=3, img_size=(128,128,128), pos_embed='conv', norm_name='instance')
80+
>>> net = UNETR(in_channels=4, out_channels=3, img_size=(128,128,128), proj_type='conv', norm_name='instance')
7681
7782
"""
7883

@@ -98,7 +103,7 @@ def __init__(
98103
mlp_dim=mlp_dim,
99104
num_layers=self.num_layers,
100105
num_heads=num_heads,
101-
pos_embed=pos_embed,
106+
proj_type=proj_type,
102107
classification=self.classification,
103108
dropout_rate=dropout_rate,
104109
spatial_dims=spatial_dims,

monai/networks/nets/vit.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from monai.networks.blocks.patchembedding import PatchEmbeddingBlock
2020
from monai.networks.blocks.transformerblock import TransformerBlock
21+
from monai.utils import deprecated_arg
2122

2223
__all__ = ["ViT"]
2324

@@ -30,6 +31,7 @@ class ViT(nn.Module):
3031
ViT supports Torchscript but only works for Pytorch after 1.8.
3132
"""
3233

34+
@deprecated_arg(name="pos_embed", since="1.2", new_name="proj_type", msg_suffix="please use `proj_type` instead.")
3335
def __init__(
3436
self,
3537
in_channels: int,
@@ -40,6 +42,8 @@ def __init__(
4042
num_layers: int = 12,
4143
num_heads: int = 12,
4244
pos_embed: str = "conv",
45+
proj_type: str = "conv",
46+
pos_embed_type: str = "learnable",
4347
classification: bool = False,
4448
num_classes: int = 2,
4549
dropout_rate: float = 0.0,
@@ -57,27 +61,32 @@ def __init__(
5761
mlp_dim (int, optional): dimension of feedforward layer. Defaults to 3072.
5862
num_layers (int, optional): number of transformer blocks. Defaults to 12.
5963
num_heads (int, optional): number of attention heads. Defaults to 12.
60-
pos_embed (str, optional): position embedding layer type. Defaults to "conv".
64+
proj_type (str, optional): patch embedding layer type. Defaults to "conv".
65+
pos_embed_type (str, optional): position embedding type. Defaults to "learnable".
6166
classification (bool, optional): bool argument to determine if classification is used. Defaults to False.
6267
num_classes (int, optional): number of classes if classification is used. Defaults to 2.
63-
dropout_rate (float, optional): faction of the input units to drop. Defaults to 0.0.
68+
dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.
6469
spatial_dims (int, optional): number of spatial dimensions. Defaults to 3.
6570
post_activation (str, optional): add a final acivation function to the classification head
6671
when `classification` is True. Default to "Tanh" for `nn.Tanh()`.
6772
Set to other values to remove this function.
6873
qkv_bias (bool, optional): apply bias to the qkv linear layer in self attention block. Defaults to False.
6974
save_attn (bool, optional): to make accessible the attention in self attention block. Defaults to False.
7075
76+
.. deprecated:: 1.4
77+
``pos_embed`` is deprecated in favor of ``proj_type``.
78+
7179
Examples::
7280
7381
# for single channel input with image size of (96,96,96), conv position embedding and segmentation backbone
74-
>>> net = ViT(in_channels=1, img_size=(96,96,96), pos_embed='conv')
82+
>>> net = ViT(in_channels=1, img_size=(96,96,96), proj_type='conv', pos_embed_type='sincos')
7583
7684
# for 3-channel with image size of (128,128,128), 24 layers and classification backbone
77-
>>> net = ViT(in_channels=3, img_size=(128,128,128), pos_embed='conv', classification=True)
85+
>>> net = ViT(in_channels=3, img_size=(128,128,128), proj_type='conv', pos_embed_type='sincos', classification=True)
7886
7987
# for 3-channel with image size of (224,224), 12 layers and classification backbone
80-
>>> net = ViT(in_channels=3, img_size=(224,224), pos_embed='conv', classification=True, spatial_dims=2)
88+
>>> net = ViT(in_channels=3, img_size=(224,224), proj_type='conv', pos_embed_type='sincos', classification=True,
89+
>>> spatial_dims=2)
8190
8291
"""
8392

@@ -96,7 +105,8 @@ def __init__(
96105
patch_size=patch_size,
97106
hidden_size=hidden_size,
98107
num_heads=num_heads,
99-
pos_embed=pos_embed,
108+
proj_type=proj_type,
109+
pos_embed_type=pos_embed_type,
100110
dropout_rate=dropout_rate,
101111
spatial_dims=spatial_dims,
102112
)

0 commit comments

Comments
 (0)