Skip to content

Commit d009416

Browse files
committed
Update 14_cnn.py
1 parent 99d4e6a commit d009416

File tree

1 file changed

+85
-77
lines changed

1 file changed

+85
-77
lines changed

14_cnn.py

Lines changed: 85 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,24 @@
55
import torchvision.transforms as transforms
66
import matplotlib.pyplot as plt
77
import numpy as np
8+
import ssl
89

9-
# Device configuration
10-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10+
ssl._create_default_https_context = ssl._create_unverified_context
11+
12+
# GPU device configuration
13+
if torch.cuda.is_available():
14+
device = torch.device('cuda')
15+
print('Using GPU')
16+
elif torch.backends.mps.is_available():
17+
device = torch.device('mps')
18+
print('Using MPS')
19+
else:
20+
device = torch.device('cpu')
21+
print('Using CPU')
1122

1223
# Hyper-parameters
13-
num_epochs = 5
14-
batch_size = 4
24+
num_epochs = 20
25+
batch_size = 8
1526
learning_rate = 0.001
1627

1728
# dataset has PILImage images of range [0, 1].
@@ -22,109 +33,106 @@
2233

2334
# CIFAR10: 60000 32x32 color images in 10 classes, with 6000 images per class
2435
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True,
25-
download=True, transform=transform)
36+
download=True, transform=transform)
2637

2738
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False,
28-
download=True, transform=transform)
39+
download=True, transform=transform)
2940

3041
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
31-
shuffle=True)
42+
shuffle=True)
3243

3344
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size,
34-
shuffle=False)
45+
shuffle=False)
3546

36-
classes = ('plane', 'car', 'bird', 'cat',
37-
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
47+
classes = ('plane', 'car', 'bird', 'cat', 'deer',
48+
'dog', 'frog', 'horse', 'ship', 'truck')
3849

3950
def imshow(img):
40-
img = img / 2 + 0.5 # unnormalize
41-
npimg = img.numpy()
42-
plt.imshow(np.transpose(npimg, (1, 2, 0)))
43-
plt.show()
44-
51+
img = img / 2 + 0.5 # un-normalize
52+
npimg = img.numpy()
53+
plt.imshow(np.transpose(npimg, (1, 2, 0)))
54+
plt.show()
4555

4656
# get some random training images
4757
dataiter = iter(train_loader)
4858
images, labels = next(dataiter)
4959

5060
# show images
51-
imshow(torchvision.utils.make_grid(images))
61+
# imshow(torchvision.utils.make_grid(images))
5262

5363
class ConvNet(nn.Module):
54-
def __init__(self):
55-
super(ConvNet, self).__init__()
56-
self.conv1 = nn.Conv2d(3, 6, 5)
57-
self.pool = nn.MaxPool2d(2, 2)
58-
self.conv2 = nn.Conv2d(6, 16, 5)
59-
self.fc1 = nn.Linear(16 * 5 * 5, 120)
60-
self.fc2 = nn.Linear(120, 84)
61-
self.fc3 = nn.Linear(84, 10)
62-
63-
def forward(self, x):
64-
# -> n, 3, 32, 32
65-
x = self.pool(F.relu(self.conv1(x))) # -> n, 6, 14, 14
66-
x = self.pool(F.relu(self.conv2(x))) # -> n, 16, 5, 5
67-
x = x.view(-1, 16 * 5 * 5) # -> n, 400
68-
x = F.relu(self.fc1(x)) # -> n, 120
69-
x = F.relu(self.fc2(x)) # -> n, 84
70-
x = self.fc3(x) # -> n, 10
71-
return x
72-
64+
def __init__(self):
65+
super(ConvNet, self).__init__()
66+
self.conv1 = nn.Conv2d(3, 6, 5)
67+
self.pool = nn.MaxPool2d(2, 2)
68+
self.conv2 = nn.Conv2d(6, 16, 5)
69+
self.fc1 = nn.Linear(16 * 5 * 5, 120)
70+
self.fc2 = nn.Linear(120, 84)
71+
self.fc3 = nn.Linear(84, 10)
72+
73+
def forward(self, x):
74+
# -> n, 3, 32, 32
75+
x = self.pool(F.leaky_relu(self.conv1(x))) # -> n, 6, 14, 14
76+
x = self.pool(F.leaky_relu(self.conv2(x))) # -> n, 16, 5, 5
77+
x = x.view(-1, 16 * 5 * 5) # -> n, 400
78+
x = F.leaky_relu(self.fc1(x)) # -> n, 120
79+
x = F.leaky_relu(self.fc2(x)) # -> n, 84
80+
x = self.fc3(x) # -> n, 10
81+
return x
7382

7483
model = ConvNet().to(device)
7584

7685
criterion = nn.CrossEntropyLoss()
77-
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
86+
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
7887

7988
n_total_steps = len(train_loader)
8089
for epoch in range(num_epochs):
81-
for i, (images, labels) in enumerate(train_loader):
82-
# origin shape: [4, 3, 32, 32] = 4, 3, 1024
83-
# input_layer: 3 input channels, 6 output channels, 5 kernel size
84-
images = images.to(device)
85-
labels = labels.to(device)
90+
for i, (images, labels) in enumerate(train_loader):
91+
# origin shape: [4, 3, 32, 32] = 4, 3, 1024
92+
# input_layer: 3 input channels, 6 output channels, 5 kernel size
93+
images = images.to(device)
94+
labels = labels.to(device)
8695

87-
# Forward pass
88-
outputs = model(images)
89-
loss = criterion(outputs, labels)
96+
# Forward pass
97+
outputs = model(images)
98+
loss = criterion(outputs, labels)
9099

91-
# Backward and optimize
92-
optimizer.zero_grad()
93-
loss.backward()
94-
optimizer.step()
100+
# Backward and optimize
101+
optimizer.zero_grad()
102+
loss.backward()
103+
optimizer.step()
95104

96-
if (i+1) % 2000 == 0:
97-
print (f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{n_total_steps}], Loss: {loss.item():.4f}')
105+
if (i+1) % 2000 == 0:
106+
print (f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{n_total_steps}], Loss: {loss.item():.4f}')
98107

99108
print('Finished Training')
100109
PATH = './cnn.pth'
101110
torch.save(model.state_dict(), PATH)
102111

103112
with torch.no_grad():
104-
n_correct = 0
105-
n_samples = 0
106-
n_class_correct = [0 for i in range(10)]
107-
n_class_samples = [0 for i in range(10)]
108-
for images, labels in test_loader:
109-
images = images.to(device)
110-
labels = labels.to(device)
111-
outputs = model(images)
112-
# max returns (value ,index)
113-
_, predicted = torch.max(outputs, 1)
114-
n_samples += labels.size(0)
115-
n_correct += (predicted == labels).sum().item()
116-
117-
for i in range(batch_size):
118-
label = labels[i]
119-
pred = predicted[i]
120-
if (label == pred):
121-
n_class_correct[label] += 1
122-
n_class_samples[label] += 1
123-
124-
acc = 100.0 * n_correct / n_samples
125-
print(f'Accuracy of the network: {acc} %')
126-
127-
for i in range(10):
128-
acc = 100.0 * n_class_correct[i] / n_class_samples[i]
129-
print(f'Accuracy of {classes[i]}: {acc} %')
130-
113+
n_correct = 0
114+
n_samples = 0
115+
n_class_correct = [0 for i in range(10)]
116+
n_class_samples = [0 for i in range(10)]
117+
for images, labels in test_loader:
118+
images = images.to(device)
119+
labels = labels.to(device)
120+
outputs = model(images)
121+
# max returns (value ,index)
122+
_, predicted = torch.max(outputs, 1)
123+
n_samples += labels.size(0)
124+
n_correct += (predicted == labels).sum().item()
125+
126+
for i in range(batch_size):
127+
label = labels[i]
128+
pred = predicted[i]
129+
if (label == pred):
130+
n_class_correct[label] += 1
131+
n_class_samples[label] += 1
132+
133+
acc = 100.0 * n_correct / n_samples
134+
print(f'Accuracy of the network: {acc} %')
135+
136+
for i in range(10):
137+
acc = 100.0 * n_class_correct[i] / n_class_samples[i]
138+
print(f'Accuracy of {classes[i]}: {acc} %')

0 commit comments

Comments
 (0)