Skip to content

Commit

Permalink
Standardize naming of variables for number of classes (#67)
Browse files Browse the repository at this point in the history
  • Loading branch information
NeelayS authored Feb 8, 2022
1 parent 041f89b commit 12d21a9
Show file tree
Hide file tree
Showing 11 changed files with 38 additions and 40 deletions.
14 changes: 7 additions & 7 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def test_pvt():
depths=[2, 2, 2, 2],
sr_ratios=[8, 4, 2, 1],
decoder_config=[512, 10],
num_classes=10,
n_classes=10,
)
out = model(img_3channels_224)
assert out.shape == (4, 10)
Expand All @@ -187,7 +187,7 @@ def test_pvt():
depths=[2, 2, 2, 2],
sr_ratios=[8, 4, 2, 1],
decoder_config=512,
num_classes=10,
n_classes=10,
)
out = model(img_3channels_224)
assert out.shape == (4, 10)
Expand All @@ -198,17 +198,17 @@ def test_pvt():
assert out.shape == (4, 1000)
del model

model = MODEL_REGISTRY.get("PVTClassificationV2")(num_classes=10)
model = MODEL_REGISTRY.get("PVTClassificationV2")(n_classes=10)
out = model(img_3channels_224)
assert out.shape == (4, 10)
del model

model = MODEL_REGISTRY.get("PVTClassificationV2")(num_classes=10)
model = MODEL_REGISTRY.get("PVTClassificationV2")(n_classes=10)
out = model(img_3channels_224)
assert out.shape == (4, 10)
del model

model = MODEL_REGISTRY.get("PVTClassification")(num_classes=12)
model = MODEL_REGISTRY.get("PVTClassification")(n_classes=12)
out = model(img_3channels_224)
assert out.shape == (4, 12)
del model
Expand Down Expand Up @@ -305,7 +305,7 @@ def test_cvt():
embedding_dim=768,
num_heads=1,
mlp_ratio=4.0,
num_classes=10,
n_classes=10,
p_dropout=0.5,
attn_dropout=0.3,
drop_path=0.2,
Expand Down Expand Up @@ -356,7 +356,7 @@ def test_cct():
embedding_dim=768,
num_heads=1,
mlp_ratio=4.0,
num_classes=10,
n_classes=10,
p_dropout=0.5,
attn_dropout=0.3,
drop_path=0.2,
Expand Down
10 changes: 5 additions & 5 deletions vformer/models/classification/cct.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class CCT(BaseClassificationModel):
Number of heads in each transformer layer
mlp_ratio:float
Ratio of mlp heads to embedding dimension
num_classes: int
n_classes: int
Number of classes for classification
p_dropout: float
Dropout probability
Expand Down Expand Up @@ -63,7 +63,7 @@ def __init__(
head_dim=96,
num_heads=1,
mlp_ratio=4.0,
num_classes=1000,
n_classes=1000,
p_dropout=0.1,
attn_dropout=0.1,
drop_path=0.1,
Expand Down Expand Up @@ -163,10 +163,10 @@ def __init__(
assert (
decoder_config[0] == embedding_dim
), f"Configurations do not match for MLPDecoder, First element of `decoder_config` expected to be {embedding_dim}, got {decoder_config[0]} "
self.decoder = MLPDecoder(config=decoder_config, n_classes=num_classes)
self.decoder = MLPDecoder(config=decoder_config, n_classes=n_classes)

else:
self.decoder = MLPDecoder(config=embedding_dim, n_classes=num_classes)
self.decoder = MLPDecoder(config=embedding_dim, n_classes=n_classes)

def forward(self, x):
"""
Expand All @@ -178,7 +178,7 @@ def forward(self, x):
Returns
----------
torch.Tensor
Returns tensor of size `num_classes`
Returns tensor of size `n_classes`
"""
x = self.embedding(x)
Expand Down
2 changes: 1 addition & 1 deletion vformer/models/classification/convit.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def forward(self, x):
Returns
----------
torch.Tensor
Returns tensor of size `num_classes`
Returns tensor of size `n_classes`
"""
x = self.patch_embedding(x)
Expand Down
10 changes: 4 additions & 6 deletions vformer/models/classification/convvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class ConvVT(nn.Module):
Number of input channels in image, default is 3
num_stages: int
Number of stages in encoder block, default is 3
num_classes: int
n_classes: int
Number of classes for classification, default is 1000
* The following are all in list of int/float with length num_stages
patch_size: list[int]
Expand Down Expand Up @@ -75,11 +75,11 @@ def __init__(
stride_q=[1, 1, 1],
in_channels=3,
num_stages=3,
num_classes=1000,
n_classes=1000,
):
super().__init__()

self.num_classes = num_classes
self.n_classes = n_classes

self.num_stages = num_stages
self.stages = []
Expand Down Expand Up @@ -111,9 +111,7 @@ def __init__(

# Classifier head
self.head = (
nn.Linear(embedding_dim[-1], num_classes)
if num_classes > 0
else nn.Identity()
nn.Linear(embedding_dim[-1], n_classes) if n_classes > 0 else nn.Identity()
)
trunc_normal_(self.head.weight, std=0.02)

Expand Down
2 changes: 1 addition & 1 deletion vformer/models/classification/cross.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def forward(self, img):
Returns
----------
torch.Tensor
Returns tensor of size `num_classes`
Returns tensor of size `n_classes`
"""
emb_s = self.s(img)
Expand Down
10 changes: 5 additions & 5 deletions vformer/models/classification/cvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class CVT(BaseClassificationModel):
Number of heads in each transformer layer, default is 1
mlp_ratio:float
Ratio of mlp heads to embedding dimension, default is 4.0
num_classes: int
n_classes: int
Number of classes for classification, default is 1000
p_dropout: float
Dropout probability, default is 0.0
Expand All @@ -57,7 +57,7 @@ def __init__(
num_layers=1,
num_heads=1,
mlp_ratio=4.0,
num_classes=1000,
n_classes=1000,
p_dropout=0.1,
attn_dropout=0.1,
drop_path=0.1,
Expand Down Expand Up @@ -149,9 +149,9 @@ def __init__(
assert (
decoder_config[0] == embedding_dim
), f"Configurations do not match for MLPDecoder, First element of `decoder_config` expected to be {embedding_dim}, got {decoder_config[0]} "
self.decoder = MLPDecoder(config=decoder_config, n_classes=num_classes)
self.decoder = MLPDecoder(config=decoder_config, n_classes=n_classes)
else:
self.decoder = MLPDecoder(config=embedding_dim, n_classes=num_classes)
self.decoder = MLPDecoder(config=embedding_dim, n_classes=n_classes)

def forward(self, x):
"""
Expand All @@ -163,7 +163,7 @@ def forward(self, x):
Returns
----------
torch.Tensor
Returns tensor of size `num_classes`
Returns tensor of size `n_classes`
"""

Expand Down
16 changes: 8 additions & 8 deletions vformer/models/classification/pyramid.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class PVTClassification(nn.Module):
List of patch size
in_channels: int
Input channels in image, default=3
num_classes: int
n_classes: int
Number of classes for classification
embed_dims: int
Patch Embedding dimension
Expand Down Expand Up @@ -61,7 +61,7 @@ def __init__(
img_size=224,
patch_size=[7, 3, 3, 3],
in_channels=3,
num_classes=1000,
n_classes=1000,
embed_dims=[64, 128, 256, 512],
num_heads=[1, 2, 4, 8],
mlp_ratio=[4, 4, 4, 4],
Expand Down Expand Up @@ -159,9 +159,9 @@ def __init__(
assert (
decoder_config[0] == embed_dims[-1]
), f"Configurations do not match for MLPDecoder, First element of `decoder_config` expected to be {embed_dims[-1]}, got {decoder_config[0]} "
self.decoder = MLPDecoder(config=decoder_config, n_classes=num_classes)
self.decoder = MLPDecoder(config=decoder_config, n_classes=n_classes)
else:
self.decoder = MLPDecoder(config=embed_dims[-1], n_classes=num_classes)
self.decoder = MLPDecoder(config=embed_dims[-1], n_classes=n_classes)

def forward(self, x):
"""
Expand All @@ -173,7 +173,7 @@ def forward(self, x):
Returns
----------
torch.Tensor
Returns tensor of size `num_classes`
Returns tensor of size `n_classes`
"""
B = x.shape[0]
Expand Down Expand Up @@ -216,7 +216,7 @@ class PVTClassificationV2(PVTClassification):
List of patch size
in_channels: int
Input channels in image, default is 3
num_classes: int
n_classes: int
Number of classes for classification
embedding_dims: int
Patch Embedding dimension
Expand Down Expand Up @@ -255,7 +255,7 @@ def __init__(
img_size=224,
patch_size=[7, 3, 3, 3],
in_channels=3,
num_classes=1000,
n_classes=1000,
embedding_dims=[64, 128, 256, 512],
num_heads=[1, 2, 4, 8],
mlp_ratio=[4, 4, 4, 4],
Expand All @@ -276,7 +276,7 @@ def __init__(
img_size=img_size,
patch_size=patch_size,
in_channels=in_channels,
num_classes=num_classes,
n_classes=n_classes,
embed_dims=embedding_dims,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
Expand Down
2 changes: 1 addition & 1 deletion vformer/models/classification/swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def forward(self, x):
Returns
----------
torch.Tensor
Returns tensor of size `num_classes`
Returns tensor of size `n_classes`
"""
x = self.patch_embed(x)
Expand Down
2 changes: 1 addition & 1 deletion vformer/models/classification/vanilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def forward(self, x):
Returns
----------
torch.Tensor
Returns tensor of size `num_classes`
Returns tensor of size `n_classes`
"""
x = self.patch_embedding(x)
Expand Down
4 changes: 2 additions & 2 deletions vformer/models/dense/PVT/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class PVTDetection(nn.Module):
List of patch size
in_channels: int
Input channels in image, default=3
num_classes: int
n_classes: int
Number of classes for classification
embedding_dims: int
Patch Embedding dimension
Expand Down Expand Up @@ -197,7 +197,7 @@ class PVTDetectionV2(PVTDetection):
List of patch size
in_channels: int
Input channels in image, default=3
num_classes: int
n_classes: int
Number of classes for classification
embedding_dims: int
Patch Embedding dimension
Expand Down
6 changes: 3 additions & 3 deletions vformer/models/dense/dpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ def __init__(self, scale_factor, mode, align_corners=False):
self.align_corners = align_corners

def forward(self, x):
""" Forward pass """
"""Forward pass"""

x = self.interp(
x,
Expand Down Expand Up @@ -596,7 +596,7 @@ def __init__(self, features, activation=nn.GELU, bn=True):
self.skip_add = nn.quantized.FloatFunctional()

def forward(self, x):
""" forward pass"""
"""forward pass"""
out = self.activation(x)
out = self.conv1(out)
if self.bn == True:
Expand Down Expand Up @@ -651,7 +651,7 @@ def __init__(
self.skip_add = nn.quantized.FloatFunctional()

def forward(self, *xs):
"""Forward pass """
"""Forward pass"""
output = xs[0]

if len(xs) == 2:
Expand Down

0 comments on commit 12d21a9

Please sign in to comment.