Skip to content

Commit

Permalink
make googlnet scriptable
Browse files Browse the repository at this point in the history
  • Loading branch information
eellison authored and fmassa committed Sep 20, 2019
1 parent 85ffd93 commit 46dd08c
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 8 deletions.
2 changes: 1 addition & 1 deletion test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def get_available_video_models():
"mobilenet_v2": True,
"resnext50_32x4d": True,
"fcn_resnet101": True,
"googlenet": False,
"googlenet": True,
"densenet121": True,
"resnet18": True,
"alexnet": True,
Expand Down
33 changes: 26 additions & 7 deletions torchvision/models/googlenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,19 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
from .utils import load_state_dict_from_url

__all__ = ['GoogLeNet', 'googlenet']
__all__ = ['GoogLeNet', 'googlenet', "_GoogLeNetOutputs"]

model_urls = {
# GoogLeNet ported from TensorFlow
'googlenet': 'https://download.pytorch.org/models/googlenet-1378be20.pth',
}

_GoogLeNetOutputs = namedtuple('GoogLeNetOutputs', ['logits', 'aux_logits2', 'aux_logits1'])

_GoogLeNetOutputs.__annotations__ = {'logits': torch.Tensor, 'aux_logits2': Optional[torch.Tensor],
'aux_logits1': Optional[torch.Tensor]}

def googlenet(pretrained=False, progress=True, **kwargs):
r"""GoogLeNet (Inception v1) model architecture from
Expand Down Expand Up @@ -51,6 +53,7 @@ def googlenet(pretrained=False, progress=True, **kwargs):


class GoogLeNet(nn.Module):
__constants__ = ['aux_logits', 'transform_input']

def __init__(self, num_classes=1000, aux_logits=True, transform_input=False, init_weights=True):
super(GoogLeNet, self).__init__()
Expand Down Expand Up @@ -101,6 +104,14 @@ def _initialize_weights(self):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)

@torch.jit.unused
def eager_outputs(self, x, aux2, aux1):
# type: (Tensor, Optional[Tensor], Optional[Tensor]) -> _GoogLeNetOutputs
if self.training and self.aux_logits:
return _GoogLeNetOutputs(x, aux2, aux1)
else:
return x

def forward(self, x):
if self.transform_input:
x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
Expand Down Expand Up @@ -128,17 +139,22 @@ def forward(self, x):
# N x 480 x 14 x 14
x = self.inception4a(x)
# N x 512 x 14 x 14
if self.training and self.aux_logits:
aux_defined = self.training and self.aux_logits
if aux_defined:
aux1 = self.aux1(x)
else:
aux1 = None

x = self.inception4b(x)
# N x 512 x 14 x 14
x = self.inception4c(x)
# N x 512 x 14 x 14
x = self.inception4d(x)
# N x 528 x 14 x 14
if self.training and self.aux_logits:
if aux_defined:
aux2 = self.aux2(x)
else:
aux2 = None

x = self.inception4e(x)
# N x 832 x 14 x 14
Expand All @@ -156,12 +172,15 @@ def forward(self, x):
x = self.dropout(x)
x = self.fc(x)
# N x 1000 (num_classes)
if self.training and self.aux_logits:
if torch.jit.is_scripting():
if not aux_defined:
warnings.warn("Scripted GoogleNet always returns GoogleNetOutputs Tuple")
return _GoogLeNetOutputs(x, aux2, aux1)
return x

else:
return self.eager_outputs(x, aux2, aux1)

class Inception(nn.Module):
__constants__ = ['branch2', 'branch3', 'branch4']

def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):
super(Inception, self).__init__()
Expand Down

0 comments on commit 46dd08c

Please sign in to comment.