Skip to content

Commit 325a6cc

Browse files
committed
Add dd to other ResNet based models, Res2Net, ResNeSt, SKNet
1 parent 90a35c8 commit 325a6cc

File tree

4 files changed

+112
-31
lines changed

4 files changed

+112
-31
lines changed

timm/models/res2net.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,11 @@ def __init__(
3535
act_layer=nn.ReLU,
3636
norm_layer=None,
3737
attn_layer=None,
38+
device=None,
39+
dtype=None,
3840
**_,
3941
):
42+
dd = {'device': device, 'dtype': dtype}
4043
super(Bottle2neck, self).__init__()
4144
self.scale = scale
4245
self.is_first = stride > 1 or downsample is not None
@@ -46,16 +49,24 @@ def __init__(
4649
outplanes = planes * self.expansion
4750
first_dilation = first_dilation or dilation
4851

49-
self.conv1 = nn.Conv2d(inplanes, width * scale, kernel_size=1, bias=False)
50-
self.bn1 = norm_layer(width * scale)
52+
self.conv1 = nn.Conv2d(inplanes, width * scale, kernel_size=1, bias=False, **dd)
53+
self.bn1 = norm_layer(width * scale, **dd)
5154

5255
convs = []
5356
bns = []
5457
for i in range(self.num_scales):
5558
convs.append(nn.Conv2d(
56-
width, width, kernel_size=3, stride=stride, padding=first_dilation,
57-
dilation=first_dilation, groups=cardinality, bias=False))
58-
bns.append(norm_layer(width))
59+
width,
60+
width,
61+
kernel_size=3,
62+
stride=stride,
63+
padding=first_dilation,
64+
dilation=first_dilation,
65+
groups=cardinality,
66+
bias=False,
67+
**dd,
68+
))
69+
bns.append(norm_layer(width, **dd))
5970
self.convs = nn.ModuleList(convs)
6071
self.bns = nn.ModuleList(bns)
6172
if self.is_first:
@@ -64,9 +75,9 @@ def __init__(
6475
else:
6576
self.pool = None
6677

67-
self.conv3 = nn.Conv2d(width * scale, outplanes, kernel_size=1, bias=False)
68-
self.bn3 = norm_layer(outplanes)
69-
self.se = attn_layer(outplanes) if attn_layer is not None else None
78+
self.conv3 = nn.Conv2d(width * scale, outplanes, kernel_size=1, bias=False, **dd)
79+
self.bn3 = norm_layer(outplanes, **dd)
80+
self.se = attn_layer(outplanes, **dd) if attn_layer is not None else None
7081

7182
self.relu = act_layer(inplace=True)
7283
self.downsample = downsample

timm/models/resnest.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,10 @@ def __init__(
4242
aa_layer=None,
4343
drop_block=None,
4444
drop_path=None,
45+
device=None,
46+
dtype=None,
4547
):
48+
dd = {'device': device, 'dtype': dtype}
4649
super(ResNestBottleneck, self).__init__()
4750
assert reduce_first == 1 # not supported
4851
assert attn_layer is None, 'attn_layer is not supported' # not supported
@@ -57,29 +60,47 @@ def __init__(
5760
avd_stride = 0
5861
self.radix = radix
5962

60-
self.conv1 = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False)
61-
self.bn1 = norm_layer(group_width)
63+
self.conv1 = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False, **dd)
64+
self.bn1 = norm_layer(group_width, **dd)
6265
self.act1 = act_layer(inplace=True)
6366
self.avd_first = nn.AvgPool2d(3, avd_stride, padding=1) if avd_stride > 0 and avd_first else None
6467

6568
if self.radix >= 1:
6669
self.conv2 = SplitAttn(
67-
group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation,
68-
dilation=first_dilation, groups=cardinality, radix=radix, norm_layer=norm_layer, drop_layer=drop_block)
70+
group_width,
71+
group_width,
72+
kernel_size=3,
73+
stride=stride,
74+
padding=first_dilation,
75+
dilation=first_dilation,
76+
groups=cardinality,
77+
radix=radix,
78+
norm_layer=norm_layer,
79+
drop_layer=drop_block,
80+
**dd,
81+
)
6982
self.bn2 = nn.Identity()
7083
self.drop_block = nn.Identity()
7184
self.act2 = nn.Identity()
7285
else:
7386
self.conv2 = nn.Conv2d(
74-
group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation,
75-
dilation=first_dilation, groups=cardinality, bias=False)
76-
self.bn2 = norm_layer(group_width)
87+
group_width,
88+
group_width,
89+
kernel_size=3,
90+
stride=stride,
91+
padding=first_dilation,
92+
dilation=first_dilation,
93+
groups=cardinality,
94+
bias=False,
95+
**dd,
96+
)
97+
self.bn2 = norm_layer(group_width, **dd)
7798
self.drop_block = drop_block() if drop_block is not None else nn.Identity()
7899
self.act2 = act_layer(inplace=True)
79100
self.avd_last = nn.AvgPool2d(3, avd_stride, padding=1) if avd_stride > 0 and not avd_first else None
80101

81-
self.conv3 = nn.Conv2d(group_width, planes * 4, kernel_size=1, bias=False)
82-
self.bn3 = norm_layer(planes*4)
102+
self.conv3 = nn.Conv2d(group_width, planes * 4, kernel_size=1, bias=False, **dd)
103+
self.bn3 = norm_layer(planes * 4, **dd)
83104
self.act3 = act_layer(inplace=True)
84105
self.downsample = downsample
85106
self.drop_path = drop_path

timm/models/resnet.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,15 +87,29 @@ def __init__(
8787
use_aa = aa_layer is not None and (stride == 2 or first_dilation != dilation)
8888

8989
self.conv1 = nn.Conv2d(
90-
inplanes, first_planes, kernel_size=3, stride=1 if use_aa else stride, padding=first_dilation,
91-
dilation=first_dilation, bias=False, **dd)
90+
inplanes,
91+
first_planes,
92+
kernel_size=3,
93+
stride=1 if use_aa else stride,
94+
padding=first_dilation,
95+
dilation=first_dilation,
96+
bias=False,
97+
**dd,
98+
)
9299
self.bn1 = norm_layer(first_planes, **dd)
93100
self.drop_block = drop_block() if drop_block is not None else nn.Identity()
94101
self.act1 = act_layer(inplace=True)
95102
self.aa = create_aa(aa_layer, channels=first_planes, stride=stride, enable=use_aa, **dd)
96103

97104
self.conv2 = nn.Conv2d(
98-
first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False, **dd)
105+
first_planes,
106+
outplanes,
107+
kernel_size=3,
108+
padding=dilation,
109+
dilation=dilation,
110+
bias=False,
111+
**dd,
112+
)
99113
self.bn2 = norm_layer(outplanes, **dd)
100114

101115
self.se = create_attn(attn_layer, outplanes, **dd)
@@ -196,8 +210,16 @@ def __init__(
196210
self.act1 = act_layer(inplace=True)
197211

198212
self.conv2 = nn.Conv2d(
199-
first_planes, width, kernel_size=3, stride=1 if use_aa else stride,
200-
padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False, **dd)
213+
first_planes,
214+
width,
215+
kernel_size=3,
216+
stride=1 if use_aa else stride,
217+
padding=first_dilation,
218+
dilation=first_dilation,
219+
groups=cardinality,
220+
bias=False,
221+
**dd,
222+
)
201223
self.bn2 = norm_layer(width, **dd)
202224
self.drop_block = drop_block() if drop_block is not None else nn.Identity()
203225
self.act2 = act_layer(inplace=True)

timm/models/sknet.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,23 +40,39 @@ def __init__(
4040
aa_layer=None,
4141
drop_block=None,
4242
drop_path=None,
43+
device=None,
44+
dtype=None,
4345
):
46+
dd = {'device': device, 'dtype': dtype}
4447
super(SelectiveKernelBasic, self).__init__()
4548

4649
sk_kwargs = sk_kwargs or {}
47-
conv_kwargs = dict(act_layer=act_layer, norm_layer=norm_layer)
50+
conv_kwargs = dict(act_layer=act_layer, norm_layer=norm_layer, **dd)
4851
assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
4952
assert base_width == 64, 'BasicBlock doest not support changing base width'
5053
first_planes = planes // reduce_first
5154
outplanes = planes * self.expansion
5255
first_dilation = first_dilation or dilation
5356

5457
self.conv1 = SelectiveKernel(
55-
inplanes, first_planes, stride=stride, dilation=first_dilation,
56-
aa_layer=aa_layer, drop_layer=drop_block, **conv_kwargs, **sk_kwargs)
58+
inplanes,
59+
first_planes,
60+
stride=stride,
61+
dilation=first_dilation,
62+
aa_layer=aa_layer,
63+
drop_layer=drop_block,
64+
**conv_kwargs,
65+
**sk_kwargs,
66+
)
5767
self.conv2 = ConvNormAct(
58-
first_planes, outplanes, kernel_size=3, dilation=dilation, apply_act=False, **conv_kwargs)
59-
self.se = create_attn(attn_layer, outplanes)
68+
first_planes,
69+
outplanes,
70+
kernel_size=3,
71+
dilation=dilation,
72+
apply_act=False,
73+
**conv_kwargs,
74+
)
75+
self.se = create_attn(attn_layer, outplanes, **dd)
6076
self.act = act_layer(inplace=True)
6177
self.downsample = downsample
6278
self.drop_path = drop_path
@@ -101,22 +117,33 @@ def __init__(
101117
aa_layer=None,
102118
drop_block=None,
103119
drop_path=None,
120+
device=None,
121+
dtype=None,
104122
):
123+
dd = {'device': device, 'dtype': dtype}
105124
super(SelectiveKernelBottleneck, self).__init__()
106125

107126
sk_kwargs = sk_kwargs or {}
108-
conv_kwargs = dict(act_layer=act_layer, norm_layer=norm_layer)
127+
conv_kwargs = dict(act_layer=act_layer, norm_layer=norm_layer, **dd)
109128
width = int(math.floor(planes * (base_width / 64)) * cardinality)
110129
first_planes = width // reduce_first
111130
outplanes = planes * self.expansion
112131
first_dilation = first_dilation or dilation
113132

114133
self.conv1 = ConvNormAct(inplanes, first_planes, kernel_size=1, **conv_kwargs)
115134
self.conv2 = SelectiveKernel(
116-
first_planes, width, stride=stride, dilation=first_dilation, groups=cardinality,
117-
aa_layer=aa_layer, drop_layer=drop_block, **conv_kwargs, **sk_kwargs)
135+
first_planes,
136+
width,
137+
stride=stride,
138+
dilation=first_dilation,
139+
groups=cardinality,
140+
aa_layer=aa_layer,
141+
drop_layer=drop_block,
142+
**conv_kwargs,
143+
**sk_kwargs,
144+
)
118145
self.conv3 = ConvNormAct(width, outplanes, kernel_size=1, apply_act=False, **conv_kwargs)
119-
self.se = create_attn(attn_layer, outplanes)
146+
self.se = create_attn(attn_layer, outplanes, **dd)
120147
self.act = act_layer(inplace=True)
121148
self.downsample = downsample
122149
self.drop_path = drop_path

0 commit comments

Comments
 (0)