Skip to content

Commit 0eb7b70

Browse files
committed
feat: Add GoogLeNet model implementation
1 parent b5a8c28 commit 0eb7b70

File tree

3 files changed

+91
-0
lines changed

3 files changed

+91
-0
lines changed
1.24 MB
Loading
214 KB
Loading

projects/models/googlenet/main.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
5+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
6+
7+
8+
class GoogLeNet(nn.Module):
9+
def __init__(self, in_channels=3, num_classes=1000):
10+
super(GoogLeNet, self).__init__()
11+
self.conv1 = Conv(in_channels, 64, kernel_size=7, stride=2, padding=3)
12+
self.conv2 = Conv(64, 192, kernel_size=3, stride=1, padding=1)
13+
self.incp3 = nn.Sequential(
14+
Inception(192, 64, 96, 128, 16, 32, 32),
15+
Inception(256, 128, 128, 192, 32, 96, 64),
16+
)
17+
self.incp4 = nn.Sequential(
18+
Inception(480, 192, 96, 208, 16, 48, 64),
19+
Inception(512, 160, 112, 224, 24, 64, 64),
20+
Inception(512, 128, 128, 256, 24, 64, 64),
21+
Inception(512, 112, 144, 288, 32, 64, 64),
22+
Inception(528, 256, 160, 320, 32, 128, 128),
23+
)
24+
self.incp5 = nn.Sequential(
25+
Inception(832, 256, 160, 320, 32, 128, 128),
26+
Inception(832, 384, 192, 384, 48, 128, 128),
27+
)
28+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
29+
self.avgpool = nn.AvgPool2d(kernel_size=7, stride=1, padding=0)
30+
self.dropout = nn.Dropout(p=0.4)
31+
self.fc = nn.Linear(1024, 1000)
32+
33+
def forward(self, x):
34+
out = self.maxpool(self.conv1(x))
35+
out = self.maxpool(self.conv2(out))
36+
out = self.maxpool(self.incp3(out))
37+
out = self.maxpool(self.incp4(out))
38+
out = self.avgpool(self.incp5(out))
39+
out = self.dropout(out)
40+
out = out.reshape(out.shape[0], -1)
41+
out = self.fc(out)
42+
return out
43+
44+
45+
class Inception(nn.Module):
46+
def __init__(
47+
self, in_channels, out_1x1, red_3x3, out_3x3, red_5x5, out_5x5,
48+
out_poolproj,
49+
):
50+
super(Inception, self).__init__()
51+
self.branch1 = Conv(
52+
in_channels, out_1x1, kernel_size=1, stride=1, padding=0,
53+
)
54+
self.branch2 = nn.Sequential(
55+
Conv(in_channels, red_3x3, kernel_size=1, stride=1, padding=0),
56+
Conv(red_3x3, out_3x3, kernel_size=3, stride=1, padding=1),
57+
)
58+
self.branch3 = nn.Sequential(
59+
Conv(in_channels, red_5x5, kernel_size=1, stride=1, padding=0),
60+
Conv(red_5x5, out_5x5, kernel_size=5, stride=1, padding=2),
61+
)
62+
self.branch4 = nn.Sequential(
63+
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
64+
Conv(in_channels, out_poolproj, kernel_size=1, stride=1, padding=0)
65+
)
66+
67+
def forward(self, x):
68+
out = torch.cat([
69+
self.branch1(x), self.branch2(x), self.branch3(x), self.branch4(x)
70+
], dim=1)
71+
return out
72+
73+
74+
class Conv(nn.Module):
75+
def __init__(self, in_channels, out_channels, **kwargs):
76+
super(Conv, self).__init__()
77+
self.conv = nn.Sequential(
78+
nn.Conv2d(in_channels, out_channels, **kwargs),
79+
nn.BatchNorm2d(out_channels),
80+
nn.ReLU(),
81+
)
82+
83+
def forward(self, x):
84+
out = self.conv(x)
85+
return out
86+
87+
88+
x = torch.randn((4, 3, 224, 224), dtype=torch.float32).to(device)
89+
model = GoogLeNet().to(device)
90+
out = model(x)
91+
print(f"out.shape = {out.shape}")

0 commit comments

Comments
 (0)