Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CCT and CVT #39

Merged
merged 18 commits into from
Dec 3, 2021
Prev Previous commit
Next Next commit
*Changed VanillaEncoder, now CVT and CCT models use VanillaEncoder.
  • Loading branch information
abhi-glitchhg committed Dec 3, 2021
commit 62063a65e26bdaca0ff0c394cbe1f2dbd30e0e69
9 changes: 0 additions & 9 deletions tests/test_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from vformer.encoder import (
CrossEncoder,
CVTEncoderBlock,
PVTEncoder,
SwinEncoder,
SwinEncoderBlock,
Expand Down Expand Up @@ -88,11 +87,3 @@ def test_CrossEncoder():
assert out[0].shape == test_tensor1.shape
assert out[1].shape == test_tensor2.shape # shape remains same
del encoder


def test_CVTEncoder():
test_tensor1 = torch.randn(4, 3136, 128)

encoder = CVTEncoderBlock(dim=128, num_head=8, p_dropout=0.0, attn_dropout=0.0)
out = encoder(test_tensor1)
assert out.shape == test_tensor1.shape
1 change: 0 additions & 1 deletion vformer/encoder/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from .cross import CrossEncoder
from .cvt import CVTEncoderBlock
from .embedding import *
from .nn import FeedForward
from .pyramid import PVTEncoder
Expand Down
75 changes: 0 additions & 75 deletions vformer/encoder/cvt.py

This file was deleted.

26 changes: 23 additions & 3 deletions vformer/encoder/vanilla.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch.nn as nn
from timm.models.layers import DropPath

from ..attention import VanillaSelfAttention
from ..functional import PreNorm
Expand All @@ -21,9 +22,23 @@ class VanillaEncoder(nn.Module):
Dimension of the hidden layer in the feed-forward layer
p_dropout: float
Dropout Probability
attn_dropout: float
Dropout Probability
drop_path_rate: float
Stochastic drop path rate
"""

def __init__(self, latent_dim, depth, heads, dim_head, mlp_dim, p_dropout=0.0):
def __init__(
self,
latent_dim,
depth,
heads,
dim_head,
mlp_dim,
p_dropout=0.0,
attn_dropout=0.0,
drop_path_rate=0.0,
):
super().__init__()
self.encoder = nn.ModuleList([])
for _ in range(depth):
Expand All @@ -36,7 +51,7 @@ def __init__(self, latent_dim, depth, heads, dim_head, mlp_dim, p_dropout=0.0):
latent_dim,
heads=heads,
dim_head=dim_head,
p_dropout=p_dropout,
p_dropout=attn_dropout,
),
),
PreNorm(
Expand All @@ -46,10 +61,15 @@ def __init__(self, latent_dim, depth, heads, dim_head, mlp_dim, p_dropout=0.0):
]
)
)
self.drop_path = (
DropPath(drop_prob=drop_path_rate)
if drop_path_rate > 0.0
else nn.Identity()
)

def forward(self, x):
for attn, ff in self.encoder:
x = attn(x) + x
x = ff(x) + x
x = self.drop_path(ff(x)) + x

return x
13 changes: 8 additions & 5 deletions vformer/models/classification/cct.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from ...common import BaseClassificationModel
from ...decoder import MLPDecoder
from ...encoder import CVTEmbedding, CVTEncoderBlock
from ...encoder import CVTEmbedding, VanillaEncoder
from ...utils import pair


Expand Down Expand Up @@ -58,6 +58,7 @@ def __init__(
seq_pool=True,
embedding_dim=768,
num_layers=1,
dim_head=96,
num_heads=1,
mlp_ratio=4.0,
num_classes=1000,
Expand Down Expand Up @@ -140,10 +141,12 @@ def __init__(
dpr = [x.item() for x in torch.linspace(0, drop_path, num_layers)]
self.encoder_blocks = nn.ModuleList(
[
CVTEncoderBlock(
dim=embedding_dim,
num_head=num_heads,
hidden_dim=hidden_dim,
VanillaEncoder(
latent_dim=embedding_dim,
heads=num_heads,
depth=1,
dim_head=dim_head,
mlp_dim=hidden_dim,
p_dropout=p_dropout,
attn_dropout=attn_dropout,
drop_path_rate=dpr[i],
Expand Down
26 changes: 20 additions & 6 deletions vformer/models/classification/cvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from ...common import BaseClassificationModel
from ...decoder import MLPDecoder
from ...encoder import CVTEmbedding, CVTEncoderBlock
from ...encoder import CVTEmbedding, VanillaEncoder
from ...utils import pair


Expand Down Expand Up @@ -51,6 +51,7 @@ def __init__(
in_chans=3,
seq_pool=True,
embedding_dim=768,
dim_head=96,
num_layers=1,
num_heads=1,
mlp_ratio=4.0,
Expand Down Expand Up @@ -128,18 +129,31 @@ def __init__(
dpr = [x.item() for x in torch.linspace(0, drop_path, num_layers)]
self.encoder_blocks = nn.ModuleList(
[
CVTEncoderBlock(
dim=embedding_dim,
num_head=num_heads,
hidden_dim=hidden_dim,
VanillaEncoder(
latent_dim=embedding_dim,
heads=num_heads,
depth=1,
mlp_dim=hidden_dim,
dim_head=dim_head,
p_dropout=p_dropout,
attn_dropout=attn_dropout,
drop_path_rate=dpr[i],
)
for i in range(num_layers)
]
)
self.decoder = MLPDecoder(config=decoder_config, n_classes=num_classes)
if decoder_config is not None:

if not isinstance(decoder_config, list) and not isinstance(
decoder_config, tuple
):
decoder_config = [decoder_config]
assert (
decoder_config[0] == embedding_dim
), f"Configurations do not match for MLPDecoder, First element of `decoder_config` expected to be {embedding_dim}, got {decoder_config[0]} "
self.decoder = MLPDecoder(config=decoder_config, n_classes=num_classes)
else:
self.decoder = MLPDecoder(config=embedding_dim, n_classes=num_classes)

def forward(self, x):
x = self.embedding(x)
Expand Down