Skip to content

Commit 6ecc3f4

Browse files
committed
hybrid cnn vision transformer
1 parent ba22024 commit 6ecc3f4

File tree

3 files changed

+56
-59
lines changed

3 files changed

+56
-59
lines changed

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -128,5 +128,7 @@ dmypy.json
128128
# Pyre type checker
129129
.pyre/
130130

131+
notebooks/
132+
.ipynb_checkpoints/
131133
dataset/data
132134
wandb/

models.py

+51-59
Original file line numberDiff line numberDiff line change
@@ -4,86 +4,78 @@
44

55
from x_transformers import *
66
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
7+
from timm.models.vision_transformer import VisionTransformer
8+
from timm.models.resnetv2 import ResNetV2
9+
from timm.models.layers import StdConv2dSame
710
from einops import rearrange, repeat
811

912

10-
class ViTransformerWrapper(nn.Module):
11-
def __init__(
12-
self,
13-
*,
14-
max_width,
15-
max_height,
16-
patch_size,
17-
attn_layers,
18-
channels=1,
19-
num_classes=None,
20-
dropout=0.,
21-
emb_dropout=0.
22-
):
13+
class Model(nn.Module):
14+
def __init__(self, encoder: Encoder, decoder: AutoregressiveWrapper, args):
2315
super().__init__()
24-
assert isinstance(attn_layers, Encoder), 'attention layers must be an Encoder'
25-
assert max_width % patch_size == 0 and max_height % patch_size == 0, 'image dimensions must be divisible by the patch size'
26-
dim = attn_layers.dim
27-
num_patches = (max_width // patch_size)*(max_height // patch_size)
28-
patch_dim = channels * patch_size ** 2
29-
30-
self.patch_size = patch_size
31-
self.max_width = max_width
32-
self.max_height = max_height
16+
self.encoder = encoder
17+
self.decoder = decoder
18+
self.args = args
3319

34-
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
35-
self.patch_to_embedding = nn.Linear(patch_dim, dim)
36-
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
37-
self.dropout = nn.Dropout(emb_dropout)
20+
def forward(self, x: torch.Tensor):
21+
return self.decoder.generate(torch.LongTensor([self.args.bos_token]*len(x)).to(x.device), self.args.max_seq_len, eos_token=self.args.eos_token, context=self.encoder(x))
3822

39-
self.attn_layers = attn_layers
40-
self.norm = nn.LayerNorm(dim)
41-
#self.mlp_head = FeedForward(dim, dim_out = num_classes, dropout = dropout) if exists(num_classes) else None
4223

43-
def forward(self, img, **kwargs):
44-
p = self.patch_size
24+
class CustomVisionTransformer(VisionTransformer):
25+
def __init__(self, img_size=224, *args, **kwargs):
26+
super(CustomVisionTransformer, self).__init__(img_size=img_size, *args, **kwargs)
27+
self.height, self.width = img_size
28+
self.patch_size = 16
4529

46-
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
47-
x = self.patch_to_embedding(x)
48-
b, n, _ = x.shape
30+
def forward_features(self, x):
31+
B, c, h, w = x.shape
32+
x = self.patch_embed(x)
4933

50-
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
34+
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
5135
x = torch.cat((cls_tokens, x), dim=1)
52-
h, w = torch.tensor(img.shape[2:])//p
53-
pos_emb_ind = repeat(torch.arange(h)*(self.max_width//p-w), 'h -> (h w)', w=w)+torch.arange(h*w)
36+
h, w = h//self.patch_size, w//self.patch_size
37+
pos_emb_ind = repeat(torch.arange(h)*(self.width//self.patch_size-w), 'h -> (h w)', w=w)+torch.arange(h*w)
5438
pos_emb_ind = torch.cat((torch.zeros(1), pos_emb_ind+1), dim=0).long()
55-
x += self.pos_embedding[:, pos_emb_ind]
56-
x = self.dropout(x)
39+
x += self.pos_embed[:, pos_emb_ind]
40+
#x = x + self.pos_embed
41+
x = self.pos_drop(x)
5742

58-
x = self.attn_layers(x, **kwargs)
59-
x = self.norm(x)
43+
for blk in self.blocks:
44+
x = blk(x)
6045

46+
x = self.norm(x)
6147
return x
6248

6349

64-
class Model(nn.Module):
65-
def __init__(self, encoder: Encoder, decoder: AutoregressiveWrapper, args):
50+
class CNNBackbone(nn.Module):
51+
def __init__(self, feature_dim=512, channels=1, out_dim=128, depth=5, kernel_size=3, stride=1, padding=1, **kwargs):
6652
super().__init__()
67-
self.encoder = encoder
68-
self.decoder = decoder
69-
self.args = args
53+
dims = [channels]+[feature_dim]*(depth-1)+[out_dim]
54+
layers = []
55+
for i in range(depth):
56+
layers.append(nn.Conv2d(dims[i], dims[i+1], kernel_size=kernel_size, stride=stride, padding=padding))
57+
layers.append(nn.ReLU())
7058

71-
def forward(self, x: torch.Tensor):
72-
return self.decoder.generate(torch.LongTensor([self.args.bos_token]*len(x)).to(x.device), self.args.max_seq_len, eos_token=self.args.eos_token, context=self.encoder(x))
59+
self.model = nn.Sequential(*layers[:-1])
60+
61+
def forward(self, x):
62+
return self.model(x)
7363

7464

7565
def get_model(args):
76-
encoder = ViTransformerWrapper(
77-
max_width=args.max_width,
78-
max_height=args.max_height,
79-
channels=args.channels,
80-
patch_size=args.patch_size,
81-
attn_layers=Encoder(
82-
dim=args.dim,
83-
depth=args.num_layers,
84-
heads=args.heads,
85-
)
86-
).to(args.device)
66+
#backbone = CNNBackbone(args.backbone_dim, depth=args.backbone_depth, channels=args.channels)
67+
backbone = ResNetV2(
68+
layers=(3, 4, 9), num_classes=0, global_pool='', in_chans=args.channels,
69+
preact=False, stem_type='same', conv_layer=StdConv2dSame)
70+
encoder = CustomVisionTransformer(img_size=(args.max_height, args.max_width),
71+
patch_size=args.patch_size,
72+
in_chans=args.channels,
73+
num_classes=0,
74+
embed_dim=args.dim,
75+
depth=args.encoder_depth,
76+
num_heads=args.heads,
77+
hybrid_backbone=backbone
78+
).to(args.device)
8779

8880
decoder = AutoregressiveWrapper(
8981
TransformerWrapper(

settings/default.yaml

+3
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ channels: 1
2121
patch_size: 32
2222
# Encoder / Decoder
2323
dim: 128
24+
backbone_dim: 256
25+
backbone_depth: 4
26+
encoder_depth: 4
2427
num_layers: 4
2528
heads: 8
2629
num_tokens: 8000

0 commit comments

Comments
 (0)