Skip to content

Commit df656fe

Browse files
committed
complete learnable memory ViT, for efficient fine-tuning and potentially plays into continual learning
1 parent 4e6a42a commit df656fe

File tree

5 files changed

+282
-2
lines changed

5 files changed

+282
-2
lines changed

README.md

+64
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
- [Patch Merger](#patch-merger)
2929
- [Vision Transformer for Small Datasets](#vision-transformer-for-small-datasets)
3030
- [Parallel ViT](#parallel-vit)
31+
- [Learnable Memory ViT](#learnable-memory-vit)
3132
- [Dino](#dino)
3233
- [Accessing Attention](#accessing-attention)
3334
- [Research Ideas](#research-ideas)
@@ -903,6 +904,61 @@ img = torch.randn(4, 3, 256, 256)
903904
preds = v(img) # (4, 1000)
904905
```
905906

907+
## Learnable Memory ViT
908+
909+
<img src="./images/learnable-memory-vit.png" width="350px"></img>
910+
911+
This <a href="https://arxiv.org/abs/2203.15243">paper</a> shows that adding learnable memory tokens at each layer of a vision transformer can greatly enhance fine-tuning results (in addition to learnable task specific CLS token and adapter head).
912+
913+
You can use this with a specially modified `ViT` as follows
914+
915+
```python
916+
import torch
917+
from vit_pytorch.learnable_memory_vit import ViT, Adapter
918+
919+
# normal base ViT
920+
921+
v = ViT(
922+
image_size = 256,
923+
patch_size = 16,
924+
num_classes = 1000,
925+
dim = 1024,
926+
depth = 6,
927+
heads = 8,
928+
mlp_dim = 2048,
929+
dropout = 0.1,
930+
emb_dropout = 0.1
931+
)
932+
933+
img = torch.randn(4, 3, 256, 256)
934+
logits = v(img) # (4, 1000)
935+
936+
# do your usual training with ViT
937+
# ...
938+
939+
940+
# then, to finetune, just pass the ViT into the Adapter class
941+
# you can do this for multiple Adapters, as shown below
942+
943+
adapter1 = Adapter(
944+
vit = v,
945+
num_classes = 2, # number of output classes for this specific task
946+
num_memories_per_layer = 5 # number of learnable memories per layer, 10 was sufficient in paper
947+
)
948+
949+
logits1 = adapter1(img) # (4, 2) - predict 2 classes off frozen ViT backbone with learnable memories and task specific head
950+
951+
# yet another task to finetune on, this time with 4 classes
952+
953+
adapter2 = Adapter(
954+
vit = v,
955+
num_classes = 4,
956+
num_memories_per_layer = 10
957+
)
958+
959+
logits2 = adapter2(img) # (4, 4) - predict 4 classes off frozen ViT backbone with learnable memories and task specific head
960+
961+
```
906962

907963
## Dino
908964

@@ -1442,6 +1498,14 @@ Coming from computer vision and new to transformers? Here are some resources tha
14421498
}
14431499
```
14441500

1501+
```bibtex
1502+
@inproceedings{Sandler2022FinetuningIT,
1503+
title = {Fine-tuning Image Transformers using Learnable Memory},
1504+
author = {Mark Sandler and Andrey Zhmoginov and Max Vladymyrov and Andrew Jackson},
1505+
year = {2022}
1506+
}
1507+
```
1508+
14451509
```bibtex
14461510
@misc{vaswani2017attention,
14471511
title = {Attention Is All You Need},

images/learnable-memory-vit.png

108 KB
Loading

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vit-pytorch',
55
packages = find_packages(exclude=['examples']),
6-
version = '0.30.0',
6+
version = '0.31.0',
77
license='MIT',
88
description = 'Vision Transformer (ViT) - Pytorch',
99
author = 'Phil Wang',

vit_pytorch/learnable_memory_vit.py

+216
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
import torch
2+
from torch import nn
3+
import torch.nn.functional as F
4+
5+
from einops import rearrange, repeat
6+
from einops.layers.torch import Rearrange
7+
8+
# helpers
9+
10+
def exists(val):
11+
return val is not None
12+
13+
def pair(t):
14+
return t if isinstance(t, tuple) else (t, t)
15+
16+
# controlling freezing of layers
17+
18+
def set_module_requires_grad_(module, requires_grad):
19+
for param in module.parameters():
20+
param.requires_grad = requires_grad
21+
22+
def freeze_all_layers_(module):
23+
set_module_requires_grad_(module, False)
24+
25+
def unfreeze_all_layers_(module):
26+
set_module_requires_grad_(module, True)
27+
28+
# classes
29+
30+
class FeedForward(nn.Module):
31+
def __init__(self, dim, hidden_dim, dropout = 0.):
32+
super().__init__()
33+
self.net = nn.Sequential(
34+
nn.LayerNorm(dim),
35+
nn.Linear(dim, hidden_dim),
36+
nn.GELU(),
37+
nn.Dropout(dropout),
38+
nn.Linear(hidden_dim, dim),
39+
nn.Dropout(dropout)
40+
)
41+
def forward(self, x):
42+
return self.net(x)
43+
44+
class Attention(nn.Module):
45+
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
46+
super().__init__()
47+
inner_dim = dim_head * heads
48+
49+
self.heads = heads
50+
self.scale = dim_head ** -0.5
51+
self.norm = nn.LayerNorm(dim)
52+
53+
self.attend = nn.Softmax(dim = -1)
54+
self.dropout = nn.Dropout(dropout)
55+
56+
self.to_q = nn.Linear(dim, inner_dim, bias = False)
57+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
58+
59+
self.to_out = nn.Sequential(
60+
nn.Linear(inner_dim, dim),
61+
nn.Dropout(dropout)
62+
)
63+
64+
def forward(self, x, attn_mask = None, memories = None):
65+
x = self.norm(x)
66+
67+
x_kv = x # input for key / values projection
68+
69+
if exists(memories):
70+
# add memories to key / values if it is passed in
71+
memories = repeat(memories, 'n d -> b n d', b = x.shape[0]) if memories.ndim == 2 else memories
72+
x_kv = torch.cat((x_kv, memories), dim = 1)
73+
74+
qkv = (self.to_q(x), *self.to_kv(x_kv).chunk(2, dim = -1))
75+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
76+
77+
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
78+
79+
if exists(attn_mask):
80+
dots = dots.masked_fill(~attn_mask, -torch.finfo(dots.dtype).max)
81+
82+
attn = self.attend(dots)
83+
attn = self.dropout(attn)
84+
85+
out = torch.matmul(attn, v)
86+
out = rearrange(out, 'b h n d -> b n (h d)')
87+
return self.to_out(out)
88+
89+
class Transformer(nn.Module):
90+
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
91+
super().__init__()
92+
self.layers = nn.ModuleList([])
93+
for _ in range(depth):
94+
self.layers.append(nn.ModuleList([
95+
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
96+
FeedForward(dim, mlp_dim, dropout = dropout)
97+
]))
98+
99+
def forward(self, x, attn_mask = None, memories = None):
100+
for ind, (attn, ff) in enumerate(self.layers):
101+
layer_memories = memories[ind] if exists(memories) else None
102+
103+
x = attn(x, attn_mask = attn_mask, memories = layer_memories) + x
104+
x = ff(x) + x
105+
return x
106+
107+
class ViT(nn.Module):
108+
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
109+
super().__init__()
110+
image_height, image_width = pair(image_size)
111+
patch_height, patch_width = pair(patch_size)
112+
113+
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
114+
115+
num_patches = (image_height // patch_height) * (image_width // patch_width)
116+
patch_dim = channels * patch_height * patch_width
117+
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
118+
119+
self.to_patch_embedding = nn.Sequential(
120+
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
121+
nn.Linear(patch_dim, dim),
122+
)
123+
124+
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
125+
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
126+
self.dropout = nn.Dropout(emb_dropout)
127+
128+
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
129+
130+
self.mlp_head = nn.Sequential(
131+
nn.LayerNorm(dim),
132+
nn.Linear(dim, num_classes)
133+
)
134+
135+
def img_to_tokens(self, img):
136+
x = self.to_patch_embedding(img)
137+
138+
cls_tokens = repeat(self.cls_token, '1 n d -> b n d', b = x.shape[0])
139+
x = torch.cat((cls_tokens, x), dim = 1)
140+
141+
x += self.pos_embedding
142+
x = self.dropout(x)
143+
return x
144+
145+
def forward(self, img):
146+
x = self.img_to_tokens(img)
147+
148+
x = self.transformer(x)
149+
150+
cls_tokens = x[:, 0]
151+
return self.mlp_head(cls_tokens)
152+
153+
# adapter with learnable memories per layer, memory CLS token, and learnable adapter head
154+
155+
class Adapter(nn.Module):
156+
def __init__(
157+
self,
158+
*,
159+
vit,
160+
num_memories_per_layer = 10,
161+
num_classes = 2,
162+
):
163+
super().__init__()
164+
assert isinstance(vit, ViT)
165+
166+
# extract some model variables needed
167+
168+
dim = vit.cls_token.shape[-1]
169+
layers = len(vit.transformer.layers)
170+
num_patches = vit.pos_embedding.shape[-2]
171+
172+
self.vit = vit
173+
174+
# freeze ViT backbone - only memories will be finetuned
175+
176+
freeze_all_layers_(vit)
177+
178+
# learnable parameters
179+
180+
self.memory_cls_token = nn.Parameter(torch.randn(dim))
181+
self.memories_per_layer = nn.Parameter(torch.randn(layers, num_memories_per_layer, dim))
182+
183+
self.mlp_head = nn.Sequential(
184+
nn.LayerNorm(dim),
185+
nn.Linear(dim, num_classes)
186+
)
187+
188+
# specialized attention mask to preserve the output of the original ViT
189+
# it allows the memory CLS token to attend to all other tokens (and the learnable memory layer tokens), but not vice versa
190+
191+
attn_mask = torch.ones((num_patches, num_patches), dtype = torch.bool)
192+
attn_mask = F.pad(attn_mask, (1, num_memories_per_layer), value = False) # main tokens cannot attend to learnable memories per layer
193+
attn_mask = F.pad(attn_mask, (0, 0, 1, 0), value = True) # memory CLS token can attend to everything
194+
self.register_buffer('attn_mask', attn_mask)
195+
196+
def forward(self, img):
197+
b = img.shape[0]
198+
199+
tokens = self.vit.img_to_tokens(img)
200+
201+
# add task specific memory tokens
202+
203+
memory_cls_tokens = repeat(self.memory_cls_token, 'd -> b 1 d', b = b)
204+
tokens = torch.cat((memory_cls_tokens, tokens), dim = 1)
205+
206+
# pass memories along with image tokens through transformer for attending
207+
208+
out = self.vit.transformer(tokens, memories = self.memories_per_layer, attn_mask = self.attn_mask)
209+
210+
# extract memory CLS tokens
211+
212+
memory_cls_tokens = out[:, 0]
213+
214+
# pass through task specific adapter head
215+
216+
return self.mlp_head(memory_cls_tokens)

vit_pytorch/vit.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def forward(self, img):
114114
x = self.to_patch_embedding(img)
115115
b, n, _ = x.shape
116116

117-
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
117+
cls_tokens = repeat(self.cls_token, '1 n d -> b n d', b = b)
118118
x = torch.cat((cls_tokens, x), dim=1)
119119
x += self.pos_embedding[:, :(n + 1)]
120120
x = self.dropout(x)

0 commit comments

Comments
 (0)