26
26
import itertools
27
27
from copy import deepcopy
28
28
from functools import partial
29
+ from typing import Dict
29
30
30
31
import torch
31
32
import torch .nn as nn
@@ -255,6 +256,8 @@ def forward(self, x):
255
256
256
257
257
258
class Attention (nn .Module ):
259
+ ab : Dict [str , torch .Tensor ]
260
+
258
261
def __init__ (
259
262
self , dim , key_dim , num_heads = 8 , attn_ratio = 4 , act_layer = None , resolution = 14 , use_conv = False ):
260
263
super ().__init__ ()
@@ -286,20 +289,31 @@ def __init__(
286
289
idxs .append (attention_offsets [offset ])
287
290
self .attention_biases = nn .Parameter (torch .zeros (num_heads , len (attention_offsets )))
288
291
self .register_buffer ('attention_bias_idxs' , torch .LongTensor (idxs ).view (N , N ))
289
- self .ab = None
292
+ self .ab = {}
290
293
291
294
@torch .no_grad ()
292
295
def train (self , mode = True ):
293
296
super ().train (mode )
294
- self .ab = None if mode else self .attention_biases [:, self .attention_bias_idxs ]
297
+ if mode and self .ab :
298
+ self .ab = {} # clear ab cache
299
+
300
+ def get_attention_biases (self , device : torch .device ) -> torch .Tensor :
301
+ if self .training :
302
+ return self .attention_biases [:, self .attention_bias_idxs ]
303
+ else :
304
+ device_key = str (device )
305
+ if device_key not in self .ab :
306
+ self .ab [device_key ] = self .attention_biases [:, self .attention_bias_idxs ]
307
+ return self .ab [device_key ]
295
308
296
309
def forward (self , x ): # x (B,C,H,W)
297
310
if self .use_conv :
298
311
B , C , H , W = x .shape
299
312
q , k , v = self .qkv (x ).view (B , self .num_heads , - 1 , H * W ).split ([self .key_dim , self .key_dim , self .d ], dim = 2 )
300
- ab = self . attention_biases [:, self . attention_bias_idxs ] if self . ab is None else self . ab
301
- attn = (q .transpose (- 2 , - 1 ) @ k ) * self .scale + ab
313
+
314
+ attn = (q .transpose (- 2 , - 1 ) @ k ) * self .scale + self . get_attention_biases ( x . device )
302
315
attn = attn .softmax (dim = - 1 )
316
+
303
317
x = (v @ attn .transpose (- 2 , - 1 )).view (B , - 1 , H , W )
304
318
else :
305
319
B , N , C = x .shape
@@ -308,15 +322,18 @@ def forward(self, x): # x (B,C,H,W)
308
322
q = q .permute (0 , 2 , 1 , 3 )
309
323
k = k .permute (0 , 2 , 1 , 3 )
310
324
v = v .permute (0 , 2 , 1 , 3 )
311
- ab = self . attention_biases [:, self . attention_bias_idxs ] if self . ab is None else self . ab
312
- attn = q @ k .transpose (- 2 , - 1 ) * self .scale + ab
325
+
326
+ attn = q @ k .transpose (- 2 , - 1 ) * self .scale + self . get_attention_biases ( x . device )
313
327
attn = attn .softmax (dim = - 1 )
328
+
314
329
x = (attn @ v ).transpose (1 , 2 ).reshape (B , N , self .dh )
315
330
x = self .proj (x )
316
331
return x
317
332
318
333
319
334
class AttentionSubsample (nn .Module ):
335
+ ab : Dict [str , torch .Tensor ]
336
+
320
337
def __init__ (
321
338
self , in_dim , out_dim , key_dim , num_heads = 8 , attn_ratio = 2 ,
322
339
act_layer = None , stride = 2 , resolution = 14 , resolution_ = 7 , use_conv = False ):
@@ -366,21 +383,30 @@ def __init__(
366
383
idxs .append (attention_offsets [offset ])
367
384
self .attention_biases = nn .Parameter (torch .zeros (num_heads , len (attention_offsets )))
368
385
self .register_buffer ('attention_bias_idxs' , torch .LongTensor (idxs ).view (N_ , N ))
369
- self .ab = None
386
+ self .ab = {} # per-device attention_biases cache
370
387
371
388
@torch .no_grad ()
372
389
def train (self , mode = True ):
373
390
super ().train (mode )
374
- self .ab = None if mode else self .attention_biases [:, self .attention_bias_idxs ]
391
+ if mode and self .ab :
392
+ self .ab = {} # clear ab cache
393
+
394
+ def get_attention_biases (self , device : torch .device ) -> torch .Tensor :
395
+ if self .training :
396
+ return self .attention_biases [:, self .attention_bias_idxs ]
397
+ else :
398
+ device_key = str (device )
399
+ if device_key not in self .ab :
400
+ self .ab [device_key ] = self .attention_biases [:, self .attention_bias_idxs ]
401
+ return self .ab [device_key ]
375
402
376
403
def forward (self , x ):
377
404
if self .use_conv :
378
405
B , C , H , W = x .shape
379
406
k , v = self .kv (x ).view (B , self .num_heads , - 1 , H * W ).split ([self .key_dim , self .d ], dim = 2 )
380
407
q = self .q (x ).view (B , self .num_heads , self .key_dim , self .resolution_2 )
381
408
382
- ab = self .attention_biases [:, self .attention_bias_idxs ] if self .ab is None else self .ab
383
- attn = (q .transpose (- 2 , - 1 ) @ k ) * self .scale + ab
409
+ attn = (q .transpose (- 2 , - 1 ) @ k ) * self .scale + self .get_attention_biases (x .device )
384
410
attn = attn .softmax (dim = - 1 )
385
411
386
412
x = (v @ attn .transpose (- 2 , - 1 )).reshape (B , - 1 , self .resolution_ , self .resolution_ )
@@ -391,8 +417,7 @@ def forward(self, x):
391
417
v = v .permute (0 , 2 , 1 , 3 ) # BHNC
392
418
q = self .q (x ).view (B , self .resolution_2 , self .num_heads , self .key_dim ).permute (0 , 2 , 1 , 3 )
393
419
394
- ab = self .attention_biases [:, self .attention_bias_idxs ] if self .ab is None else self .ab
395
- attn = q @ k .transpose (- 2 , - 1 ) * self .scale + ab
420
+ attn = q @ k .transpose (- 2 , - 1 ) * self .scale + self .get_attention_biases (x .device )
396
421
attn = attn .softmax (dim = - 1 )
397
422
398
423
x = (attn @ v ).transpose (1 , 2 ).reshape (B , - 1 , self .dh )
0 commit comments