@@ -94,7 +94,7 @@ def forward(self, x):
94
94
95
95
96
96
class PatchMerging (nn .Module ):
97
- def __init__ (self , input_resolution , dim , out_dim , activation ):
97
+ def __init__ (self , input_resolution , dim , out_dim , activation , in_fmt = 'BCHW' ):
98
98
super ().__init__ ()
99
99
self .input_resolution = input_resolution
100
100
self .dim = dim
@@ -104,18 +104,21 @@ def __init__(self, input_resolution, dim, out_dim, activation):
104
104
self .conv2 = ConvNorm (out_dim , out_dim , 3 , 2 , 1 , groups = out_dim )
105
105
self .conv3 = ConvNorm (out_dim , out_dim , 1 , 1 , 0 )
106
106
self .output_resolution = (math .ceil (input_resolution [0 ] / 2 ), math .ceil (input_resolution [1 ] / 2 ))
107
+ self .in_fmt = in_fmt
108
+ assert self .in_fmt in ['BCHW' , 'BLC' ]
107
109
108
110
def forward (self , x ):
109
- if x .ndim == 3 :
111
+ if self .in_fmt == 'BLC' :
112
+ # (B, H * W, C) -> (B, C, H, W)
110
113
H , W = self .input_resolution
111
- B = len (x )
112
- # (B, C, H, W)
114
+ B = x .shape [0 ]
113
115
x = x .view (B , H , W , - 1 ).permute (0 , 3 , 1 , 2 )
114
116
x = self .conv1 (x )
115
117
x = self .act (x )
116
118
x = self .conv2 (x )
117
119
x = self .act (x )
118
120
x = self .conv3 (x )
121
+ # (B, C, H, W) -> (B, H * W, C)
119
122
x = x .flatten (2 ).transpose (1 , 2 )
120
123
return x
121
124
@@ -369,6 +372,7 @@ class TinyVitStage(nn.Module):
369
372
local_conv_size: the kernel size of the depthwise convolution between attention and MLP. Default: 3
370
373
activation: the activation function. Default: nn.GELU
371
374
out_dim: the output dimension of the layer. Default: dim
375
+ in_fmt: input format ('BCHW' or 'BLC'). Default: 'BCHW'
372
376
"""
373
377
374
378
def __init__ (
@@ -385,6 +389,7 @@ def __init__(
385
389
local_conv_size = 3 ,
386
390
activation = nn .GELU ,
387
391
out_dim = None ,
392
+ in_fmt = 'BCHW'
388
393
):
389
394
390
395
super ().__init__ ()
@@ -396,7 +401,7 @@ def __init__(
396
401
# patch merging layer
397
402
if downsample is not None :
398
403
self .downsample = downsample (
399
- input_resolution , dim = input_dim , out_dim = self .out_dim , activation = activation )
404
+ input_resolution , dim = input_dim , out_dim = self .out_dim , activation = activation , in_fmt = in_fmt )
400
405
input_resolution = self .downsample .output_resolution
401
406
else :
402
407
self .downsample = nn .Identity ()
@@ -483,6 +488,10 @@ def __init__(
483
488
else :
484
489
out_dim = embed_dims [stage_idx ]
485
490
drop_path_rate = dpr [sum (depths [:stage_idx ]):sum (depths [:stage_idx + 1 ])]
491
+ if stage_idx == 1 :
492
+ in_fmt = 'BCHW'
493
+ else :
494
+ in_fmt = 'BLC'
486
495
stage = TinyVitStage (
487
496
num_heads = num_heads [stage_idx ],
488
497
window_size = window_sizes [stage_idx ],
@@ -496,6 +505,7 @@ def __init__(
496
505
downsample = PatchMerging ,
497
506
out_dim = out_dim ,
498
507
activation = activation ,
508
+ in_fmt = in_fmt
499
509
)
500
510
input_resolution = (math .ceil (input_resolution [0 ] / 2 ), math .ceil (input_resolution [1 ] / 2 ))
501
511
stride *= 2
0 commit comments