Skip to content

Commit 7560ca5

Browse files
committed
Adding dd factory_kwargs to modules in timm/layers, initial model WIP in vision_transformer.py
1 parent e7bd97b commit 7560ca5

37 files changed

+1337
-468
lines changed

timm/layers/attention.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ def __init__(
3636
attn_drop: float = 0.,
3737
proj_drop: float = 0.,
3838
norm_layer: Optional[Type[nn.Module]] = None,
39+
device=None,
40+
dtype=None
3941
) -> None:
4042
"""Initialize the Attention module.
4143
@@ -50,6 +52,7 @@ def __init__(
5052
norm_layer: Normalization layer constructor for QK normalization if enabled
5153
"""
5254
super().__init__()
55+
dd = {'device': device, 'dtype': dtype}
5356
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
5457
if qk_norm or scale_norm:
5558
assert norm_layer is not None, 'norm_layer must be provided if qk_norm or scale_norm is True'
@@ -58,12 +61,12 @@ def __init__(
5861
self.scale = self.head_dim ** -0.5
5962
self.fused_attn = use_fused_attn()
6063

61-
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
62-
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
63-
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
64+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd)
65+
self.q_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity()
66+
self.k_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity()
6467
self.attn_drop = nn.Dropout(attn_drop)
65-
self.norm = norm_layer(dim) if scale_norm else nn.Identity()
66-
self.proj = nn.Linear(dim, dim, bias=proj_bias)
68+
self.norm = norm_layer(dim, **dd) if scale_norm else nn.Identity()
69+
self.proj = nn.Linear(dim, dim, bias=proj_bias, **dd)
6770
self.proj_drop = nn.Dropout(proj_drop)
6871

6972
def forward(
@@ -122,6 +125,8 @@ def __init__(
122125
scale_norm: bool = False,
123126
proj_bias: bool = True,
124127
rotate_half: bool = False,
128+
device=None,
129+
dtype=None,
125130
):
126131
"""Initialize the Attention module.
127132
@@ -140,6 +145,7 @@ def __init__(
140145
rotate_half: Use 'half' ROPE layout instead of default 'interleaved'
141146
"""
142147
super().__init__()
148+
dd = {'device': device, 'dtype': dtype}
143149
if scale_norm or qk_norm:
144150
assert norm_layer is not None, 'norm_layer must be provided if qk_norm or scale_norm is True'
145151
self.num_heads = num_heads
@@ -153,19 +159,19 @@ def __init__(
153159
self.rotate_half = rotate_half
154160

155161
if qkv_fused:
156-
self.qkv = nn.Linear(dim, attn_dim * 3, bias=qkv_bias)
162+
self.qkv = nn.Linear(dim, attn_dim * 3, bias=qkv_bias, **dd)
157163
self.q_proj = self.k_proj = self.v_proj = None
158164
else:
159165
self.qkv = None
160-
self.q_proj = nn.Linear(dim, attn_dim, bias=qkv_bias)
161-
self.k_proj = nn.Linear(dim, attn_dim, bias=qkv_bias)
162-
self.v_proj = nn.Linear(dim, attn_dim, bias=qkv_bias)
166+
self.q_proj = nn.Linear(dim, attn_dim, bias=qkv_bias, **dd)
167+
self.k_proj = nn.Linear(dim, attn_dim, bias=qkv_bias, **dd)
168+
self.v_proj = nn.Linear(dim, attn_dim, bias=qkv_bias, **dd)
163169

164-
self.q_norm = norm_layer(head_dim) if qk_norm else nn.Identity()
165-
self.k_norm = norm_layer(head_dim) if qk_norm else nn.Identity()
170+
self.q_norm = norm_layer(head_dim, **dd) if qk_norm else nn.Identity()
171+
self.k_norm = norm_layer(head_dim, **dd) if qk_norm else nn.Identity()
166172
self.attn_drop = nn.Dropout(attn_drop)
167-
self.norm = norm_layer(attn_dim) if scale_norm else nn.Identity()
168-
self.proj = nn.Linear(attn_dim, dim, bias=proj_bias)
173+
self.norm = norm_layer(attn_dim, **dd) if scale_norm else nn.Identity()
174+
self.proj = nn.Linear(attn_dim, dim, bias=proj_bias, **dd)
169175
self.proj_drop = nn.Dropout(proj_drop)
170176

171177
def forward(

timm/layers/attention2d.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,22 +33,34 @@ def __init__(
3333
value_dim: int = 64,
3434
attn_drop: float = 0.,
3535
proj_drop: float = 0.,
36+
device=None,
37+
dtype=None,
3638
):
3739
"""Initializer."""
40+
dd = {'device': device, 'dtype': dtype}
3841
super().__init__()
3942
dim_out = dim_out or dim
4043
self.num_heads = num_heads
4144
self.key_dim = key_dim
4245
self.value_dim = value_dim
4346
self.scale = key_dim ** -0.5
4447

45-
self.query_proj = nn.Parameter(torch.randn([self.num_heads, self.key_dim, dim]))
46-
self.key_proj = nn.Parameter(torch.randn([dim, self.key_dim]))
47-
self.value_proj = nn.Parameter(torch.randn([dim, self.value_dim]))
48+
self.query_proj = nn.Parameter(torch.empty((self.num_heads, self.key_dim, dim), **dd))
49+
self.key_proj = nn.Parameter(torch.empty((dim, self.key_dim), **dd))
50+
self.value_proj = nn.Parameter(torch.empty((dim, self.value_dim), **dd))
4851
self.attn_drop = nn.Dropout(attn_drop)
49-
self.out_proj = nn.Parameter(torch.randn([dim_out, self.num_heads, self.value_dim]))
52+
self.out_proj = nn.Parameter(torch.empty((dim_out, self.num_heads, self.value_dim), **dd))
5053
self.proj_drop = nn.Dropout(proj_drop)
5154

55+
self.reset_parameters()
56+
57+
def reset_parameters(self):
58+
scale = self.key_proj.shape[0] ** -0.5
59+
nn.init.normal_(self.query_proj, std=scale)
60+
nn.init.normal_(self.key_proj, std=scale)
61+
nn.init.normal_(self.value_proj, std=scale)
62+
nn.init.normal_(self.out_proj, std=self.out_proj.shape[0] ** -0.5)
63+
5264
def _reshape_input(self, t):
5365
"""Reshapes a tensor to three dimensions, keeping the first and last."""
5466
s = t.shape
@@ -108,6 +120,8 @@ def __init__(
108120
proj_drop: float = 0.,
109121
norm_layer: Type[nn.Module] = nn.BatchNorm2d,
110122
use_bias: bool = False,
123+
device=None,
124+
dtype=None,
111125
):
112126
"""Initializer.
113127
@@ -119,6 +133,7 @@ def __init__(
119133
kv_stride: Key and value stride size.
120134
dw_kernel_size: Spatial dimension of the depthwise kernel.
121135
"""
136+
dd = {'device': device, 'dtype': dtype}
122137
super().__init__()
123138
dim_out = dim_out or dim
124139
self.num_heads = num_heads
@@ -149,6 +164,7 @@ def __init__(
149164
self.num_heads * self.key_dim,
150165
kernel_size=1,
151166
bias=use_bias,
167+
**dd,
152168
))
153169

154170
self.key = nn.Sequential()
@@ -161,6 +177,7 @@ def __init__(
161177
dilation=dilation,
162178
padding=padding,
163179
depthwise=True,
180+
**dd,
164181
))
165182
self.key.add_module('norm', norm_layer(dim))
166183
self.key.add_module('proj', create_conv2d(
@@ -169,6 +186,7 @@ def __init__(
169186
kernel_size=1,
170187
padding=padding,
171188
bias=use_bias,
189+
**dd,
172190
))
173191

174192
self.value = nn.Sequential()
@@ -181,29 +199,37 @@ def __init__(
181199
dilation=dilation,
182200
padding=padding,
183201
depthwise=True,
202+
**dd,
184203
))
185204
self.value.add_module('norm', norm_layer(dim))
186205
self.value.add_module('proj', create_conv2d(
187206
dim,
188207
self.value_dim,
189208
kernel_size=1,
190209
bias=use_bias,
210+
**dd,
191211
))
192212

193213
self.attn_drop = nn.Dropout(attn_drop)
194214

195215
self.output = nn.Sequential()
196216
if self.has_query_strides:
197-
self.output.add_module('upsample', nn.Upsample(scale_factor=self.query_strides, mode='bilinear', align_corners=False))
217+
self.output.add_module('upsample', nn.Upsample(
218+
scale_factor=self.query_strides,
219+
mode='bilinear',
220+
align_corners=False
221+
))
198222
self.output.add_module('proj', create_conv2d(
199223
self.value_dim * self.num_heads,
200224
dim_out,
201225
kernel_size=1,
202226
bias=use_bias,
227+
**dd,
203228
))
204-
self.output.add_module('drop', nn.Dropout(proj_drop))
229+
self.output.add_module('drop', nn.Dropout(proj_drop))
205230

206231
self.einsum = False
232+
self.init_weights()
207233

208234
def init_weights(self):
209235
# using xavier appeared to improve stability for mobilenetv4 hybrid w/ this layer
@@ -304,8 +330,11 @@ def __init__(
304330
expand_first: bool = False,
305331
head_first: bool = False,
306332
attn_drop: float = 0.,
307-
proj_drop: float = 0.
333+
proj_drop: float = 0.,
334+
device=None,
335+
dtype=None,
308336
):
337+
dd = {'device': device, 'dtype': dtype}
309338
super().__init__()
310339
dim_out = dim_out or dim
311340
dim_attn = dim_out if expand_first else dim
@@ -314,9 +343,9 @@ def __init__(
314343
self.head_first = head_first
315344
self.fused_attn = use_fused_attn()
316345

317-
self.qkv = nn.Conv2d(dim, dim_attn * 3, 1, bias=bias)
346+
self.qkv = nn.Conv2d(dim, dim_attn * 3, 1, bias=bias, **dd)
318347
self.attn_drop = nn.Dropout(attn_drop)
319-
self.proj = nn.Conv2d(dim_attn, dim_out, 1, bias=bias)
348+
self.proj = nn.Conv2d(dim_attn, dim_out, 1, bias=bias, **dd)
320349
self.proj_drop = nn.Dropout(proj_drop)
321350

322351
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):

timm/layers/attention_pool.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,10 @@ def __init__(
3232
norm_layer: Optional[Type[nn.Module]] = None,
3333
act_layer: Optional[Type[nn.Module]] = nn.GELU,
3434
drop: float = 0.0,
35+
device = None,
36+
dtype = None
3537
):
38+
dd = {'device': device, 'dtype': dtype}
3639
super().__init__()
3740
embed_dim = embed_dim or in_features
3841
out_features = out_features or in_features
@@ -46,28 +49,28 @@ def __init__(
4649

4750
if pos_embed == 'abs':
4851
assert feat_size is not None
49-
self.pos_embed = nn.Parameter(torch.zeros(feat_size, in_features))
52+
self.pos_embed = nn.Parameter(torch.zeros(feat_size, in_features, **dd))
5053
else:
5154
self.pos_embed = None
5255

5356
self.latent_dim = latent_dim or embed_dim
5457
self.latent_len = latent_len
55-
self.latent = nn.Parameter(torch.zeros(1, self.latent_len, embed_dim))
58+
self.latent = nn.Parameter(torch.zeros(1, self.latent_len, embed_dim, **dd))
5659

57-
self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
58-
self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias)
60+
self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias, **dd)
61+
self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias, **dd)
5962
if qk_norm:
6063
qk_norm_layer = norm_layer or nn.LayerNorm
61-
self.q_norm = qk_norm_layer(self.head_dim)
62-
self.k_norm = qk_norm_layer(self.head_dim)
64+
self.q_norm = qk_norm_layer(self.head_dim, **dd)
65+
self.k_norm = qk_norm_layer(self.head_dim, **dd)
6366
else:
6467
self.q_norm = nn.Identity()
6568
self.k_norm = nn.Identity()
66-
self.proj = nn.Linear(embed_dim, embed_dim)
69+
self.proj = nn.Linear(embed_dim, embed_dim, **dd)
6770
self.proj_drop = nn.Dropout(drop)
6871

69-
self.norm = norm_layer(out_features) if norm_layer is not None else nn.Identity()
70-
self.mlp = Mlp(embed_dim, int(embed_dim * mlp_ratio), act_layer=act_layer)
72+
self.norm = norm_layer(out_features, **dd) if norm_layer is not None else nn.Identity()
73+
self.mlp = Mlp(embed_dim, int(embed_dim * mlp_ratio), act_layer=act_layer, **dd)
7174

7275
self.init_weights()
7376

timm/layers/attention_pool2d.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,10 @@ def __init__(
4444
pool_type: str = 'token',
4545
class_token: bool = False,
4646
drop_rate: float = 0.,
47+
device=None,
48+
dtype=None,
4749
):
50+
dd = {'device': device, 'dtype': dtype}
4851
super().__init__()
4952
assert pool_type in ('', 'token')
5053
self.embed_dim = embed_dim = embed_dim or in_features
@@ -64,20 +67,20 @@ def __init__(
6467
self.fused_attn = use_fused_attn()
6568

6669
if class_token:
67-
self.cls_token = nn.Parameter(torch.zeros(1, embed_dim))
70+
self.cls_token = nn.Parameter(torch.zeros(1, embed_dim, **dd))
6871
else:
6972
self.cls_token = None
7073

7174
if qkv_separate:
72-
self.q = nn.Linear(in_features, embed_dim, bias=qkv_bias)
73-
self.k = nn.Linear(in_features, embed_dim, bias=qkv_bias)
74-
self.v = nn.Linear(in_features, embed_dim, bias=qkv_bias)
75+
self.q = nn.Linear(in_features, embed_dim, bias=qkv_bias, **dd)
76+
self.k = nn.Linear(in_features, embed_dim, bias=qkv_bias, **dd)
77+
self.v = nn.Linear(in_features, embed_dim, bias=qkv_bias, **dd)
7578
self.qkv = None
7679
else:
77-
self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias)
80+
self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias, **dd)
7881
self.drop = nn.Dropout(drop_rate)
79-
self.proj = nn.Linear(embed_dim, self.out_features)
80-
self.pos_embed = RotaryEmbedding(self.head_dim, in_pixels=False, ref_feat_shape=ref_feat_size)
82+
self.proj = nn.Linear(embed_dim, self.out_features, **dd)
83+
self.pos_embed = RotaryEmbedding(self.head_dim, in_pixels=False, ref_feat_shape=ref_feat_size, **dd)
8184

8285
def init_weights(self, zero_init_last: bool = False):
8386
if self.qkv is None:
@@ -171,7 +174,10 @@ def __init__(
171174
pool_type: str = 'token',
172175
class_token: bool = False,
173176
drop_rate: float = 0.,
177+
device=None,
178+
dtype=None,
174179
):
180+
dd = {'device': device, 'dtype': dtype}
175181
super().__init__()
176182
assert pool_type in ('', 'token')
177183
self.embed_dim = embed_dim = embed_dim or in_features
@@ -192,21 +198,21 @@ def __init__(
192198
self.fused_attn = use_fused_attn()
193199

194200
if class_token:
195-
self.cls_token = nn.Parameter(torch.zeros(1, embed_dim))
201+
self.cls_token = nn.Parameter(torch.zeros(1, embed_dim, **dd))
196202
else:
197203
self.cls_token = None
198204

199205
if qkv_separate:
200-
self.q = nn.Linear(in_features, embed_dim, bias=qkv_bias)
201-
self.k = nn.Linear(in_features, embed_dim, bias=qkv_bias)
202-
self.v = nn.Linear(in_features, embed_dim, bias=qkv_bias)
206+
self.q = nn.Linear(in_features, embed_dim, bias=qkv_bias, **dd)
207+
self.k = nn.Linear(in_features, embed_dim, bias=qkv_bias, **dd)
208+
self.v = nn.Linear(in_features, embed_dim, bias=qkv_bias, **dd)
203209
self.qkv = None
204210
else:
205211
self.q = self.k = self.v = None
206-
self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias)
212+
self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias, **dd)
207213
self.drop = nn.Dropout(drop_rate)
208-
self.proj = nn.Linear(embed_dim, self.out_features)
209-
self.pos_embed = nn.Parameter(torch.zeros(self.seq_len + 1, in_features))
214+
self.proj = nn.Linear(embed_dim, self.out_features, **dd)
215+
self.pos_embed = nn.Parameter(torch.zeros(self.seq_len + 1, in_features, **dd))
210216

211217
self.init_weights()
212218

timm/layers/blur_pool.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,10 @@ def __init__(
3636
filt_size: int = 3,
3737
stride: int = 2,
3838
pad_mode: str = 'reflect',
39+
device=None,
40+
dtype=None
3941
) -> None:
42+
dd = {'device': device, 'dtype': dtype}
4043
super(BlurPool2d, self).__init__()
4144
assert filt_size > 1
4245
self.channels = channels
@@ -48,7 +51,7 @@ def __init__(
4851
# (0.5 + 0.5 x)^N => coefficients = C(N,k) / 2^N, k = 0..N
4952
coeffs = torch.tensor(
5053
[comb(filt_size - 1, k) for k in range(filt_size)],
51-
dtype=torch.float32,
54+
**dd,
5255
) / (2 ** (filt_size - 1)) # normalise so coefficients sum to 1
5356
blur_filter = (coeffs[:, None] * coeffs[None, :])[None, None, :, :]
5457
if channels is not None:
@@ -71,7 +74,9 @@ def create_aa(
7174
channels: Optional[int] = None,
7275
stride: int = 2,
7376
enable: bool = True,
74-
noop: Optional[Type[nn.Module]] = nn.Identity
77+
noop: Optional[Type[nn.Module]] = nn.Identity,
78+
device=None,
79+
dtype=None,
7580
) -> nn.Module:
7681
""" Anti-aliasing """
7782
if not aa_layer or not enable:
@@ -82,9 +87,9 @@ def create_aa(
8287
if aa_layer == 'avg' or aa_layer == 'avgpool':
8388
aa_layer = nn.AvgPool2d
8489
elif aa_layer == 'blur' or aa_layer == 'blurpool':
85-
aa_layer = BlurPool2d
90+
aa_layer = partial(BlurPool2d, device=device, dtype=dtype)
8691
elif aa_layer == 'blurpc':
87-
aa_layer = partial(BlurPool2d, pad_mode='constant')
92+
aa_layer = partial(BlurPool2d, pad_mode='constant', device=device, dtype=dtype)
8893

8994
else:
9095
assert False, f"Unknown anti-aliasing layer ({aa_layer})."

0 commit comments

Comments
 (0)