1919import torch .nn .functional as F
2020from torch .nn import LayerNorm
2121
22+ from monai .networks .blocks .pos_embed_utils import build_sincos_position_embedding
2223from monai .networks .layers import Conv , trunc_normal_
23- from monai .utils import ensure_tuple_rep , optional_import
24+ from monai .utils import deprecated_arg , ensure_tuple_rep , optional_import
2425from monai .utils .module import look_up_option
2526
2627Rearrange , _ = optional_import ("einops.layers.torch" , name = "Rearrange" )
27- SUPPORTED_EMBEDDING_TYPES = {"conv" , "perceptron" }
28+ SUPPORTED_PATCH_EMBEDDING_TYPES = {"conv" , "perceptron" }
29+ SUPPORTED_POS_EMBEDDING_TYPES = {"none" , "learnable" , "sincos" }
2830
2931
3032class PatchEmbeddingBlock (nn .Module ):
@@ -35,18 +37,22 @@ class PatchEmbeddingBlock(nn.Module):
3537 Example::
3638
3739 >>> from monai.networks.blocks import PatchEmbeddingBlock
38- >>> PatchEmbeddingBlock(in_channels=4, img_size=32, patch_size=8, hidden_size=32, num_heads=4, pos_embed="conv")
40+ >>> PatchEmbeddingBlock(in_channels=4, img_size=32, patch_size=8, hidden_size=32, num_heads=4,
41+ >>> proj_type="conv", pos_embed_type="sincos")
3942
4043 """
4144
45+ @deprecated_arg (name = "pos_embed" , since = "1.2" , new_name = "proj_type" , msg_suffix = "please use `proj_type` instead." )
4246 def __init__ (
4347 self ,
4448 in_channels : int ,
4549 img_size : Sequence [int ] | int ,
4650 patch_size : Sequence [int ] | int ,
4751 hidden_size : int ,
4852 num_heads : int ,
49- pos_embed : str ,
53+ pos_embed : str = "conv" ,
54+ proj_type : str = "conv" ,
55+ pos_embed_type : str = "learnable" ,
5056 dropout_rate : float = 0.0 ,
5157 spatial_dims : int = 3 ,
5258 ) -> None :
@@ -57,11 +63,12 @@ def __init__(
5763 patch_size: dimension of patch size.
5864 hidden_size: dimension of hidden layer.
5965 num_heads: number of attention heads.
60- pos_embed: position embedding layer type.
61- dropout_rate: faction of the input units to drop.
66+ proj_type: patch embedding layer type.
67+ pos_embed_type: position embedding layer type.
68+ dropout_rate: fraction of the input units to drop.
6269 spatial_dims: number of spatial dimensions.
63-
64-
70+ .. deprecated:: 1.4
71+ ``pos_embed`` is deprecated in favor of ``proj_type``.
6572 """
6673
6774 super ().__init__ ()
@@ -72,24 +79,25 @@ def __init__(
7279 if hidden_size % num_heads != 0 :
7380 raise ValueError (f"hidden size { hidden_size } should be divisible by num_heads { num_heads } ." )
7481
75- self .pos_embed = look_up_option (pos_embed , SUPPORTED_EMBEDDING_TYPES )
82+ self .proj_type = look_up_option (proj_type , SUPPORTED_PATCH_EMBEDDING_TYPES )
83+ self .pos_embed_type = look_up_option (pos_embed_type , SUPPORTED_POS_EMBEDDING_TYPES )
7684
7785 img_size = ensure_tuple_rep (img_size , spatial_dims )
7886 patch_size = ensure_tuple_rep (patch_size , spatial_dims )
7987 for m , p in zip (img_size , patch_size ):
8088 if m < p :
8189 raise ValueError ("patch_size should be smaller than img_size." )
82- if self .pos_embed == "perceptron" and m % p != 0 :
90+ if self .proj_type == "perceptron" and m % p != 0 :
8391 raise ValueError ("patch_size should be divisible by img_size for perceptron." )
8492 self .n_patches = np .prod ([im_d // p_d for im_d , p_d in zip (img_size , patch_size )])
8593 self .patch_dim = int (in_channels * np .prod (patch_size ))
8694
8795 self .patch_embeddings : nn .Module
88- if self .pos_embed == "conv" :
96+ if self .proj_type == "conv" :
8997 self .patch_embeddings = Conv [Conv .CONV , spatial_dims ](
9098 in_channels = in_channels , out_channels = hidden_size , kernel_size = patch_size , stride = patch_size
9199 )
92- elif self .pos_embed == "perceptron" :
100+ elif self .proj_type == "perceptron" :
93101 # for 3d: "b c (h p1) (w p2) (d p3)-> b (h w d) (p1 p2 p3 c)"
94102 chars = (("h" , "p1" ), ("w" , "p2" ), ("d" , "p3" ))[:spatial_dims ]
95103 from_chars = "b c " + " " .join (f"({ k } { v } )" for k , v in chars )
@@ -100,7 +108,22 @@ def __init__(
100108 )
101109 self .position_embeddings = nn .Parameter (torch .zeros (1 , self .n_patches , hidden_size ))
102110 self .dropout = nn .Dropout (dropout_rate )
103- trunc_normal_ (self .position_embeddings , mean = 0.0 , std = 0.02 , a = - 2.0 , b = 2.0 )
111+
112+ if self .pos_embed_type == "none" :
113+ pass
114+ elif self .pos_embed_type == "learnable" :
115+ trunc_normal_ (self .position_embeddings , mean = 0.0 , std = 0.02 , a = - 2.0 , b = 2.0 )
116+ elif self .pos_embed_type == "sincos" :
117+ grid_size = []
118+ for in_size , pa_size in zip (img_size , patch_size ):
119+ grid_size .append (in_size // pa_size )
120+
121+ with torch .no_grad ():
122+ pos_embeddings = build_sincos_position_embedding (grid_size , hidden_size , spatial_dims )
123+ self .position_embeddings .data .copy_ (pos_embeddings .float ())
124+ else :
125+ raise ValueError (f"pos_embed_type { self .pos_embed_type } not supported." )
126+
104127 self .apply (self ._init_weights )
105128
106129 def _init_weights (self , m ):
@@ -114,7 +137,7 @@ def _init_weights(self, m):
114137
115138 def forward (self , x ):
116139 x = self .patch_embeddings (x )
117- if self .pos_embed == "conv" :
140+ if self .proj_type == "conv" :
118141 x = x .flatten (2 ).transpose (- 1 , - 2 )
119142 embeddings = x + self .position_embeddings
120143 embeddings = self .dropout (embeddings )
0 commit comments