Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Googlnet & InceptionNet scriptable #1349

Merged
merged 12 commits into from
Sep 27, 2019
32 changes: 15 additions & 17 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,20 @@ def get_available_video_models():
return [k for k, v in models.video.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]


# model_name, expected to script without error
torchub_models = {
"deeplabv3_resnet101": True,
"mobilenet_v2": True,
"resnext50_32x4d": True,
"fcn_resnet101": True,
"googlenet": False,
"densenet121": True,
"resnet18": True,
"alexnet": True,
"shufflenet_v2_x1_0": True,
"squeezenet1_0": True,
"vgg11": True,
"inception_v3": False,
}

torchub_models = [
"deeplabv3_resnet101",
"mobilenet_v2",
"resnext50_32x4d",
"fcn_resnet101",
"googlenet",
"densenet121",
"resnet18",
"alexnet",
"shufflenet_v2_x1_0",
"squeezenet1_0",
"vgg11",
"inception_v3",
]

class Tester(unittest.TestCase):
def check_script(self, model, name):
Expand All @@ -55,7 +53,7 @@ def check_script(self, model, name):
tb = traceback.format_exc()
scriptable = False
msg = str(e) + str(tb)
self.assertEqual(torchub_models[name], scriptable, msg)
self.assertEqual(scriptable, scriptable, msg)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this change is what you want to do?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since they are all scriptable now, we don't need to access torchub_models and just assert that scriptable is True

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you then make

self.assertEqual(scriptable, True, msg)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh oops - good catch


def _test_classification_model(self, name, input_shape):
# passing num_class equal to a number other than 1000 helps in making the test
Expand Down
28 changes: 25 additions & 3 deletions torchvision/models/googlenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.jit.annotations import Optional
from torch import Tensor
from .utils import load_state_dict_from_url

__all__ = ['GoogLeNet', 'googlenet']
Expand All @@ -13,6 +15,8 @@
}

_GoogLeNetOutputs = namedtuple('GoogLeNetOutputs', ['logits', 'aux_logits2', 'aux_logits1'])
_GoogLeNetOutputs.__annotations__ = {'logits': Tensor, 'aux_logits2': Optional[Tensor],
'aux_logits1': Optional[Tensor]}


def googlenet(pretrained=False, progress=True, **kwargs):
Expand Down Expand Up @@ -51,6 +55,7 @@ def googlenet(pretrained=False, progress=True, **kwargs):


class GoogLeNet(nn.Module):
__constants__ = ['aux_logits', 'transform_input']

def __init__(self, num_classes=1000, aux_logits=True, transform_input=False, init_weights=True):
super(GoogLeNet, self).__init__()
Expand Down Expand Up @@ -128,17 +133,22 @@ def forward(self, x):
# N x 480 x 14 x 14
x = self.inception4a(x)
# N x 512 x 14 x 14
if self.training and self.aux_logits:
aux_defined = self.training and self.aux_logits
if aux_defined:
aux1 = self.aux1(x)
else:
aux1 = None

x = self.inception4b(x)
# N x 512 x 14 x 14
x = self.inception4c(x)
# N x 512 x 14 x 14
x = self.inception4d(x)
# N x 528 x 14 x 14
if self.training and self.aux_logits:
if aux_defined:
aux2 = self.aux2(x)
else:
aux2 = None

x = self.inception4e(x)
# N x 832 x 14 x 14
Expand All @@ -156,12 +166,24 @@ def forward(self, x):
x = self.dropout(x)
x = self.fc(x)
# N x 1000 (num_classes)
if torch.jit.is_scripting():
if not aux_defined:
warnings.warn("Scripted GoogleNet always returns GoogleNetOutputs Tuple")
return _GoogLeNetOutputs(x, aux2, aux1)
else:
return self.eager_outputs(x, aux2, aux1)

@torch.jit.unused
def eager_outputs(self, x, aux2, aux1):
# type: (Tensor, Optional[Tensor], Optional[Tensor]) -> GoogLeNetOutputs # noqa: 177
if self.training and self.aux_logits:
return _GoogLeNetOutputs(x, aux2, aux1)
return x
else:
return x


class Inception(nn.Module):
__constants__ = ['branch2', 'branch3', 'branch4']

def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):
super(Inception, self).__init__()
Expand Down
18 changes: 17 additions & 1 deletion torchvision/models/inception.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from collections import namedtuple
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.jit.annotations import Optional
from .utils import load_state_dict_from_url


Expand All @@ -14,6 +16,7 @@
}

_InceptionOutputs = namedtuple('InceptionOutputs', ['logits', 'aux_logits'])
_InceptionOutputs.__annotations__ = {'logits': torch.Tensor, 'aux_logits': Optional[torch.Tensor]}


def inception_v3(pretrained=False, progress=True, **kwargs):
Expand Down Expand Up @@ -128,8 +131,11 @@ def forward(self, x):
# N x 768 x 17 x 17
x = self.Mixed_6e(x)
# N x 768 x 17 x 17
if self.training and self.aux_logits:
aux_defined = self.training and self.aux_logits
if aux_defined:
aux = self.AuxLogits(x)
else:
aux = None
# N x 768 x 17 x 17
x = self.Mixed_7a(x)
# N x 1280 x 8 x 8
Expand All @@ -146,6 +152,16 @@ def forward(self, x):
# N x 2048
x = self.fc(x)
# N x 1000 (num_classes)
if torch.jit.is_scripting():
if not aux_defined:
warnings.warn("Scripted InceptionNet always returns InceptionOutputs Tuple")
return _InceptionOutputs(x, aux)
else:
return self.eager_outputs(x, aux)

@torch.jit.unused
def eager_outputs(self, x, aux):
# type: (torch.Tensor, Optional[torch.Tensor]) -> _InceptionOutputs
if self.training and self.aux_logits:
return _InceptionOutputs(x, aux)
return x
Expand Down