Skip to content

Commit 68f6f50

Browse files
committed
Add CaffeResnet101
1 parent f46394a commit 68f6f50

File tree

5 files changed

+194
-2
lines changed

5 files changed

+194
-2
lines changed

README.md

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ The goal of this repo is:
66
- to access pretrained ConvNets with a unique interface/API inspired by torchvision.
77

88
News:
9+
- 22/03/2018: CaffeResNet101 (good for localization with FasterRCNN)
910
- 21/03/2018: NASNet Mobile thanks to [Veronika Yurchuk](https://github.com/veronikayurchuk) and [Anastasiia](https://github.com/DagnyT)
1011
- 25/01/2018: DualPathNetworks thanks to [Ross Wightman](https://github.com/rwightman/pytorch-dpn-pretrained), Xception thanks to [T Standley](https://github.com/tstandley/Xception-PyTorch), improved TransformImage API
1112
- 13/01/2018: `pip install pretrainedmodels`, `pretrainedmodels.model_names`, `pretrainedmodels.pretrained_settings`
@@ -32,6 +33,7 @@ News:
3233
- [Available models](https://github.com/Cadene/pretrained-models.pytorch#available-models)
3334
- [AlexNet](https://github.com/Cadene/pretrained-models.pytorch#torchvision)
3435
- [BNInception](https://github.com/Cadene/pretrained-models.pytorch#bninception)
36+
- [CaffeResNet101](https://github.com/Cadene/pretrained-models.pytorch#caffe-resnet)
3537
- [DenseNet121](https://github.com/Cadene/pretrained-models.pytorch#torchvision)
3638
- [DenseNet161](https://github.com/Cadene/pretrained-models.pytorch#torchvision)
3739
- [DenseNet169](https://github.com/Cadene/pretrained-models.pytorch#torchvision)
@@ -42,6 +44,7 @@ News:
4244
- [DualPathNet98](https://github.com/Cadene/pretrained-models.pytorch#dualpathnetworks)
4345
- [DualPathNet107](https://github.com/Cadene/pretrained-models.pytorch#dualpathnetworks)
4446
- [DualPathNet113](https://github.com/Cadene/pretrained-models.pytorch#dualpathnetworks)
47+
- [FBResNet152](https://github.com/Cadene/pretrained-models.pytorch#facebook-resnet)
4548
- [InceptionResNetV2](https://github.com/Cadene/pretrained-models.pytorch#inception)
4649
- [InceptionV3](https://github.com/Cadene/pretrained-models.pytorch#inception)
4750
- [InceptionV4](https://github.com/Cadene/pretrained-models.pytorch#inception)
@@ -107,7 +110,7 @@ import pretrainedmodels
107110

108111
```python
109112
print(pretrainedmodels.model_names)
110-
> ['fbresnet152', 'bninception', 'resnext101_32x4d', 'resnext101_64x4d', 'inceptionv4', 'inceptionresnetv2', 'alexnet', 'densenet121', 'densenet169', 'densenet201', 'densenet161', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'inceptionv3', 'squeezenet1_0', 'squeezenet1_1', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19_bn', 'vgg19', 'nasnetalarge', 'nasnetamobile']
113+
> ['fbresnet152', 'bninception', 'resnext101_32x4d', 'resnext101_64x4d', 'inceptionv4', 'inceptionresnetv2', 'alexnet', 'densenet121', 'densenet169', 'densenet201', 'densenet161', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'inceptionv3', 'squeezenet1_0', 'squeezenet1_1', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19_bn', 'vgg19', 'nasnetalarge', 'nasnetamobile', 'cafferesnet101']
111114
```
112115

113116
- To print the available pretrained settings for a chosen model:
@@ -215,6 +218,8 @@ FBResNet152 | [Torch7](https://github.com/facebook/fb.resnet.torch) | 77.84 | 93
215218
[InceptionV3](https://github.com/Cadene/pretrained-models.pytorch#inception) | [Pytorch](https://github.com/pytorch/vision#models) | 77.294 | 93.454
216219
[DenseNet201](https://github.com/Cadene/pretrained-models.pytorch#torchvision) | [Pytorch](https://github.com/pytorch/vision#models) | 77.152 | 93.548
217220
[DualPathNet68b_5k](https://github.com/Cadene/pretrained-models.pytorch#dualpathnetworks) | Our porting | 77.034 | 93.590
221+
[CaffeResnet101](https://github.com/Cadene/pretrained-models.pytorch#caffe-resnet) | [Caffe](https://github.com/KaimingHe/deep-residual-networks) | 76.400 | 92.900
222+
[CaffeResnet101](https://github.com/Cadene/pretrained-models.pytorch#caffe-resnet) | Our porting | 76.200 | 92.766
218223
[DenseNet169](https://github.com/Cadene/pretrained-models.pytorch#torchvision) | [Pytorch](https://github.com/pytorch/vision#models) | 76.026 | 92.992
219224
[ResNet50](https://github.com/Cadene/pretrained-models.pytorch#torchvision) | [Pytorch](https://github.com/pytorch/vision#models) | 76.002 | 92.980
220225
[DualPathNet68](https://github.com/Cadene/pretrained-models.pytorch#dualpathnetworks) | Our porting | 75.868 | 92.774
@@ -265,6 +270,13 @@ There are a bit different from the ResNet* of torchvision. ResNet152 is currentl
265270

266271
- `fbresnet152(num_classes=1000, pretrained='imagenet')`
267272

273+
#### Caffe ResNet*
274+
275+
Source: [Caffe repo of KaimingHe](https://github.com/KaimingHe/deep-residual-networks)
276+
277+
- `cafferesnet101(num_classes=1000, pretrained='imagenet')`
278+
279+
268280
#### Inception*
269281

270282
Source: [TensorFlow Slim repo](https://github.com/tensorflow/models/tree/master/slim) and [Pytorch/Vision repo](https://github.com/pytorch/vision/tree/master/torchvision) for `inceptionv3`

pretrainedmodels/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
# to support pretrainedmodels.__dict__['nasnetalarge']
99
# but depreciated
1010
from .models.fbresnet import fbresnet152
11+
from .models.cafferesnet import cafferesnet101
1112
from .models.bninception import bninception
1213
from .models.resnext import resnext101_32x4d
1314
from .models.resnext import resnext101_64x4d

pretrainedmodels/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from .fbresnet import fbresnet152
22

3+
from .cafferesnet import cafferesnet101
4+
35
from .bninception import bninception
46

57
from .resnext import resnext101_32x4d
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
import math
2+
import torch
3+
import torch.nn as nn
4+
import torch.nn.functional as F
5+
import torch.utils.model_zoo as model_zoo
6+
7+
pretrained_settings = {
8+
'cafferesnet101': {
9+
'imagenet': {
10+
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/cafferesnet101-9cf32c75.pth',
11+
'input_space': 'BGR',
12+
'input_size': [3, 224, 224],
13+
'input_range': [0, 255],
14+
'mean': [102.9801, 115.9465, 122.7717],
15+
'std': [1, 1, 1],
16+
'num_classes': 1000
17+
}
18+
}
19+
}
20+
21+
22+
def conv3x3(in_planes, out_planes, stride=1):
23+
"3x3 convolution with padding"
24+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
25+
padding=1, bias=False)
26+
27+
28+
class BasicBlock(nn.Module):
29+
expansion = 1
30+
31+
def __init__(self, inplanes, planes, stride=1, downsample=None):
32+
super(BasicBlock, self).__init__()
33+
self.conv1 = conv3x3(inplanes, planes, stride)
34+
self.bn1 = nn.BatchNorm2d(planes)
35+
self.relu = nn.ReLU(inplace=True)
36+
self.conv2 = conv3x3(planes, planes)
37+
self.bn2 = nn.BatchNorm2d(planes)
38+
self.downsample = downsample
39+
self.stride = stride
40+
41+
def forward(self, x):
42+
residual = x
43+
44+
out = self.conv1(x)
45+
out = self.bn1(out)
46+
out = self.relu(out)
47+
48+
out = self.conv2(out)
49+
out = self.bn2(out)
50+
51+
if self.downsample is not None:
52+
residual = self.downsample(x)
53+
54+
out += residual
55+
out = self.relu(out)
56+
57+
return out
58+
59+
60+
class Bottleneck(nn.Module):
61+
expansion = 4
62+
63+
def __init__(self, inplanes, planes, stride=1, downsample=None):
64+
super(Bottleneck, self).__init__()
65+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False) # change
66+
self.bn1 = nn.BatchNorm2d(planes)
67+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, # change
68+
padding=1, bias=False)
69+
self.bn2 = nn.BatchNorm2d(planes)
70+
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
71+
self.bn3 = nn.BatchNorm2d(planes * 4)
72+
self.relu = nn.ReLU(inplace=True)
73+
self.downsample = downsample
74+
self.stride = stride
75+
76+
def forward(self, x):
77+
residual = x
78+
79+
out = self.conv1(x)
80+
out = self.bn1(out)
81+
out = self.relu(out)
82+
83+
out = self.conv2(out)
84+
out = self.bn2(out)
85+
out = self.relu(out)
86+
87+
out = self.conv3(out)
88+
out = self.bn3(out)
89+
90+
if self.downsample is not None:
91+
residual = self.downsample(x)
92+
93+
out += residual
94+
out = self.relu(out)
95+
96+
return out
97+
98+
99+
class ResNet(nn.Module):
100+
101+
def __init__(self, block, layers, num_classes=1000):
102+
self.inplanes = 64
103+
super(ResNet, self).__init__()
104+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
105+
bias=False)
106+
self.bn1 = nn.BatchNorm2d(64)
107+
self.relu = nn.ReLU(inplace=True)
108+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) # change
109+
self.layer1 = self._make_layer(block, 64, layers[0])
110+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
111+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
112+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
113+
# it is slightly better whereas slower to set stride = 1
114+
# self.layer4 = self._make_layer(block, 512, layers[3], stride=1)
115+
self.avgpool = nn.AvgPool2d(7)
116+
self.fc = nn.Linear(512 * block.expansion, num_classes)
117+
118+
for m in self.modules():
119+
if isinstance(m, nn.Conv2d):
120+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
121+
m.weight.data.normal_(0, math.sqrt(2. / n))
122+
elif isinstance(m, nn.BatchNorm2d):
123+
m.weight.data.fill_(1)
124+
m.bias.data.zero_()
125+
126+
def _make_layer(self, block, planes, blocks, stride=1):
127+
downsample = None
128+
if stride != 1 or self.inplanes != planes * block.expansion:
129+
downsample = nn.Sequential(
130+
nn.Conv2d(self.inplanes, planes * block.expansion,
131+
kernel_size=1, stride=stride, bias=False),
132+
nn.BatchNorm2d(planes * block.expansion),
133+
)
134+
135+
layers = []
136+
layers.append(block(self.inplanes, planes, stride, downsample))
137+
self.inplanes = planes * block.expansion
138+
for i in range(1, blocks):
139+
layers.append(block(self.inplanes, planes))
140+
141+
return nn.Sequential(*layers)
142+
143+
def forward(self, x):
144+
x = self.conv1(x)
145+
x = self.bn1(x)
146+
x = self.relu(x)
147+
x = self.maxpool(x)
148+
149+
x = self.layer1(x)
150+
x = self.layer2(x)
151+
x = self.layer3(x)
152+
x = self.layer4(x)
153+
154+
x = self.avgpool(x)
155+
x = x.view(x.size(0), -1)
156+
x = self.fc(x)
157+
158+
return x
159+
160+
161+
def cafferesnet101(num_classes=1000, pretrained='imagenet'):
162+
"""Constructs a ResNet-101 model.
163+
Args:
164+
pretrained (bool): If True, returns a model pre-trained on ImageNet
165+
"""
166+
model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes)
167+
if pretrained is not None:
168+
settings = pretrained_settings['cafferesnet101'][pretrained]
169+
assert num_classes == settings['num_classes'], \
170+
"num_classes should be {}, but is {}".format(settings['num_classes'], num_classes)
171+
model.load_state_dict(model_zoo.load_url(settings['url']))
172+
model.input_space = settings['input_space']
173+
model.input_size = settings['input_size']
174+
model.input_range = settings['input_range']
175+
model.mean = settings['mean']
176+
model.std = settings['std']
177+
return model

pretrainedmodels/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.3.0'
1+
__version__ = '0.4.0'

0 commit comments

Comments
 (0)