Skip to content
31 changes: 15 additions & 16 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,20 @@ def get_available_video_models():
return [k for k, v in models.video.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]


# model_name, expected to script without error
torchub_models = {
"deeplabv3_resnet101": True,
"mobilenet_v2": True,
"resnext50_32x4d": True,
"fcn_resnet101": True,
"googlenet": False,
"densenet121": True,
"resnet18": True,
"alexnet": True,
"shufflenet_v2_x1_0": True,
"squeezenet1_0": True,
"vgg11": True,
"inception_v3": False,
}
torchub_models = [
"deeplabv3_resnet101",
"mobilenet_v2",
"resnext50_32x4d",
"fcn_resnet101",
"googlenet",
"densenet121",
"resnet18",
"alexnet",
"shufflenet_v2_x1_0",
"squeezenet1_0",
"vgg11",
"inception_v3",
]


class Tester(unittest.TestCase):
Expand All @@ -55,7 +54,7 @@ def check_script(self, model, name):
tb = traceback.format_exc()
scriptable = False
msg = str(e) + str(tb)
self.assertEqual(torchub_models[name], scriptable, msg)
self.assertTrue(scriptable, msg)

def _test_classification_model(self, name, input_shape):
# passing num_class equal to a number other than 1000 helps in making the test
Expand Down
39 changes: 34 additions & 5 deletions torchvision/models/googlenet.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,28 @@
from __future__ import division

import warnings
from collections import namedtuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.jit.annotations import Optional
from torch import Tensor
from .utils import load_state_dict_from_url

__all__ = ['GoogLeNet', 'googlenet']
__all__ = ['GoogLeNet', 'googlenet', "GoogLeNetOutputs", "_GoogLeNetOutputs"]

model_urls = {
# GoogLeNet ported from TensorFlow
'googlenet': 'https://download.pytorch.org/models/googlenet-1378be20.pth',
}

_GoogLeNetOutputs = namedtuple('GoogLeNetOutputs', ['logits', 'aux_logits2', 'aux_logits1'])
GoogLeNetOutputs = namedtuple('GoogLeNetOutputs', ['logits', 'aux_logits2', 'aux_logits1'])
GoogLeNetOutputs.__annotations__ = {'logits': Tensor, 'aux_logits2': Optional[Tensor],
'aux_logits1': Optional[Tensor]}

# Script annotations failed with _GoogleNetOutputs = namedtuple ...
# _GoogLeNetOutputs set here for backwards compat
_GoogLeNetOutputs = GoogLeNetOutputs


def googlenet(pretrained=False, progress=True, **kwargs):
Expand Down Expand Up @@ -51,6 +61,7 @@ def googlenet(pretrained=False, progress=True, **kwargs):


class GoogLeNet(nn.Module):
__constants__ = ['aux_logits', 'transform_input']

def __init__(self, num_classes=1000, aux_logits=True, transform_input=False, init_weights=True):
super(GoogLeNet, self).__init__()
Expand Down Expand Up @@ -102,6 +113,7 @@ def _initialize_weights(self):
nn.init.constant_(m.bias, 0)

def forward(self, x):
# type: (Tensor) -> GoogLeNetOutputs
if self.transform_input:
x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
Expand All @@ -128,17 +140,22 @@ def forward(self, x):
# N x 480 x 14 x 14
x = self.inception4a(x)
# N x 512 x 14 x 14
if self.training and self.aux_logits:
aux_defined = self.training and self.aux_logits
if aux_defined:
aux1 = self.aux1(x)
else:
aux1 = None

x = self.inception4b(x)
# N x 512 x 14 x 14
x = self.inception4c(x)
# N x 512 x 14 x 14
x = self.inception4d(x)
# N x 528 x 14 x 14
if self.training and self.aux_logits:
if aux_defined:
aux2 = self.aux2(x)
else:
aux2 = None

x = self.inception4e(x)
# N x 832 x 14 x 14
Expand All @@ -156,12 +173,24 @@ def forward(self, x):
x = self.dropout(x)
x = self.fc(x)
# N x 1000 (num_classes)
if torch.jit.is_scripting():
if not aux_defined:
warnings.warn("Scripted GoogleNet always returns GoogleNetOutputs Tuple")
return GoogLeNetOutputs(x, aux2, aux1)
else:
return self.eager_outputs(x, aux2, aux1)

@torch.jit.unused
def eager_outputs(self, x, aux2, aux1):
# type: (Tensor, Optional[Tensor], Optional[Tensor]) -> GoogLeNetOutputs
if self.training and self.aux_logits:
return _GoogLeNetOutputs(x, aux2, aux1)
return x
else:
return x


class Inception(nn.Module):
__constants__ = ['branch2', 'branch3', 'branch4']

def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):
super(Inception, self).__init__()
Expand Down
30 changes: 26 additions & 4 deletions torchvision/models/inception.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,28 @@
from __future__ import division

from collections import namedtuple
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.jit.annotations import Optional
from .utils import load_state_dict_from_url


__all__ = ['Inception3', 'inception_v3']
__all__ = ['Inception3', 'inception_v3', 'InceptionOutputs', '_InceptionOutputs']


model_urls = {
# Inception v3 ported from TensorFlow
'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth',
}

_InceptionOutputs = namedtuple('InceptionOutputs', ['logits', 'aux_logits'])
InceptionOutputs = namedtuple('InceptionOutputs', ['logits', 'aux_logits'])
InceptionOutputs.__annotations__ = {'logits': torch.Tensor, 'aux_logits': Optional[torch.Tensor]}

# Script annotations failed with _GoogleNetOutputs = namedtuple ...
# _InceptionOutputs set here for backwards compat
_InceptionOutputs = InceptionOutputs


def inception_v3(pretrained=False, progress=True, **kwargs):
Expand Down Expand Up @@ -128,8 +137,11 @@ def forward(self, x):
# N x 768 x 17 x 17
x = self.Mixed_6e(x)
# N x 768 x 17 x 17
if self.training and self.aux_logits:
aux_defined = self.training and self.aux_logits
if aux_defined:
aux = self.AuxLogits(x)
else:
aux = None
# N x 768 x 17 x 17
x = self.Mixed_7a(x)
# N x 1280 x 8 x 8
Expand All @@ -146,8 +158,18 @@ def forward(self, x):
# N x 2048
x = self.fc(x)
# N x 1000 (num_classes)
if torch.jit.is_scripting():
if not aux_defined:
warnings.warn("Scripted InceptionNet always returns InceptionOutputs Tuple")
return InceptionOutputs(x, aux)
else:
return self.eager_outputs(x, aux)

@torch.jit.unused
def eager_outputs(self, x, aux):
# type: (torch.Tensor, Optional[torch.Tensor]) -> InceptionOutputs
if self.training and self.aux_logits:
return _InceptionOutputs(x, aux)
return InceptionOutputs(x, aux)
return x


Expand Down