|
4 | 4 |
|
5 | 5 | from x_transformers import *
|
6 | 6 | from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
|
| 7 | +from timm.models.vision_transformer import VisionTransformer |
| 8 | +from timm.models.resnetv2 import ResNetV2 |
| 9 | +from timm.models.layers import StdConv2dSame |
7 | 10 | from einops import rearrange, repeat
|
8 | 11 |
|
9 | 12 |
|
10 |
| -class ViTransformerWrapper(nn.Module): |
11 |
| - def __init__( |
12 |
| - self, |
13 |
| - *, |
14 |
| - max_width, |
15 |
| - max_height, |
16 |
| - patch_size, |
17 |
| - attn_layers, |
18 |
| - channels=1, |
19 |
| - num_classes=None, |
20 |
| - dropout=0., |
21 |
| - emb_dropout=0. |
22 |
| - ): |
| 13 | +class Model(nn.Module): |
| 14 | + def __init__(self, encoder: Encoder, decoder: AutoregressiveWrapper, args): |
23 | 15 | super().__init__()
|
24 |
| - assert isinstance(attn_layers, Encoder), 'attention layers must be an Encoder' |
25 |
| - assert max_width % patch_size == 0 and max_height % patch_size == 0, 'image dimensions must be divisible by the patch size' |
26 |
| - dim = attn_layers.dim |
27 |
| - num_patches = (max_width // patch_size)*(max_height // patch_size) |
28 |
| - patch_dim = channels * patch_size ** 2 |
29 |
| - |
30 |
| - self.patch_size = patch_size |
31 |
| - self.max_width = max_width |
32 |
| - self.max_height = max_height |
| 16 | + self.encoder = encoder |
| 17 | + self.decoder = decoder |
| 18 | + self.args = args |
33 | 19 |
|
34 |
| - self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) |
35 |
| - self.patch_to_embedding = nn.Linear(patch_dim, dim) |
36 |
| - self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) |
37 |
| - self.dropout = nn.Dropout(emb_dropout) |
| 20 | + def forward(self, x: torch.Tensor): |
| 21 | + return self.decoder.generate(torch.LongTensor([self.args.bos_token]*len(x)).to(x.device), self.args.max_seq_len, eos_token=self.args.eos_token, context=self.encoder(x)) |
38 | 22 |
|
39 |
| - self.attn_layers = attn_layers |
40 |
| - self.norm = nn.LayerNorm(dim) |
41 |
| - #self.mlp_head = FeedForward(dim, dim_out = num_classes, dropout = dropout) if exists(num_classes) else None |
42 | 23 |
|
43 |
| - def forward(self, img, **kwargs): |
44 |
| - p = self.patch_size |
| 24 | +class CustomVisionTransformer(VisionTransformer): |
| 25 | + def __init__(self, img_size=224, *args, **kwargs): |
| 26 | + super(CustomVisionTransformer, self).__init__(img_size=img_size, *args, **kwargs) |
| 27 | + self.height, self.width = img_size |
| 28 | + self.patch_size = 16 |
45 | 29 |
|
46 |
| - x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p) |
47 |
| - x = self.patch_to_embedding(x) |
48 |
| - b, n, _ = x.shape |
| 30 | + def forward_features(self, x): |
| 31 | + B, c, h, w = x.shape |
| 32 | + x = self.patch_embed(x) |
49 | 33 |
|
50 |
| - cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b) |
| 34 | + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks |
51 | 35 | x = torch.cat((cls_tokens, x), dim=1)
|
52 |
| - h, w = torch.tensor(img.shape[2:])//p |
53 |
| - pos_emb_ind = repeat(torch.arange(h)*(self.max_width//p-w), 'h -> (h w)', w=w)+torch.arange(h*w) |
| 36 | + h, w = h//self.patch_size, w//self.patch_size |
| 37 | + pos_emb_ind = repeat(torch.arange(h)*(self.width//self.patch_size-w), 'h -> (h w)', w=w)+torch.arange(h*w) |
54 | 38 | pos_emb_ind = torch.cat((torch.zeros(1), pos_emb_ind+1), dim=0).long()
|
55 |
| - x += self.pos_embedding[:, pos_emb_ind] |
56 |
| - x = self.dropout(x) |
| 39 | + x += self.pos_embed[:, pos_emb_ind] |
| 40 | + #x = x + self.pos_embed |
| 41 | + x = self.pos_drop(x) |
57 | 42 |
|
58 |
| - x = self.attn_layers(x, **kwargs) |
59 |
| - x = self.norm(x) |
| 43 | + for blk in self.blocks: |
| 44 | + x = blk(x) |
60 | 45 |
|
| 46 | + x = self.norm(x) |
61 | 47 | return x
|
62 | 48 |
|
63 | 49 |
|
64 |
| -class Model(nn.Module): |
65 |
| - def __init__(self, encoder: Encoder, decoder: AutoregressiveWrapper, args): |
| 50 | +class CNNBackbone(nn.Module): |
| 51 | + def __init__(self, feature_dim=512, channels=1, out_dim=128, depth=5, kernel_size=3, stride=1, padding=1, **kwargs): |
66 | 52 | super().__init__()
|
67 |
| - self.encoder = encoder |
68 |
| - self.decoder = decoder |
69 |
| - self.args = args |
| 53 | + dims = [channels]+[feature_dim]*(depth-1)+[out_dim] |
| 54 | + layers = [] |
| 55 | + for i in range(depth): |
| 56 | + layers.append(nn.Conv2d(dims[i], dims[i+1], kernel_size=kernel_size, stride=stride, padding=padding)) |
| 57 | + layers.append(nn.ReLU()) |
70 | 58 |
|
71 |
| - def forward(self, x: torch.Tensor): |
72 |
| - return self.decoder.generate(torch.LongTensor([self.args.bos_token]*len(x)).to(x.device), self.args.max_seq_len, eos_token=self.args.eos_token, context=self.encoder(x)) |
| 59 | + self.model = nn.Sequential(*layers[:-1]) |
| 60 | + |
| 61 | + def forward(self, x): |
| 62 | + return self.model(x) |
73 | 63 |
|
74 | 64 |
|
75 | 65 | def get_model(args):
|
76 |
| - encoder = ViTransformerWrapper( |
77 |
| - max_width=args.max_width, |
78 |
| - max_height=args.max_height, |
79 |
| - channels=args.channels, |
80 |
| - patch_size=args.patch_size, |
81 |
| - attn_layers=Encoder( |
82 |
| - dim=args.dim, |
83 |
| - depth=args.num_layers, |
84 |
| - heads=args.heads, |
85 |
| - ) |
86 |
| - ).to(args.device) |
| 66 | + #backbone = CNNBackbone(args.backbone_dim, depth=args.backbone_depth, channels=args.channels) |
| 67 | + backbone = ResNetV2( |
| 68 | + layers=(3, 4, 9), num_classes=0, global_pool='', in_chans=args.channels, |
| 69 | + preact=False, stem_type='same', conv_layer=StdConv2dSame) |
| 70 | + encoder = CustomVisionTransformer(img_size=(args.max_height, args.max_width), |
| 71 | + patch_size=args.patch_size, |
| 72 | + in_chans=args.channels, |
| 73 | + num_classes=0, |
| 74 | + embed_dim=args.dim, |
| 75 | + depth=args.encoder_depth, |
| 76 | + num_heads=args.heads, |
| 77 | + hybrid_backbone=backbone |
| 78 | + ).to(args.device) |
87 | 79 |
|
88 | 80 | decoder = AutoregressiveWrapper(
|
89 | 81 | TransformerWrapper(
|
|
0 commit comments