Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Video Vision Transformer implementation #62

Merged
merged 17 commits into from
Feb 1, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
added model2 of vivit (spatial-transformer and temporal transformer i…
…n series) with simple linear embedding..
  • Loading branch information
abhi-glitchhg committed Jan 23, 2022
commit b7fdeb389199ed1e451151505bbcd477172db4ba
21 changes: 21 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,3 +461,24 @@ def test_dpt():
out = model(img)
assert out.shape == (4, 384, 384)
del model


def test_Vivit():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def test_Vivit():
def test_ViViT():

img = torch.randn([1, 16, 3, 224, 224])
from vformer.models.classification import ViViT

model = ViViT(
img_size=224,
in_channels=3,
patch_size=16,
embedding_dim=192,
depth=4,
num_heads=3,
head_dim=64,
num_frames=16,
n_classes=10,
)

# model = MODEL_REGISTRY.get("ViViT")(img_size=224,in_channels=3,patch_size=16,embedding_dim=192,depth=4,num_heads=3,head_dim=64,num_frames=1,n_classes=10)

out = model(img)
4 changes: 2 additions & 2 deletions vformer/common/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ def __init__(self, img_size, patch_size, in_channels=3, pool="cls"):
img_height % patch_height == 0 and img_width % patch_width == 0
), "Image dimensions must be divisible by the patch size."

n_patches = (img_height // patch_height) * (img_width // patch_width)
num_patches = (img_height // patch_height) * (img_width // patch_width)
patch_dim = in_channels * patch_height * patch_width

self.patch_height = patch_height
self.patch_width = patch_width
self.n_patches = n_patches
self.num_patches = num_patches
self.patch_dim = patch_dim

assert pool in {
Expand Down
1 change: 1 addition & 0 deletions vformer/encoder/embedding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from .overlappatch import OverlapPatchEmbed
from .patch import PatchEmbedding
from .pos_embedding import *
from .video_patch_embeddings import *
8 changes: 3 additions & 5 deletions vformer/encoder/embedding/video_patch_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ class LinearVideoEmbedding(nn.Module):
Height of the patch
patch_width: int
Width of the patch
patch_dim: int
Dimension of the patch

"""

Expand All @@ -31,8 +29,8 @@ def __init__(
patch_width,
patch_dim,
):
super().__init__()

super().__init__()
self.patch_embedding = nn.Sequential(
Rearrange(
"b t c (h ph) (w pw) -> b t (h w) (ph pw c)",
Expand Down Expand Up @@ -82,15 +80,15 @@ def __init__(self, embedding_dim, tubelet_t, tubelet_h, tubelet_w, in_channels):
tubelet_dim = in_channels * tubelet_h * tubelet_w * tubelet_t
self.tubelet_embedding = nn.Sequential(
Rearrange(
"b c (t pt) (h ph) (w pw) -> b t (h w) (pt ph pw c)",
"b (t pt) c (h ph) (w pw) -> b t (h w) (pt ph pw c)",
pt=tubelet_t,
ph=tubelet_h,
pw=tubelet_w,
),
nn.Linear(tubelet_dim, embedding_dim),
)

def forwar(self, x):
def forward(self, x):
"""

Parameters
Expand Down
1 change: 1 addition & 0 deletions vformer/models/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from .swin import SwinTransformer
from .vanilla import VanillaViT
from .visformer import *
from .vivit import ViViT
2 changes: 1 addition & 1 deletion vformer/models/classification/cross.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(
)

self.pos_embedding = nn.Parameter(
torch.randn(1, self.n_patches + 1, latent_dim)
torch.randn(1, self.num_patches + 1, latent_dim)
)
self.cls_token = nn.Parameter(torch.randn(1, 1, latent_dim))
self.embedding_dropout = nn.Dropout(p_dropout_embedding)
Expand Down
2 changes: 1 addition & 1 deletion vformer/models/classification/vanilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(
)

self.pos_embedding = PosEmbedding(
shape=self.n_patches + 1,
shape=self.num_patches + 1,
dim=embedding_dim,
drop=p_dropout_embedding,
sinusoidal=False,
Expand Down
111 changes: 111 additions & 0 deletions vformer/models/classification/vivit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import torch
import torch.nn as nn
from einops import rearrange, repeat

from ...common.base_model import BaseClassificationModel
from ...decoder.mlp import MLPDecoder
from ...encoder.embedding import LinearVideoEmbedding, PosEmbedding, TubeletEmbedding
from ...encoder.vanilla import VanillaEncoder
from ...utils.registry import MODEL_REGISTRY


@MODEL_REGISTRY.register()
class ViViT(BaseClassificationModel):
def __init__(
self,
img_size,
in_channels,
patch_size,
embedding_dim,
num_frames,
depth,
num_heads,
head_dim,
n_classes,
mlp_dim=None,
pool="cls",
p_dropout=0.0,
attn_dropout=0.0,
drop_path_rate=0.02,
):
super(ViViT, self).__init__(
img_size=img_size,
in_channels=in_channels,
patch_size=patch_size,
pool=pool,
)

patch_dim = in_channels * patch_size ** 2
self.patch_embedding = LinearVideoEmbedding(
embedding_dim=embedding_dim,
patch_height=patch_size,
patch_width=patch_size,
patch_dim=patch_dim,
)

self.pos_embedding = PosEmbedding(
shape=[num_frames, self.num_patches + 1], dim=embedding_dim, drop=p_dropout
)

self.space_token = nn.Parameter(
torch.randn(1, 1, embedding_dim)
) # this is similar to using cls token in vanilla vision transformer
self.spatial_transformer = VanillaEncoder(
embedding_dim=embedding_dim,
depth=depth,
num_heads=num_heads,
head_dim=head_dim,
mlp_dim=mlp_dim,
p_dropout=p_dropout,
attn_dropout=attn_dropout,
drop_path_rate=drop_path_rate,
)

self.time_token = nn.Parameter(torch.randn(1, 1, embedding_dim))
self.temporal_transformer = VanillaEncoder(
embedding_dim=embedding_dim,
depth=depth,
num_heads=num_heads,
head_dim=head_dim,
mlp_dim=mlp_dim,
p_dropout=p_dropout,
attn_dropout=attn_dropout,
drop_path_rate=drop_path_rate,
)

self.decoder = MLPDecoder(
config=[
embedding_dim,
],
n_classes=n_classes,
)

def forward(self, x):

x = self.patch_embedding(x)

(
b,
t,
n,
d,
) = x.shape # shape of x will be number of videos,time,num_frames,embedding dim
cls_space_tokens = repeat(self.space_token, "() n d -> b t n d", b=b, t=t)

x = nn.Parameter(torch.cat((cls_space_tokens, x), dim=2))
x = self.pos_embedding(x)

x = rearrange(x, "b t n d -> (b t) n d")
x = self.spatial_transformer(x)
x = rearrange(x[:, 0], "(b t) ... -> b t ...", b=b)

cls_temporal_tokens = repeat(self.time_token, "() n d -> b n d", b=b)
x = torch.cat((cls_temporal_tokens, x), dim=1)

x = self.temporal_transformer(x)

x = x.mean(dim=1) if self.pool == "mean" else x[:, 0]

x = self.decoder(x)

return x