-
Notifications
You must be signed in to change notification settings - Fork 2
/
loss.py
256 lines (205 loc) · 10.2 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
from utilities import convert_YOLO_to_center_coords, iou, _iou, im2PIL, draw_detections, build_class_names
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
_DEVICE_ = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_EPS_ = 1e-8 #safeguard to prevent torch.sqrt(0) returning NaN/infinity as in older versions of torch
def normalised_to_global(A, width=448, height=448):
"""
Converts a bounding box in the center normalised coordinates (values between 0 and 1) to
the center global coordinates (wrt to the image width and height).
A is in the format N*5 where N is the number of bounding boxes and each bounding
box is in the format <x> <y> <w> <h> <class>
- returns A in the same size as input
"""
A[:,0] = A[:,0] * width
A[:,1] = A[:,1] * height
A[:,2] = A[:,2] * width
A[:,3] = A[:,3] * height
return A
def cell_to_global(A, im_size=448, stride=64, B=2, S=7):
"""
Receives a tensor in the format N*(B*5) that represents each cell's bounding box
prediction where each cell is in the format <x> <y> <w> <h> <conf> encoded
in the YOLO format where it normalised relative to the grid cell and N is the
total number of grids i.e S*S and B is the number of bounding boxes
It returns the grid cell coordinates with respect to the global image.
- return B: Tensor of size N*(B*5), same size as input where the <x> <y>
<w> <h> <conf> in each cell is wrt to the global image. This is
still in the center normalised coordinate.
"""
rng = np.arange(S) # the range of possible grid coords
cols, rows = np.meshgrid(rng, rng)
#create a grid with each cell containing the (x,y) location multiplied by stride
rows = torch.FloatTensor(rows).view(-1,1)
cols = torch.FloatTensor(cols).view(-1,1)
grid = torch.cat((rows,cols),1) * stride #now the grid system is measured relative to the image
grid = grid.to(_DEVICE_)
bboxes = torch.split(A, A.size(1)//B, 1) #split the N*10 bboxes into a tuple of (Nx5,Nx5) assuming B=2
res = []
for v in bboxes: # v would be a chunk from the tuple `bboxes`, of size N*5
# convert the <x> <y> and <w> <h> cell coordinates to global image coordinates
# adding `grid` and multiplying by `stride` takes the predictions relative to a cell and transforms it
# to a prediction in global image coordinates
v[:,:2] = (v[:,:2] * stride).round() + grid
v[:,2:4] = (torch.pow(v[:,2:4].clone().detach(),2) * im_size).round()
res.append(v)
res = torch.cat(res,1).to(_DEVICE_)
return res
def box(output, target, size=448, B=2):
"""
Returns the bounding box to use for loss calculation at each grid cell. This is either the box with the
highest confidence or the box with the highest intersection over union with the ground truth target.
Receives `output` a prediction of size SxSx(B*5+C) where each cell is in the format
<x> <y> <w> <h> <conf> | <x> <y> <w> <h> <conf> | <cls.......probs>
assuming number of bounding boxes is 2 and `target` ground truth of size SxSx5 where target is in the format
<x> <y> <w> <h> <cls>
the function returns the bounding box to use for each grid cell as a tensor of size SxSx(5+C)
- return bbox: Tensor of size SxSx(5+C) where each bounding box is in
the format <x> <y> <w> <h> <confidence> <cls probs>
The coordinate, width and height are in global image coordinates as `cell_to_global` and `normalised_to_global` are used to transform `output` and `target` respectively
"""
#Reshape the output tensor into (S*S)x(B*5+C) to make it easier to work with
sz = output.size()
output = output.view(sz[0] * sz[1], -1) #e.g 49x30
pred_bboxes = output[:,:B*5] #slice out only the bounding boxes e.g 49x10
pred_classes = output[:,B*5:] #slice out the pred classes e.g 49x20
target = target.view(sz[0] * sz[1], -1) #e.g 49x5
# The `*_global` variables are needed for IoU calculations
pred_bboxes_global = cell_to_global(pred_bboxes.clone().detach(), B=B) #e.g 49x10
target_global = normalised_to_global(target.clone().detach()) #e.g 49*5
num_classes = output.size(1) - (B*5)
R = torch.zeros(output.size(0),5+num_classes) #result to return. e.g it is of size 49x25
for i in range(output.size(0)): #loop over each cell coordinate
# `bboxes` will be a tuple of size B (e.g 2), where each elem is 1*5
bboxes = torch.split(pred_bboxes[i,:], pred_bboxes.size(1)//B)
bboxes = torch.stack(bboxes)
bboxes_global = torch.split(pred_bboxes_global[i,:], pred_bboxes.size(1)//B)
bboxes_global = torch.stack(bboxes_global)
"""
In the case where there is a ground truth tensor at the current grid cell,
the predicted bounding box with the highest intersection over union to the
ground truth is chosen.
If there is no ground truth prediction at the current cell, just pick the
bounding box with the highest confidence
"""
#case 1: There is a ground truth prediction at this cell i
if target[i].sum() > 0:#select the box with the highest intersection over union
repeated_target = target_global[i].clone().detach().repeat(bboxes.size(0),1)
jac_idx = _iou(bboxes_global, repeated_target)
max_iou_idx = torch.argmax(jac_idx)
R[i,:5] = bboxes[max_iou_idx,:]
else: #select the box with the highest confidence
highest_conf_idx = torch.argmax(bboxes[:,4])
R[i,:5] = bboxes[highest_conf_idx,:]
#Add the predicted class confidence to the results
R[i,5:] = pred_classes[i]
return R.view(sz[0], sz[1], -1)
def criterion(output, target, lambda_coord = 5, lambda_noobj=0.5): #, stride
"""
Computes the average loss (YOLO) between the output and the target batch tensor
- The output is of size NxSxSx(Bx5+C) where B is the no. of bounding boxes encoded
in the YOLO format
- The target is of size NxSxSx5 (generated in the `batch_collate_fn`) in format
<x> <y> <w> <h> <class>
not encoded in the YOLO format but normalised wrt the image
"""
batch_loss = torch.tensor(0).float().to(_DEVICE_)
for idx, out_tensor in enumerate(output):
best_boxes = box(out_tensor, target[idx], B=2) #e.g 7x7x(5+20)
sz = best_boxes.size()
P = best_boxes.view(sz[0] * sz[1], -1).to(_DEVICE_) #e.g 49x25
G = target[idx].view(sz[0] * sz[1], -1).to(_DEVICE_) #e.g 49x5
image_loss = torch.tensor(0).float().to(_DEVICE_)
for i in range(P.size(0)): #loop over each cell coordinate
if G[i].sum() > 0: #there is a ground truth prediction at this cell
pred_cls = P[i,5:]
true_cls = torch.zeros(pred_cls.size()).to(_DEVICE_)
true_cls[int(G[i,4])] = 1
# grid cell regression loss
grid_loss = lambda_coord * torch.pow(P[i,0:2] - G[i,0:2], 2).sum() \
+ lambda_coord * torch.pow(torch.sqrt(P[i,2:4] + _EPS_) - torch.sqrt(G[i,2:4] + _EPS_),2).sum() \
+ torch.pow(P[i,4] - 1,2) \
+ torch.pow(pred_cls - true_cls, 2).sum() # class probability loss
else:
grid_loss = lambda_noobj * torch.pow(P[i,4] - 0,2) #confidence should be zero
image_loss += (grid_loss)
batch_loss += image_loss
avg_batch_loss = batch_loss/output.size(0)
return avg_batch_loss
class TestNet(nn.Module):
def __init__(self):
super(TestNet,self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(3, 64, 3, 1, 0),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 128, 3, 1, 0,bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 256, 3, 1, 0),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256, 512, 5, 1, 0 ),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d((1,1))
)
self.linear = nn.Sequential(
nn.Linear(512, 1024, True),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(1024, 1470),
nn.Sigmoid()
)
def forward(self, x):
x = self.conv(x)
x = x.view(x.size()[0], -1)
x = self.linear(x)
x = x.view(-1, 7, 7, 30)
return x
if __name__ == "__main__":
# Test the loss function
# class_names = build_class_names("./voc.names")
# images = torch.randn(1, 3, 448, 448)
# detections = torch.rand(1, 1, 5)
# detections[:,:,0] *= 20
# x = im2PIL(images[0])
# x = draw_detections(x, detections[0], class_names)
# x.show()
torch.autograd.set_detect_anomaly(True)
"""
This test assumes that I can build a simple model that can overfit to
a random batch of images over at least 20 epochs
"""
net = TestNet()
optimiser = torch.optim.Adam(net.parameters(), lr=1e-5)
X = torch.randn(3, 3, 448, 448)
"""
Ground truth prediction format
- It is not in global image coordinates
- The structure of each cell is <x> <y> <w> <h> <class>
- It is in the center normalised form where x,y,w,h have been divided by the image
width and height (not encoded in the YOLO format)
"""
Y = torch.rand(3,7,7,5)
Y[:,:,:,:4] = torch.clamp(Y[:,:,:,:4], 0,1) #This is not encoded in the YOLO format
Y[:,:,:,4] = (Y[:,:,:,4] * 20).floor()
#zero out some cells to represent no ground truth data at those locations
mask = torch.empty(7,7).uniform_(0,1)
mask = torch.bernoulli(mask).bool()
Y[:,mask,:] = 0
losses = []
for i in range(2):
optimiser.zero_grad()
Y_ = net(X)
loss = criterion(Y_, Y)
print(f"Test Epoch {i}: Loss = {loss.item()}")
loss.backward()
optimiser.step()
losses.append(loss.item())
print(f"Finished testing model and Loss function")
plt.plot(losses)
plt.show()