Skip to content

Commit 11ae795

Browse files
committed
Redo LeViT attention bias caching in a way that works with both torchscript and DataParallel
1 parent d400f1d commit 11ae795

File tree

1 file changed

+37
-12
lines changed

1 file changed

+37
-12
lines changed

timm/models/levit.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import itertools
2727
from copy import deepcopy
2828
from functools import partial
29+
from typing import Dict
2930

3031
import torch
3132
import torch.nn as nn
@@ -255,6 +256,8 @@ def forward(self, x):
255256

256257

257258
class Attention(nn.Module):
259+
ab: Dict[str, torch.Tensor]
260+
258261
def __init__(
259262
self, dim, key_dim, num_heads=8, attn_ratio=4, act_layer=None, resolution=14, use_conv=False):
260263
super().__init__()
@@ -286,20 +289,31 @@ def __init__(
286289
idxs.append(attention_offsets[offset])
287290
self.attention_biases = nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
288291
self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N))
289-
self.ab = None
292+
self.ab = {}
290293

291294
@torch.no_grad()
292295
def train(self, mode=True):
293296
super().train(mode)
294-
self.ab = None if mode else self.attention_biases[:, self.attention_bias_idxs]
297+
if mode and self.ab:
298+
self.ab = {} # clear ab cache
299+
300+
def get_attention_biases(self, device: torch.device) -> torch.Tensor:
301+
if self.training:
302+
return self.attention_biases[:, self.attention_bias_idxs]
303+
else:
304+
device_key = str(device)
305+
if device_key not in self.ab:
306+
self.ab[device_key] = self.attention_biases[:, self.attention_bias_idxs]
307+
return self.ab[device_key]
295308

296309
def forward(self, x): # x (B,C,H,W)
297310
if self.use_conv:
298311
B, C, H, W = x.shape
299312
q, k, v = self.qkv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.key_dim, self.d], dim=2)
300-
ab = self.attention_biases[:, self.attention_bias_idxs] if self.ab is None else self.ab
301-
attn = (q.transpose(-2, -1) @ k) * self.scale + ab
313+
314+
attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device)
302315
attn = attn.softmax(dim=-1)
316+
303317
x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W)
304318
else:
305319
B, N, C = x.shape
@@ -308,15 +322,18 @@ def forward(self, x): # x (B,C,H,W)
308322
q = q.permute(0, 2, 1, 3)
309323
k = k.permute(0, 2, 1, 3)
310324
v = v.permute(0, 2, 1, 3)
311-
ab = self.attention_biases[:, self.attention_bias_idxs] if self.ab is None else self.ab
312-
attn = q @ k.transpose(-2, -1) * self.scale + ab
325+
326+
attn = q @ k.transpose(-2, -1) * self.scale + self.get_attention_biases(x.device)
313327
attn = attn.softmax(dim=-1)
328+
314329
x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
315330
x = self.proj(x)
316331
return x
317332

318333

319334
class AttentionSubsample(nn.Module):
335+
ab: Dict[str, torch.Tensor]
336+
320337
def __init__(
321338
self, in_dim, out_dim, key_dim, num_heads=8, attn_ratio=2,
322339
act_layer=None, stride=2, resolution=14, resolution_=7, use_conv=False):
@@ -366,21 +383,30 @@ def __init__(
366383
idxs.append(attention_offsets[offset])
367384
self.attention_biases = nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
368385
self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N_, N))
369-
self.ab = None
386+
self.ab = {} # per-device attention_biases cache
370387

371388
@torch.no_grad()
372389
def train(self, mode=True):
373390
super().train(mode)
374-
self.ab = None if mode else self.attention_biases[:, self.attention_bias_idxs]
391+
if mode and self.ab:
392+
self.ab = {} # clear ab cache
393+
394+
def get_attention_biases(self, device: torch.device) -> torch.Tensor:
395+
if self.training:
396+
return self.attention_biases[:, self.attention_bias_idxs]
397+
else:
398+
device_key = str(device)
399+
if device_key not in self.ab:
400+
self.ab[device_key] = self.attention_biases[:, self.attention_bias_idxs]
401+
return self.ab[device_key]
375402

376403
def forward(self, x):
377404
if self.use_conv:
378405
B, C, H, W = x.shape
379406
k, v = self.kv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.d], dim=2)
380407
q = self.q(x).view(B, self.num_heads, self.key_dim, self.resolution_2)
381408

382-
ab = self.attention_biases[:, self.attention_bias_idxs] if self.ab is None else self.ab
383-
attn = (q.transpose(-2, -1) @ k) * self.scale + ab
409+
attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device)
384410
attn = attn.softmax(dim=-1)
385411

386412
x = (v @ attn.transpose(-2, -1)).reshape(B, -1, self.resolution_, self.resolution_)
@@ -391,8 +417,7 @@ def forward(self, x):
391417
v = v.permute(0, 2, 1, 3) # BHNC
392418
q = self.q(x).view(B, self.resolution_2, self.num_heads, self.key_dim).permute(0, 2, 1, 3)
393419

394-
ab = self.attention_biases[:, self.attention_bias_idxs] if self.ab is None else self.ab
395-
attn = q @ k.transpose(-2, -1) * self.scale + ab
420+
attn = q @ k.transpose(-2, -1) * self.scale + self.get_attention_biases(x.device)
396421
attn = attn.softmax(dim=-1)
397422

398423
x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh)

0 commit comments

Comments
 (0)