-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
2f2a219
commit 607517b
Showing
1 changed file
with
97 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import torch.nn as nn | ||
from timm.models.layers import DropPath, to_2tuple, trunc_normal_ | ||
|
||
|
||
class Mlp(nn.Module): | ||
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): | ||
super().__init__() | ||
out_features = out_features or in_features | ||
hidden_features = hidden_features or in_features | ||
self.fc1 = nn.Linear(in_features, hidden_features) | ||
self.act = act_layer() | ||
self.fc2 = nn.Linear(hidden_features, out_features) | ||
self.drop = nn.Dropout(drop) | ||
|
||
def forward(self, x): | ||
x = self.fc1(x) | ||
x = self.act(x) | ||
x = self.drop(x) | ||
x = self.fc2(x) | ||
x = self.drop(x) | ||
return x | ||
|
||
|
||
class Attention(nn.Module): | ||
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): | ||
super().__init__() | ||
self.num_heads = num_heads | ||
head_dim = dim // num_heads | ||
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights | ||
self.scale = qk_scale or head_dim ** -0.5 | ||
|
||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | ||
self.attn_drop = nn.Dropout(attn_drop) | ||
self.proj = nn.Linear(dim, dim) | ||
self.proj_drop = nn.Dropout(proj_drop) | ||
|
||
def forward(self, x, return_relation=False): | ||
B, N, C = x.shape | ||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) | ||
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) | ||
|
||
attn = (q @ k.transpose(-2, -1)) * self.scale | ||
attn = attn.softmax(dim=-1) | ||
if return_relation: | ||
return attn, ((v @ v.transpose(-2, -1)) * self.scale).softmax(dim=-1) | ||
attn = self.attn_drop(attn) | ||
|
||
x = (attn @ v).transpose(1, 2).reshape(B, N, C) | ||
x = self.proj(x) | ||
x = self.proj_drop(x) | ||
return x | ||
|
||
|
||
class Block(nn.Module): | ||
|
||
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., | ||
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): | ||
super().__init__() | ||
self.norm1 = norm_layer(dim) | ||
self.attn = Attention( | ||
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) | ||
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here | ||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() | ||
self.norm2 = norm_layer(dim) | ||
mlp_hidden_dim = int(dim * mlp_ratio) | ||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) | ||
|
||
def forward(self, x, return_relation=False): | ||
if return_relation: | ||
qk, vv =self.attn(self.norm1(x), return_relation=True) | ||
return qk, vv | ||
x = x + self.drop_path(self.attn(self.norm1(x))) | ||
x = x + self.drop_path(self.mlp(self.norm2(x))) | ||
return x | ||
|
||
|
||
class PatchEmbed(nn.Module): | ||
""" Image to Patch Embedding | ||
""" | ||
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): | ||
super().__init__() | ||
img_size = to_2tuple(img_size) | ||
patch_size = to_2tuple(patch_size) | ||
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) | ||
self.img_size = img_size | ||
self.patch_size = patch_size | ||
self.num_patches = num_patches | ||
|
||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) | ||
|
||
def forward(self, x): | ||
B, C, H, W = x.shape | ||
# FIXME look at relaxing size constraints | ||
# assert H == self.img_size[0] and W == self.img_size[1], \ | ||
# f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." | ||
x = self.proj(x).flatten(2).transpose(1, 2) | ||
return x |