@@ -33,22 +33,34 @@ def __init__(
33
33
value_dim : int = 64 ,
34
34
attn_drop : float = 0. ,
35
35
proj_drop : float = 0. ,
36
+ device = None ,
37
+ dtype = None ,
36
38
):
37
39
"""Initializer."""
40
+ dd = {'device' : device , 'dtype' : dtype }
38
41
super ().__init__ ()
39
42
dim_out = dim_out or dim
40
43
self .num_heads = num_heads
41
44
self .key_dim = key_dim
42
45
self .value_dim = value_dim
43
46
self .scale = key_dim ** - 0.5
44
47
45
- self .query_proj = nn .Parameter (torch .randn ([ self .num_heads , self .key_dim , dim ] ))
46
- self .key_proj = nn .Parameter (torch .randn ([ dim , self .key_dim ] ))
47
- self .value_proj = nn .Parameter (torch .randn ([ dim , self .value_dim ] ))
48
+ self .query_proj = nn .Parameter (torch .empty (( self .num_heads , self .key_dim , dim ), ** dd ))
49
+ self .key_proj = nn .Parameter (torch .empty (( dim , self .key_dim ), ** dd ))
50
+ self .value_proj = nn .Parameter (torch .empty (( dim , self .value_dim ), ** dd ))
48
51
self .attn_drop = nn .Dropout (attn_drop )
49
- self .out_proj = nn .Parameter (torch .randn ([ dim_out , self .num_heads , self .value_dim ] ))
52
+ self .out_proj = nn .Parameter (torch .empty (( dim_out , self .num_heads , self .value_dim ), ** dd ))
50
53
self .proj_drop = nn .Dropout (proj_drop )
51
54
55
+ self .reset_parameters ()
56
+
57
+ def reset_parameters (self ):
58
+ scale = self .key_proj .shape [0 ] ** - 0.5
59
+ nn .init .normal_ (self .query_proj , std = scale )
60
+ nn .init .normal_ (self .key_proj , std = scale )
61
+ nn .init .normal_ (self .value_proj , std = scale )
62
+ nn .init .normal_ (self .out_proj , std = self .out_proj .shape [0 ] ** - 0.5 )
63
+
52
64
def _reshape_input (self , t ):
53
65
"""Reshapes a tensor to three dimensions, keeping the first and last."""
54
66
s = t .shape
@@ -108,6 +120,8 @@ def __init__(
108
120
proj_drop : float = 0. ,
109
121
norm_layer : Type [nn .Module ] = nn .BatchNorm2d ,
110
122
use_bias : bool = False ,
123
+ device = None ,
124
+ dtype = None ,
111
125
):
112
126
"""Initializer.
113
127
@@ -119,6 +133,7 @@ def __init__(
119
133
kv_stride: Key and value stride size.
120
134
dw_kernel_size: Spatial dimension of the depthwise kernel.
121
135
"""
136
+ dd = {'device' : device , 'dtype' : dtype }
122
137
super ().__init__ ()
123
138
dim_out = dim_out or dim
124
139
self .num_heads = num_heads
@@ -149,6 +164,7 @@ def __init__(
149
164
self .num_heads * self .key_dim ,
150
165
kernel_size = 1 ,
151
166
bias = use_bias ,
167
+ ** dd ,
152
168
))
153
169
154
170
self .key = nn .Sequential ()
@@ -161,6 +177,7 @@ def __init__(
161
177
dilation = dilation ,
162
178
padding = padding ,
163
179
depthwise = True ,
180
+ ** dd ,
164
181
))
165
182
self .key .add_module ('norm' , norm_layer (dim ))
166
183
self .key .add_module ('proj' , create_conv2d (
@@ -169,6 +186,7 @@ def __init__(
169
186
kernel_size = 1 ,
170
187
padding = padding ,
171
188
bias = use_bias ,
189
+ ** dd ,
172
190
))
173
191
174
192
self .value = nn .Sequential ()
@@ -181,29 +199,37 @@ def __init__(
181
199
dilation = dilation ,
182
200
padding = padding ,
183
201
depthwise = True ,
202
+ ** dd ,
184
203
))
185
204
self .value .add_module ('norm' , norm_layer (dim ))
186
205
self .value .add_module ('proj' , create_conv2d (
187
206
dim ,
188
207
self .value_dim ,
189
208
kernel_size = 1 ,
190
209
bias = use_bias ,
210
+ ** dd ,
191
211
))
192
212
193
213
self .attn_drop = nn .Dropout (attn_drop )
194
214
195
215
self .output = nn .Sequential ()
196
216
if self .has_query_strides :
197
- self .output .add_module ('upsample' , nn .Upsample (scale_factor = self .query_strides , mode = 'bilinear' , align_corners = False ))
217
+ self .output .add_module ('upsample' , nn .Upsample (
218
+ scale_factor = self .query_strides ,
219
+ mode = 'bilinear' ,
220
+ align_corners = False
221
+ ))
198
222
self .output .add_module ('proj' , create_conv2d (
199
223
self .value_dim * self .num_heads ,
200
224
dim_out ,
201
225
kernel_size = 1 ,
202
226
bias = use_bias ,
227
+ ** dd ,
203
228
))
204
- self .output .add_module ('drop' , nn .Dropout (proj_drop ))
229
+ self .output .add_module ('drop' , nn .Dropout (proj_drop ))
205
230
206
231
self .einsum = False
232
+ self .init_weights ()
207
233
208
234
def init_weights (self ):
209
235
# using xavier appeared to improve stability for mobilenetv4 hybrid w/ this layer
@@ -304,8 +330,11 @@ def __init__(
304
330
expand_first : bool = False ,
305
331
head_first : bool = False ,
306
332
attn_drop : float = 0. ,
307
- proj_drop : float = 0.
333
+ proj_drop : float = 0. ,
334
+ device = None ,
335
+ dtype = None ,
308
336
):
337
+ dd = {'device' : device , 'dtype' : dtype }
309
338
super ().__init__ ()
310
339
dim_out = dim_out or dim
311
340
dim_attn = dim_out if expand_first else dim
@@ -314,9 +343,9 @@ def __init__(
314
343
self .head_first = head_first
315
344
self .fused_attn = use_fused_attn ()
316
345
317
- self .qkv = nn .Conv2d (dim , dim_attn * 3 , 1 , bias = bias )
346
+ self .qkv = nn .Conv2d (dim , dim_attn * 3 , 1 , bias = bias , ** dd )
318
347
self .attn_drop = nn .Dropout (attn_drop )
319
- self .proj = nn .Conv2d (dim_attn , dim_out , 1 , bias = bias )
348
+ self .proj = nn .Conv2d (dim_attn , dim_out , 1 , bias = bias , ** dd )
320
349
self .proj_drop = nn .Dropout (proj_drop )
321
350
322
351
def forward (self , x , attn_mask : Optional [torch .Tensor ] = None ):
0 commit comments