Skip to content

Commit

Permalink
👕 Standardize WideResNets
Browse files Browse the repository at this point in the history
  • Loading branch information
o-laurent committed Oct 10, 2023
1 parent e673fca commit f70b200
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 38 deletions.
18 changes: 9 additions & 9 deletions torch_uncertainty/models/wideresnet/batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def __init__(
groups: int = 1,
):
super().__init__()
self.bn1 = nn.BatchNorm2d(in_planes)
self.conv1 = BatchConv2d(
in_planes,
planes,
Expand All @@ -34,7 +33,7 @@ def __init__(
bias=False,
)
self.dropout = nn.Dropout(p=dropout_rate)
self.bn2 = nn.BatchNorm2d(planes)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = BatchConv2d(
planes,
planes,
Expand All @@ -59,10 +58,13 @@ def __init__(
),
)

self.bn2 = nn.BatchNorm2d(planes)

def forward(self, x):
out = self.dropout(self.conv1(F.relu(self.bn1(x))))
out = self.conv2(F.relu(self.bn2(out)))
out = F.relu(self.bn1(self.dropout(self.conv1(x))))
out = self.conv2(out)
out += self.shortcut(x)
out = F.relu(self.bn2(out))
return out


Expand Down Expand Up @@ -111,6 +113,8 @@ def __init__(
bias=True,
)

self.bn1 = nn.BatchNorm2d(nStages[0])

if style == "imagenet":
self.optional_pool = nn.MaxPool2d(
kernel_size=3, stride=2, padding=1
Expand Down Expand Up @@ -145,7 +149,6 @@ def __init__(
num_estimators=self.num_estimators,
groups=groups,
)
self.bn1 = nn.BatchNorm2d(nStages[3])

self.pool = nn.AdaptiveAvgPool2d(output_size=1)
self.flatten = nn.Flatten(1)
Expand Down Expand Up @@ -186,17 +189,14 @@ def _wide_layer(

def forward(self, x):
out = x.repeat(self.num_estimators, 1, 1, 1)
out = self.conv1(out)
out = F.relu(self.bn1(self.conv1(out)))
out = self.optional_pool(out)
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = F.relu(self.bn1(out))

out = self.pool(out)
out = self.flatten(out)
out = self.linear(out)

return out


Expand Down
23 changes: 12 additions & 11 deletions torch_uncertainty/models/wideresnet/masked.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# fmt: off
from typing import Type

import torch.nn.functional as F
from torch import nn

Expand All @@ -22,7 +24,6 @@ def __init__(
groups: int = 1,
):
super().__init__()
self.bn1 = nn.BatchNorm2d(in_planes)
self.conv1 = MaskedConv2d(
in_planes,
planes,
Expand All @@ -34,7 +35,7 @@ def __init__(
groups=groups,
)
self.dropout = nn.Dropout(p=dropout_rate)
self.bn2 = nn.BatchNorm2d(planes)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = MaskedConv2d(
planes,
planes,
Expand All @@ -60,11 +61,13 @@ def __init__(
groups=groups,
),
)
self.bn2 = nn.BatchNorm2d(planes)

def forward(self, x):
out = self.dropout(self.conv1(F.relu(self.bn1(x))))
out = self.conv2(F.relu(self.bn2(out)))
out = F.relu(self.bn1(self.dropout(self.conv1(x))))
out = self.conv2(out)
out += self.shortcut(x)
out = F.relu(self.bn2(out))
return out


Expand All @@ -86,7 +89,7 @@ def __init__(
self.in_planes = 16

assert (depth - 4) % 6 == 0, "Wide-resnet depth should be 6n+4."
n = (depth - 4) / 6
n = (depth - 4) // 6
k = widen_factor

nStages = [16, 16 * k, 32 * k, 64 * k]
Expand All @@ -112,6 +115,8 @@ def __init__(
groups=1,
)

self.bn1 = nn.BatchNorm2d(nStages[0])

if style == "imagenet":
self.optional_pool = nn.MaxPool2d(
kernel_size=3, stride=2, padding=1
Expand Down Expand Up @@ -149,7 +154,6 @@ def __init__(
scale=scale,
groups=groups,
)
self.bn1 = nn.BatchNorm2d(nStages[3])

self.pool = nn.AdaptiveAvgPool2d(output_size=1)
self.flatten = nn.Flatten(1)
Expand All @@ -160,7 +164,7 @@ def __init__(

def _wide_layer(
self,
block: nn.Module,
block: Type[nn.Module],
planes: int,
num_blocks: int,
dropout_rate: float,
Expand Down Expand Up @@ -190,17 +194,14 @@ def _wide_layer(

def forward(self, x):
out = x.repeat(self.num_estimators, 1, 1, 1)
out = self.conv1(out)
out = F.relu(self.bn1(self.conv1(out)))
out = self.optional_pool(out)
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = F.relu(self.bn1(out))

out = self.pool(out)
out = self.flatten(out)
out = self.linear(out)

return out


Expand Down
20 changes: 10 additions & 10 deletions torch_uncertainty/models/wideresnet/packed.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,12 @@ def __init__(
planes: int,
dropout_rate: float,
stride: int = 1,
alpha: float = 2,
alpha: int = 2,
num_estimators: int = 4,
gamma: int = 1,
groups: int = 1,
):
super().__init__()
self.bn1 = nn.BatchNorm2d(alpha * in_planes)
self.conv1 = PackedConv2d(
in_planes,
planes,
Expand All @@ -39,7 +38,7 @@ def __init__(
bias=False,
)
self.dropout = nn.Dropout(p=dropout_rate)
self.bn2 = nn.BatchNorm2d(alpha * planes)
self.bn1 = nn.BatchNorm2d(alpha * planes)
self.conv2 = PackedConv2d(
planes,
planes,
Expand Down Expand Up @@ -67,11 +66,13 @@ def __init__(
bias=True,
),
)
self.bn2 = nn.BatchNorm2d(alpha * planes)

def forward(self, x):
out = self.dropout(self.conv1(F.relu(self.bn1(x))))
out = self.conv2(F.relu(self.bn2(out)))
out = F.relu(self.bn1(self.dropout(self.conv1(x))))
out = self.conv2(out)
out += self.shortcut(x)
out = F.relu(self.bn2(out))
return out


Expand Down Expand Up @@ -128,6 +129,8 @@ def __init__(
first=True,
)

self.bn1 = nn.BatchNorm2d(nStages[0] * alpha)

if style == "imagenet":
self.optional_pool = nn.MaxPool2d(
kernel_size=3, stride=2, padding=1
Expand Down Expand Up @@ -168,7 +171,6 @@ def __init__(
gamma=gamma,
groups=groups,
)
self.bn1 = nn.BatchNorm2d(nStages[3] * alpha, momentum=0.9)

self.pool = nn.AdaptiveAvgPool2d(output_size=1)
self.flatten = nn.Flatten(1)
Expand All @@ -188,7 +190,7 @@ def _wide_layer(
num_blocks: int,
dropout_rate: float,
stride: int,
alpha: float,
alpha: int,
num_estimators: int,
gamma: int,
groups: int,
Expand All @@ -214,19 +216,17 @@ def _wide_layer(
return nn.Sequential(*layers)

def forward(self, x):
out = self.conv1(x)
out = F.relu(self.bn1(self.conv1(x)))
out = self.optional_pool(out)
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = F.relu(self.bn1(out))
out = rearrange(
out, "e (m c) h w -> (m e) c h w", m=self.num_estimators
)
out = self.pool(out)
out = self.flatten(out)
out = self.linear(out)

return out


Expand Down
16 changes: 8 additions & 8 deletions torch_uncertainty/models/wideresnet/std.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def __init__(
groups=1,
):
super().__init__()
self.bn1 = nn.BatchNorm2d(in_planes)
self.conv1 = nn.Conv2d(
in_planes,
planes,
Expand All @@ -30,7 +29,7 @@ def __init__(
bias=False,
)
self.dropout = nn.Dropout(p=dropout_rate)
self.bn2 = nn.BatchNorm2d(planes)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(
planes,
planes,
Expand All @@ -52,11 +51,13 @@ def __init__(
bias=True,
),
)
self.bn2 = nn.BatchNorm2d(planes)

def forward(self, x):
out = self.dropout(self.conv1(F.relu(self.bn1(x))))
out = self.conv2(F.relu(self.bn2(out)))
out = F.relu(self.bn1(self.dropout(self.conv1(x))))
out = self.conv2(out)
out += self.shortcut(x)
out = F.relu(self.bn2(out))
return out


Expand Down Expand Up @@ -101,6 +102,8 @@ def __init__(
bias=True,
)

self.bn1 = nn.BatchNorm2d(nStages[0])

if style == "imagenet":
self.optional_pool = nn.MaxPool2d(
kernel_size=3, stride=2, padding=1
Expand Down Expand Up @@ -132,7 +135,6 @@ def __init__(
stride=2,
groups=groups,
)
self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9)

self.pool = nn.AdaptiveAvgPool2d(output_size=1)
self.flatten = nn.Flatten(1)
Expand Down Expand Up @@ -169,16 +171,14 @@ def _wide_layer(
return nn.Sequential(*layers)

def forward(self, x):
out = self.conv1(x)
out = F.relu(self.bn1(self.conv1(x)))
out = self.optional_pool(out)
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = F.relu(self.bn1(out))
out = self.pool(out)
out = self.flatten(out)
out = self.linear(out)

return out


Expand Down

0 comments on commit f70b200

Please sign in to comment.