Skip to content

Commit 39aa44b

Browse files
seefunrwightman
authored andcommitted
Fixing tinyvit trace issue
1 parent aea3b9c commit 39aa44b

File tree

1 file changed

+15
-5
lines changed

1 file changed

+15
-5
lines changed

timm/models/tiny_vit.py

+15-5
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def forward(self, x):
9494

9595

9696
class PatchMerging(nn.Module):
97-
def __init__(self, input_resolution, dim, out_dim, activation):
97+
def __init__(self, input_resolution, dim, out_dim, activation, in_fmt='BCHW'):
9898
super().__init__()
9999
self.input_resolution = input_resolution
100100
self.dim = dim
@@ -104,18 +104,21 @@ def __init__(self, input_resolution, dim, out_dim, activation):
104104
self.conv2 = ConvNorm(out_dim, out_dim, 3, 2, 1, groups=out_dim)
105105
self.conv3 = ConvNorm(out_dim, out_dim, 1, 1, 0)
106106
self.output_resolution = (math.ceil(input_resolution[0] / 2), math.ceil(input_resolution[1] / 2))
107+
self.in_fmt = in_fmt
108+
assert self.in_fmt in ['BCHW', 'BLC']
107109

108110
def forward(self, x):
109-
if x.ndim == 3:
111+
if self.in_fmt == 'BLC':
112+
# (B, H * W, C) -> (B, C, H, W)
110113
H, W = self.input_resolution
111-
B = len(x)
112-
# (B, C, H, W)
114+
B = x.shape[0]
113115
x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
114116
x = self.conv1(x)
115117
x = self.act(x)
116118
x = self.conv2(x)
117119
x = self.act(x)
118120
x = self.conv3(x)
121+
# (B, C, H, W) -> (B, H * W, C)
119122
x = x.flatten(2).transpose(1, 2)
120123
return x
121124

@@ -369,6 +372,7 @@ class TinyVitStage(nn.Module):
369372
local_conv_size: the kernel size of the depthwise convolution between attention and MLP. Default: 3
370373
activation: the activation function. Default: nn.GELU
371374
out_dim: the output dimension of the layer. Default: dim
375+
in_fmt: input format ('BCHW' or 'BLC'). Default: 'BCHW'
372376
"""
373377

374378
def __init__(
@@ -385,6 +389,7 @@ def __init__(
385389
local_conv_size=3,
386390
activation=nn.GELU,
387391
out_dim=None,
392+
in_fmt='BCHW'
388393
):
389394

390395
super().__init__()
@@ -396,7 +401,7 @@ def __init__(
396401
# patch merging layer
397402
if downsample is not None:
398403
self.downsample = downsample(
399-
input_resolution, dim=input_dim, out_dim=self.out_dim, activation=activation)
404+
input_resolution, dim=input_dim, out_dim=self.out_dim, activation=activation, in_fmt=in_fmt)
400405
input_resolution = self.downsample.output_resolution
401406
else:
402407
self.downsample = nn.Identity()
@@ -483,6 +488,10 @@ def __init__(
483488
else:
484489
out_dim = embed_dims[stage_idx]
485490
drop_path_rate = dpr[sum(depths[:stage_idx]):sum(depths[:stage_idx + 1])]
491+
if stage_idx == 1:
492+
in_fmt = 'BCHW'
493+
else:
494+
in_fmt = 'BLC'
486495
stage = TinyVitStage(
487496
num_heads=num_heads[stage_idx],
488497
window_size=window_sizes[stage_idx],
@@ -496,6 +505,7 @@ def __init__(
496505
downsample=PatchMerging,
497506
out_dim=out_dim,
498507
activation=activation,
508+
in_fmt=in_fmt
499509
)
500510
input_resolution = (math.ceil(input_resolution[0] / 2), math.ceil(input_resolution[1] / 2))
501511
stride *= 2

0 commit comments

Comments
 (0)