Skip to content

Make Googlnet & InceptionNet scriptable #1349

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Sep 27, 2019
31 changes: 15 additions & 16 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,20 @@ def get_available_video_models():
return [k for k, v in models.video.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]


# model_name, expected to script without error
torchub_models = {
"deeplabv3_resnet101": True,
"mobilenet_v2": True,
"resnext50_32x4d": True,
"fcn_resnet101": True,
"googlenet": False,
"densenet121": True,
"resnet18": True,
"alexnet": True,
"shufflenet_v2_x1_0": True,
"squeezenet1_0": True,
"vgg11": True,
"inception_v3": False,
}
torchub_models = [
"deeplabv3_resnet101",
"mobilenet_v2",
"resnext50_32x4d",
"fcn_resnet101",
"googlenet",
"densenet121",
"resnet18",
"alexnet",
"shufflenet_v2_x1_0",
"squeezenet1_0",
"vgg11",
"inception_v3",
]


class Tester(unittest.TestCase):
Expand All @@ -55,7 +54,7 @@ def check_script(self, model, name):
tb = traceback.format_exc()
scriptable = False
msg = str(e) + str(tb)
self.assertEqual(torchub_models[name], scriptable, msg)
self.assertTrue(scriptable, msg)

def _test_classification_model(self, name, input_shape):
# passing num_class equal to a number other than 1000 helps in making the test
Expand Down
39 changes: 34 additions & 5 deletions torchvision/models/googlenet.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,28 @@
from __future__ import division

import warnings
from collections import namedtuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.jit.annotations import Optional
from torch import Tensor
from .utils import load_state_dict_from_url

__all__ = ['GoogLeNet', 'googlenet']
__all__ = ['GoogLeNet', 'googlenet', "GoogLeNetOutputs", "_GoogLeNetOutputs"]

model_urls = {
# GoogLeNet ported from TensorFlow
'googlenet': 'https://download.pytorch.org/models/googlenet-1378be20.pth',
}

_GoogLeNetOutputs = namedtuple('GoogLeNetOutputs', ['logits', 'aux_logits2', 'aux_logits1'])
GoogLeNetOutputs = namedtuple('GoogLeNetOutputs', ['logits', 'aux_logits2', 'aux_logits1'])
GoogLeNetOutputs.__annotations__ = {'logits': Tensor, 'aux_logits2': Optional[Tensor],
'aux_logits1': Optional[Tensor]}

# Script annotations failed with _GoogleNetOutputs = namedtuple ...
# _GoogLeNetOutputs set here for backwards compat
_GoogLeNetOutputs = GoogLeNetOutputs


def googlenet(pretrained=False, progress=True, **kwargs):
Expand Down Expand Up @@ -51,6 +61,7 @@ def googlenet(pretrained=False, progress=True, **kwargs):


class GoogLeNet(nn.Module):
__constants__ = ['aux_logits', 'transform_input']

def __init__(self, num_classes=1000, aux_logits=True, transform_input=False, init_weights=True):
super(GoogLeNet, self).__init__()
Expand Down Expand Up @@ -102,6 +113,7 @@ def _initialize_weights(self):
nn.init.constant_(m.bias, 0)

def forward(self, x):
# type: (Tensor) -> GoogLeNetOutputs
if self.transform_input:
x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
Expand All @@ -128,17 +140,22 @@ def forward(self, x):
# N x 480 x 14 x 14
x = self.inception4a(x)
# N x 512 x 14 x 14
if self.training and self.aux_logits:
aux_defined = self.training and self.aux_logits
if aux_defined:
aux1 = self.aux1(x)
else:
aux1 = None

x = self.inception4b(x)
# N x 512 x 14 x 14
x = self.inception4c(x)
# N x 512 x 14 x 14
x = self.inception4d(x)
# N x 528 x 14 x 14
if self.training and self.aux_logits:
if aux_defined:
aux2 = self.aux2(x)
else:
aux2 = None

x = self.inception4e(x)
# N x 832 x 14 x 14
Expand All @@ -156,12 +173,24 @@ def forward(self, x):
x = self.dropout(x)
x = self.fc(x)
# N x 1000 (num_classes)
if torch.jit.is_scripting():
if not aux_defined:
warnings.warn("Scripted GoogleNet always returns GoogleNetOutputs Tuple")
return GoogLeNetOutputs(x, aux2, aux1)
else:
return self.eager_outputs(x, aux2, aux1)

@torch.jit.unused
def eager_outputs(self, x, aux2, aux1):
# type: (Tensor, Optional[Tensor], Optional[Tensor]) -> GoogLeNetOutputs
if self.training and self.aux_logits:
return _GoogLeNetOutputs(x, aux2, aux1)
return x
else:
return x


class Inception(nn.Module):
__constants__ = ['branch2', 'branch3', 'branch4']

def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):
super(Inception, self).__init__()
Expand Down
30 changes: 26 additions & 4 deletions torchvision/models/inception.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,28 @@
from __future__ import division

from collections import namedtuple
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.jit.annotations import Optional
from .utils import load_state_dict_from_url


__all__ = ['Inception3', 'inception_v3']
__all__ = ['Inception3', 'inception_v3', 'InceptionOutputs', '_InceptionOutputs']


model_urls = {
# Inception v3 ported from TensorFlow
'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth',
}

_InceptionOutputs = namedtuple('InceptionOutputs', ['logits', 'aux_logits'])
InceptionOutputs = namedtuple('InceptionOutputs', ['logits', 'aux_logits'])
InceptionOutputs.__annotations__ = {'logits': torch.Tensor, 'aux_logits': Optional[torch.Tensor]}

# Script annotations failed with _GoogleNetOutputs = namedtuple ...
# _InceptionOutputs set here for backwards compat
_InceptionOutputs = InceptionOutputs


def inception_v3(pretrained=False, progress=True, **kwargs):
Expand Down Expand Up @@ -128,8 +137,11 @@ def forward(self, x):
# N x 768 x 17 x 17
x = self.Mixed_6e(x)
# N x 768 x 17 x 17
if self.training and self.aux_logits:
aux_defined = self.training and self.aux_logits
if aux_defined:
aux = self.AuxLogits(x)
else:
aux = None
# N x 768 x 17 x 17
x = self.Mixed_7a(x)
# N x 1280 x 8 x 8
Expand All @@ -146,8 +158,18 @@ def forward(self, x):
# N x 2048
x = self.fc(x)
# N x 1000 (num_classes)
if torch.jit.is_scripting():
if not aux_defined:
warnings.warn("Scripted InceptionNet always returns InceptionOutputs Tuple")
return InceptionOutputs(x, aux)
else:
return self.eager_outputs(x, aux)

@torch.jit.unused
def eager_outputs(self, x, aux):
# type: (torch.Tensor, Optional[torch.Tensor]) -> InceptionOutputs
if self.training and self.aux_logits:
return _InceptionOutputs(x, aux)
return InceptionOutputs(x, aux)
return x


Expand Down