-
Notifications
You must be signed in to change notification settings - Fork 0
/
template.py
120 lines (102 loc) · 4.53 KB
/
template.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# Feel free to change / extend / adapt this source code as needed to complete the homework, based on its requirements.
# This code is given as a starting point.
#
# REFEFERENCES
# The code is partly adapted from pytorch tutorials, including https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
# ---- hyper-parameters ----
# You should tune these hyper-parameters using:
# (i) your reasoning and observations,
# (ii) by tuning it on the validation set, using the techniques discussed in class.
# You definitely can add more hyper-parameters here.
batch_size = 16
max_num_epoch = 100
hps = {'lr':0.001}
# ---- options ----
DEVICE_ID = 'cpu' # set to 'cpu' for cpu, 'cuda' / 'cuda:0' or similar for gpu.
LOG_DIR = 'checkpoints'
VISUALIZE = False # set True to visualize input, prediction and the output from the last batch
LOAD_CHKPT = False
# --- imports ---
import torch
import os
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
import hw3utils
torch.multiprocessing.set_start_method('spawn', force=True)
# ---- utility functions -----
def get_loaders(batch_size,device):
data_root = 'ceng483-hw3-dataset'
train_set = hw3utils.HW3ImageFolder(root=os.path.join(data_root,'train'),device=device)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0)
val_set = hw3utils.HW3ImageFolder(root=os.path.join(data_root,'val'),device=device)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=0)
# Note: you may later add test_loader to here.
return train_loader, val_loader
# ---- ConvNet -----
class Net(nn.Module):
def __init__(self, num_kernels=3):
super(Net, self).__init__()
# Number of kernels for convolutional layers (except the last one)
self.num_kernels = num_kernels
# Encoder
self.encoder = nn.Sequential(
nn.Conv2d(1, num_kernels, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(num_kernels, num_kernels, kernel_size=3, stride=2, padding=1),
nn.ReLU(inplace=True),
)
# Decoder
self.decoder = nn.Sequential(
nn.Conv2d(num_kernels, num_kernels, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.Conv2d(num_kernels, num_kernels, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(num_kernels, 3, kernel_size=3, stride=1, padding=1),
# No activation for the last conv layer
)
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
# ---- training code -----
device = torch.device(DEVICE_ID)
print('device: ' + str(device))
net = Net().to(device=device)
criterion = nn.MSELoss()
optimizer = optim.SGD(net.parameters(), lr=hps['lr'])
train_loader, val_loader = get_loaders(batch_size,device)
if LOAD_CHKPT:
print('loading the model from the checkpoint')
model.load_state_dict(os.path.join(LOG_DIR,'checkpoint.pt'))
print('training begins')
for epoch in range(max_num_epoch):
running_loss = 0.0 # training loss of the network
for iteri, data in enumerate(train_loader, 0):
inputs, targets = data # inputs: low-resolution images, targets: high-resolution images.
optimizer.zero_grad() # zero the parameter gradients
# do forward, backward, SGD step
preds = net(inputs)
loss = criterion(preds, targets)
loss.backward()
optimizer.step()
# print loss
running_loss += loss.item()
print_n = 100 # feel free to change this constant
if iteri % print_n == (print_n-1): # print every print_n mini-batches
print('[%d, %5d] network-loss: %.3f' %
(epoch + 1, iteri + 1, running_loss / 100))
running_loss = 0.0
# note: you most probably want to track the progress on the validation set as well (needs to be implemented)
if (iteri==0) and VISUALIZE:
hw3utils.visualize_batch(inputs,preds,targets)
print('Saving the model, end of epoch %d' % (epoch+1))
if not os.path.exists(LOG_DIR):
os.makedirs(LOG_DIR)
torch.save(net.state_dict(), os.path.join(LOG_DIR,'checkpoint.pt'))
hw3utils.visualize_batch(inputs,preds,targets,os.path.join(LOG_DIR,'example.png'))
print('Finished Training')