|
| 1 | +import torch |
| 2 | +import torch.nn as nn |
| 3 | + |
| 4 | +from x_transformers import Encoder |
| 5 | +from einops import rearrange, repeat |
| 6 | + |
| 7 | + |
| 8 | +class ViTransformerWrapper(nn.Module): |
| 9 | + def __init__( |
| 10 | + self, |
| 11 | + *, |
| 12 | + max_width, |
| 13 | + max_height, |
| 14 | + patch_size, |
| 15 | + attn_layers, |
| 16 | + channels=1, |
| 17 | + num_classes=None, |
| 18 | + dropout=0., |
| 19 | + emb_dropout=0. |
| 20 | + ): |
| 21 | + super().__init__() |
| 22 | + assert isinstance(attn_layers, Encoder), 'attention layers must be an Encoder' |
| 23 | + assert max_width % patch_size == 0 and max_height % patch_size == 0, 'image dimensions must be divisible by the patch size' |
| 24 | + dim = attn_layers.dim |
| 25 | + num_patches = (max_width // patch_size)*(max_height // patch_size) |
| 26 | + patch_dim = channels * patch_size ** 2 |
| 27 | + |
| 28 | + self.patch_size = patch_size |
| 29 | + self.max_width = max_width |
| 30 | + self.max_height = max_height |
| 31 | + |
| 32 | + self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) |
| 33 | + self.patch_to_embedding = nn.Linear(patch_dim, dim) |
| 34 | + self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) |
| 35 | + self.dropout = nn.Dropout(emb_dropout) |
| 36 | + |
| 37 | + self.attn_layers = attn_layers |
| 38 | + self.norm = nn.LayerNorm(dim) |
| 39 | + #self.mlp_head = FeedForward(dim, dim_out = num_classes, dropout = dropout) if exists(num_classes) else None |
| 40 | + |
| 41 | + def forward(self, img, **kwargs): |
| 42 | + p = self.patch_size |
| 43 | + |
| 44 | + x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p) |
| 45 | + x = self.patch_to_embedding(x) |
| 46 | + b, n, _ = x.shape |
| 47 | + |
| 48 | + cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b) |
| 49 | + x = torch.cat((cls_tokens, x), dim=1) |
| 50 | + h, w = torch.tensor(img.shape[2:])//p |
| 51 | + pos_emb_ind = repeat(torch.arange(h)*(self.max_width//p-w), 'h -> (h w)', w=w)+torch.arange(h*w) |
| 52 | + pos_emb_ind = torch.cat((torch.zeros(1), pos_emb_ind+1), dim=0).long() |
| 53 | + x += self.pos_embedding[:, pos_emb_ind] |
| 54 | + x = self.dropout(x) |
| 55 | + |
| 56 | + x = self.attn_layers(x, **kwargs) |
| 57 | + x = self.norm(x) |
| 58 | + |
| 59 | + return x |
| 60 | + |
| 61 | + |
| 62 | +def get_encoder(args): |
| 63 | + return ViTransformerWrapper( |
| 64 | + max_width=args.max_width, |
| 65 | + max_height=args.max_height, |
| 66 | + channels=args.channels, |
| 67 | + patch_size=args.patch_size, |
| 68 | + emb_dropout=args.get('emb_dropout', 0), |
| 69 | + attn_layers=Encoder( |
| 70 | + dim=args.dim, |
| 71 | + depth=args.num_layers, |
| 72 | + heads=args.heads, |
| 73 | + ) |
| 74 | + ) |
0 commit comments