Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Video Vision Transformer implementation #62

Merged
merged 17 commits into from
Feb 1, 2022
Merged
Prev Previous commit
Next Next commit
added tests
  • Loading branch information
abhi-glitchhg committed Jan 31, 2022
commit 8e7ac6e94c0dc0dba457d257ce6c4def1ac6c1df
15 changes: 15 additions & 0 deletions tests/test_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,18 @@ def test_CrossEncoder():
assert out[0].shape == test_tensor1.shape
assert out[1].shape == test_tensor2.shape
del encoder

def test_TubeletEmbedding():
from vformer.encoder.embedding.video_patch_embeddings import TubeletEmbedding
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please move this import to the top of the file since we've been doing that for the rest of the test functions?


test_tensor = torch.randn(7,20,3,224,224) #batch_size,time,in_channels,height,width
embedding = TubeletEmbedding(embedding_dim=192,tubelet_w=16,tubelet_t=5,tubelet_h=16,in_channels=3)
out = embedding(test_tensor)
assert out.shape == (7,4,196,192) #batch,time/tubelet_t,height*width/(tubelet_h,tubelet_w),embeeding_dim
del embedding

test_tensor = torch.randn(11,15,1,28,28)
embedding=TubeletEmbedding(96,5,7,7,1)
out = embedding(test_tensor)
assert out.shape == (11,3,16,96)
del embedding