Skip to content

Commit

Permalink
extend mc-dropout to vgg and wideresnet
Browse files Browse the repository at this point in the history
  • Loading branch information
badrmarani committed Oct 13, 2023
1 parent df636a3 commit 73f81f3
Show file tree
Hide file tree
Showing 8 changed files with 159 additions and 10 deletions.
44 changes: 44 additions & 0 deletions tests/models/test_monte_carlo_dropout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from torch_uncertainty.models.resnet.std import resnet34
from torch_uncertainty.models.utils import enable_dropout
from torch_uncertainty.models.vgg.std import vgg11
from torch_uncertainty.models.wideresnet.std import wideresnet28x10


class TestMonteCarloDropout:
"""Testing the ResNet std class."""

def test_resnet(self):
resnet34(1, 10, 0.5, num_estimators=10)

model = resnet34(1, 10, 1)
model.eval()

enable_dropout(model)

for m in model.modules():
if m.__class__.__name__.startswith("Dropout"):
assert m.training

def test_vgg(self):
vgg11(1, 10, dropout=0.5, num_estimators=10)

model = vgg11(1, 10, dropout=0.5, num_estimators=10)
model.eval()

enable_dropout(model)

for m in model.modules():
if m.__class__.__name__.startswith("Dropout"):
assert m.training

def test_wideresnet(self):
wideresnet28x10(1, 10, dropout_rate=0.5, num_estimators=10)

model = wideresnet28x10(1, 10, dropout_rate=0.5, num_estimators=10)
model.eval()

enable_dropout(model)

for m in model.modules():
if m.__class__.__name__.startswith("Dropout"):
assert m.training
16 changes: 16 additions & 0 deletions tests/models/test_vggs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,22 @@ class TestStdVGG:
def test_main(self):
vgg13(1, 10, style="cifar")

def test_mc_dropout(self):
vgg13(
in_channels=1,
num_classes=10,
style="cifar",
num_estimators=3,
enable_last_layer_dropout=True,
)
vgg13(
in_channels=1,
num_classes=10,
style="cifar",
num_estimators=3,
enable_last_layer_dropout=False,
)


class TestPackedVGG:
"""Testing the VGG packed class."""
Expand Down
23 changes: 23 additions & 0 deletions tests/models/test_wideresnets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,29 @@
from torch_uncertainty.models.wideresnet.masked import masked_wideresnet28x10
from torch_uncertainty.models.wideresnet.mimo import mimo_wideresnet28x10
from torch_uncertainty.models.wideresnet.packed import packed_wideresnet28x10
from torch_uncertainty.models.wideresnet.std import wideresnet28x10


class TestMonteCarloDropoutResnet:
"""Testing the WideResNet MC Dropout."""

def test_main(self):
wideresnet28x10(
in_channels=1,
num_classes=2,
groups=1,
style="imagenet",
num_estimators=3,
enable_last_layer_dropout=True,
)
wideresnet28x10(
in_channels=1,
num_classes=2,
groups=1,
style="imagenet",
num_estimators=3,
enable_last_layer_dropout=False,
)


class TestPackedResnet:
Expand Down
29 changes: 20 additions & 9 deletions torch_uncertainty/models/resnet/std.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# fmt: off
from typing import List, Optional, Type, Union
from typing import List, Type, Union

import torch.nn.functional as F
from torch import Tensor, nn
Expand Down Expand Up @@ -199,8 +199,8 @@ def __init__(
dropout_rate: float,
groups: int,
style: str = "imagenet",
num_estimators: Optional[int] = None,
enable_last_layer_dropout: Optional[bool] = False,
num_estimators: int = None,
enable_last_layer_dropout: bool = False,
) -> None:
super().__init__()

Expand Down Expand Up @@ -307,7 +307,8 @@ def _make_layer(
def forward(self, x: Tensor) -> Tensor:
if self.num_estimators is not None:
if not self.training:
enable_dropout(self, self.enable_last_layer_dropout)
if self.enable_last_layer_dropout is not None:
enable_dropout(self, self.enable_last_layer_dropout)
x = x.repeat(self.num_estimators, 1, 1, 1)

out = F.relu(self.bn1(self.conv1(x)))
Expand All @@ -328,7 +329,8 @@ def resnet18(
dropout_rate: float = 0,
groups: int = 1,
style: str = "imagenet",
num_estimators: Optional[int] = None,
num_estimators: int = None,
enable_last_layer_dropout: bool = False,
) -> _ResNet:
"""ResNet-18 from `Deep Residual Learning for Image Recognition
<https://arxiv.org/pdf/1512.03385.pdf>`_.
Expand All @@ -353,6 +355,7 @@ def resnet18(
groups=groups,
style=style,
num_estimators=num_estimators,
enable_last_layer_dropout=enable_last_layer_dropout,
)


Expand All @@ -362,7 +365,8 @@ def resnet34(
dropout_rate: float = 0,
groups: int = 1,
style: str = "imagenet",
num_estimators: Optional[int] = None,
num_estimators: int = None,
enable_last_layer_dropout: bool = False,
) -> _ResNet:
"""ResNet-34 from `Deep Residual Learning for Image Recognition
<https://arxiv.org/pdf/1512.03385.pdf>`_.
Expand All @@ -387,6 +391,7 @@ def resnet34(
groups=groups,
style=style,
num_estimators=num_estimators,
enable_last_layer_dropout=enable_last_layer_dropout,
)


Expand All @@ -396,7 +401,8 @@ def resnet50(
dropout_rate: float = 0,
groups: int = 1,
style: str = "imagenet",
num_estimators: Optional[int] = None,
num_estimators: int = None,
enable_last_layer_dropout: bool = False,
) -> _ResNet:
"""ResNet-50 from `Deep Residual Learning for Image Recognition
<https://arxiv.org/pdf/1512.03385.pdf>`_.
Expand All @@ -421,6 +427,7 @@ def resnet50(
groups=groups,
style=style,
num_estimators=num_estimators,
enable_last_layer_dropout=enable_last_layer_dropout,
)


Expand All @@ -430,7 +437,8 @@ def resnet101(
dropout_rate: float = 0,
groups: int = 1,
style: str = "imagenet",
num_estimators: Optional[int] = None,
num_estimators: int = None,
enable_last_layer_dropout: bool = False,
) -> _ResNet:
"""ResNet-101 from `Deep Residual Learning for Image Recognition
<https://arxiv.org/pdf/1512.03385.pdf>`_.
Expand All @@ -455,6 +463,7 @@ def resnet101(
groups=groups,
style=style,
num_estimators=num_estimators,
enable_last_layer_dropout=enable_last_layer_dropout,
)


Expand All @@ -464,7 +473,8 @@ def resnet152(
dropout_rate: float = 0,
groups: int = 1,
style: str = "imagenet",
num_estimators: Optional[int] = None,
num_estimators: int = None,
enable_last_layer_dropout: bool = False,
) -> _ResNet:
"""ResNet-152 from `Deep Residual Learning for Image Recognition
<https://arxiv.org/pdf/1512.03385.pdf>`_.
Expand All @@ -490,4 +500,5 @@ def resnet152(
groups=groups,
style=style,
num_estimators=num_estimators,
enable_last_layer_dropout=enable_last_layer_dropout,
)
19 changes: 19 additions & 0 deletions torch_uncertainty/models/vgg/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch import nn

from ...layers.packed import PackedConv2d, PackedLinear
from ..utils import enable_dropout


# fmt: on
Expand All @@ -30,6 +31,14 @@ def __init__(
self.conv2d_layer = conv2d_layer
self.norm = norm
self.groups = groups

if self.conv2d_layer == PackedConv2d:
self.num_estimators = model_kwargs.get("num_estimators")
else:
self.num_estimators = model_kwargs.pop("num_estimators")
self.enable_last_layer_dropout = model_kwargs.pop(
"enable_last_layer_dropout", False
)
self.model_kwargs = model_kwargs

self.features = self._make_layers(vgg_cfg)
Expand Down Expand Up @@ -62,7 +71,17 @@ def __init__(
self._init_weights()

def forward(self, x: torch.Tensor) -> torch.Tensor:
if (
self.num_estimators is not None
and self.linear_layer != PackedLinear
):
if not self.training:
if self.enable_last_layer_dropout is not None:
enable_dropout(self, self.enable_last_layer_dropout)
x = x.repeat(self.num_estimators, 1, 1, 1)

x = self.features(x)

if self.linear_layer == PackedLinear:
x = rearrange(
x,
Expand Down
16 changes: 16 additions & 0 deletions torch_uncertainty/models/vgg/std.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ def vgg11(
groups: int = 1,
dropout: float = 0.5,
style: str = "imagenet",
num_estimators: int = None,
enable_last_layer_dropout: bool = False,
) -> VGG:
return _vgg(
cfgs["A"],
Expand All @@ -26,6 +28,8 @@ def vgg11(
groups=groups,
dropout=dropout,
style=style,
num_estimators=num_estimators,
enable_last_layer_dropout=enable_last_layer_dropout,
)


Expand All @@ -36,6 +40,8 @@ def vgg13(
groups: int = 1,
dropout: float = 0.5,
style: str = "imagenet",
num_estimators: int = None,
enable_last_layer_dropout: bool = False,
) -> VGG:
return _vgg(
cfgs["B"],
Expand All @@ -45,6 +51,8 @@ def vgg13(
groups=groups,
dropout=dropout,
style=style,
num_estimators=num_estimators,
enable_last_layer_dropout=enable_last_layer_dropout,
)


Expand All @@ -55,6 +63,8 @@ def vgg16(
groups: int = 1,
dropout: float = 0.5,
style: str = "imagenet",
num_estimators: int = None,
enable_last_layer_dropout: bool = False,
) -> VGG:
return _vgg(
cfgs["D"],
Expand All @@ -64,6 +74,8 @@ def vgg16(
groups=groups,
dropout=dropout,
style=style,
num_estimators=num_estimators,
enable_last_layer_dropout=enable_last_layer_dropout,
)


Expand All @@ -74,6 +86,8 @@ def vgg19(
groups: int = 1,
dropout: float = 0.5,
style: str = "imagenet",
num_estimators: int = None,
enable_last_layer_dropout: bool = False,
) -> VGG: # coverage: ignore
return _vgg(
cfgs["E"],
Expand All @@ -83,4 +97,6 @@ def vgg19(
groups=groups,
dropout=dropout,
style=style,
num_estimators=num_estimators,
enable_last_layer_dropout=enable_last_layer_dropout,
)
3 changes: 3 additions & 0 deletions torch_uncertainty/models/wideresnet/mimo.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
dropout_rate=dropout_rate,
groups=groups,
style=style,
num_estimators=num_estimators,
)

self.num_estimators = num_estimators
Expand All @@ -50,6 +51,7 @@ def mimo_wideresnet28x10(
num_estimators: int,
groups: int = 1,
style: str = "imagenet",
**model_kwargs,
) -> _MIMOWide:
return _MIMOWide(
depth=28,
Expand All @@ -60,4 +62,5 @@ def mimo_wideresnet28x10(
dropout_rate=0.3,
groups=groups,
style=style,
**model_kwargs,
)
Loading

0 comments on commit 73f81f3

Please sign in to comment.