8
8
# Last modification: May 24th 2018 #
9
9
###########################################################################################
10
10
11
+ import sys
12
+ from collections import Counter
13
+
14
+ import matplotlib .pyplot as plt
15
+ import numpy as np
16
+
11
17
from BoundingBox import *
12
18
from BoundingBoxes import *
13
- import matplotlib .pyplot as plt
14
- from collections import Counter
15
19
from utils import *
16
- import numpy as np
17
- import sys
18
20
19
21
20
22
class Evaluator :
21
- def GetPascalVOCMetrics (self , boundingboxes , IOUThreshold = 0.5 ):
23
+ def GetPascalVOCMetrics (self ,
24
+ boundingboxes ,
25
+ IOUThreshold = 0.5 ,
26
+ method = MethodAveragePrecision .EveryPointInterpolation ):
22
27
"""Get the metrics used by the VOC Pascal 2012 challenge.
23
28
Get
24
29
Args:
25
30
boundingboxes: Object of the class BoundingBoxes representing ground truth and detected
26
31
bounding boxes;
27
32
IOUThreshold: IOU threshold indicating which detections will be considered TP or FP
28
- (default value = 0.5).
33
+ (default value = 0.5);
34
+ method (default = EveryPointInterpolation): It can be calculated as the implementation
35
+ in the official PASCAL VOC toolkit (EveryPointInterpolation), or applying the 11-point
36
+ interpolatio as described in the paper "The PASCAL Visual Object Classes(VOC) Challenge"
37
+ or EveryPointInterpolation" (ElevenPointInterpolation);
29
38
Returns:
30
39
A list of dictionaries. Each dictionary contains information and metrics of each class.
31
40
The keys of each dictionary are:
@@ -112,7 +121,11 @@ def GetPascalVOCMetrics(self, boundingboxes, IOUThreshold=0.5):
112
121
acc_TP = np .cumsum (TP )
113
122
rec = acc_TP / npos
114
123
prec = np .divide (acc_TP , (acc_FP + acc_TP ))
115
- [ap , mpre , mrec , ii ] = Evaluator .CalculateAveragePrecision (rec , prec )
124
+ # Depending on the method, call the right implementation
125
+ if method == MethodAveragePrecision .EveryPointInterpolation :
126
+ [ap , mpre , mrec , ii ] = Evaluator .CalculateAveragePrecision (rec , prec )
127
+ else :
128
+ [ap , mpre , mrec , _ ] = Evaluator .ElevenPointInterpolatedAP (rec , prec )
116
129
# add class result in the dictionary to be returned
117
130
r = {
118
131
'class' : c ,
@@ -132,6 +145,7 @@ def PlotPrecisionRecallCurve(self,
132
145
classId ,
133
146
boundingBoxes ,
134
147
IOUThreshold = 0.5 ,
148
+ method = MethodAveragePrecision .EveryPointInterpolation ,
135
149
showAP = False ,
136
150
showInterpolatedPrecision = False ,
137
151
savePath = None ,
@@ -144,6 +158,10 @@ def PlotPrecisionRecallCurve(self,
144
158
bounding boxes;
145
159
IOUThreshold (optional): IOU threshold indicating which detections will be considered
146
160
TP or FP (default value = 0.5);
161
+ method (default = EveryPointInterpolation): It can be calculated as the implementation
162
+ in the official PASCAL VOC toolkit (EveryPointInterpolation), or applying the 11-point
163
+ interpolatio as described in the paper "The PASCAL Visual Object Classes(VOC) Challenge"
164
+ or EveryPointInterpolation" (ElevenPointInterpolation).
147
165
showAP (optional): if True, the average precision value will be shown in the title of
148
166
the graph (default = False);
149
167
showInterpolatedPrecision (optional): if True, it will show in the plot the interpolated
@@ -164,7 +182,7 @@ def PlotPrecisionRecallCurve(self,
164
182
dict['total TP']: total number of True Positive detections;
165
183
dict['total FP']: total number of False Negative detections;
166
184
"""
167
- results = self .GetPascalVOCMetrics (boundingBoxes , IOUThreshold )
185
+ results = self .GetPascalVOCMetrics (boundingBoxes , IOUThreshold , method )
168
186
result = None
169
187
for res in results :
170
188
if res ['class' ] == classId :
@@ -178,10 +196,64 @@ def PlotPrecisionRecallCurve(self,
178
196
average_precision = result ['AP' ]
179
197
mpre = result ['interpolated precision' ]
180
198
mrec = result ['interpolated recall' ]
199
+ # npos = result['total positives']
200
+ # total_tp = result['total TP']
201
+ # total_fp = result['total FP']
202
+
203
+ if showInterpolatedPrecision :
204
+ if method == MethodAveragePrecision .EveryPointInterpolation :
205
+ plt .plot (mrec , mpre , '--r' , label = 'Interpolated precision (every point)' )
206
+ elif method == MethodAveragePrecision .ElevenPointInterpolation :
207
+ # Uncomment the line below if you want to plot the area
208
+ # plt.plot(mrec, mpre, 'or', label='11-point interpolated precision')
209
+ # Remove duplicates, getting only the highest precision of each recall value
210
+ nrec = []
211
+ nprec = []
212
+ for idx in range (len (mrec )):
213
+ r = mrec [idx ]
214
+ if r not in nrec :
215
+ idxEq = np .argwhere (mrec == r )
216
+ nrec .append (r )
217
+ nprec .append (max ([mpre [int (id )] for id in idxEq ]))
218
+ plt .plot (nrec , nprec , 'or' , label = '11-point interpolated precision' )
219
+ plt .plot (recall , precision , label = 'Precision' )
220
+ plt .xlabel ('recall' )
221
+ plt .ylabel ('precision' )
222
+ if showAP :
223
+ ap_str = "{0:.2f}%" .format (average_precision * 100 )
224
+ plt .title ('Precision x Recall curve \n Class: %s, AP: %s' % (str (classId ), ap_str ))
225
+ # plt.title('Precision x Recall curve \nClass: %s, AP: %.4f' % (str(classId),
226
+ # average_precision))
227
+ else :
228
+ plt .title ('Precision x Recall curve \n Class: %d' % classId )
229
+ plt .legend (shadow = True )
230
+ plt .grid ()
231
+ plt .show ()
232
+
233
+ def PlotPrecisionRecallCurve2 (self ,
234
+ classId ,
235
+ boundingBoxes ,
236
+ IOUThreshold = 0.5 ,
237
+ showAP = False ,
238
+ showInterpolatedPrecision = False ,
239
+ savePath = None ,
240
+ showGraphic = True ):
241
+ results = self .GetPascalVOCMetrics (boundingBoxes , IOUThreshold )
242
+ result = None
243
+ for res in results :
244
+ if res ['class' ] == classId :
245
+ result = res
246
+ break
247
+ if result is None :
248
+ raise IOError ('Error: Class %d could not be found.' % classId )
249
+ precision = result ['precision' ]
250
+ recall = result ['recall' ]
251
+ average_precision = result ['AP' ]
252
+ mpre = result ['interpolated precision' ]
253
+ mrec = result ['interpolated recall' ]
181
254
npos = result ['total positives' ]
182
255
total_tp = result ['total TP' ]
183
256
total_fp = result ['total FP' ]
184
-
185
257
if showInterpolatedPrecision :
186
258
plt .plot (mrec , mpre , '--r' , label = 'Interpolated precision' )
187
259
plt .plot (recall , precision , label = 'Precision' )
@@ -288,6 +360,56 @@ def CalculateAveragePrecision(rec, prec):
288
360
# return [ap, mpre[1:len(mpre)-1], mrec[1:len(mpre)-1], ii]
289
361
return [ap , mpre [0 :len (mpre ) - 1 ], mrec [0 :len (mpre ) - 1 ], ii ]
290
362
363
+ @staticmethod
364
+ # 11-point interpolated average precision
365
+ def ElevenPointInterpolatedAP (rec , prec ):
366
+ # def CalculateAveragePrecision2(rec, prec):
367
+ mrec = []
368
+ # mrec.append(0)
369
+ [mrec .append (e ) for e in rec ]
370
+ # mrec.append(1)
371
+ mpre = []
372
+ # mpre.append(0)
373
+ [mpre .append (e ) for e in prec ]
374
+ # mpre.append(0)
375
+ recallValues = np .linspace (0 , 1 , 11 )
376
+ recallValues = list (recallValues [::- 1 ])
377
+ rhoInterp = []
378
+ recallValid = []
379
+ # For each recallValues (0, 0.1, 0.2, ... , 1)
380
+ for r in recallValues :
381
+ # Obtain all recall values higher or equal than r
382
+ argGreaterRecalls = np .argwhere (mrec [:- 1 ] >= r )
383
+ pmax = 0
384
+ # If there are recalls above r
385
+ if argGreaterRecalls .size != 0 :
386
+ pmax = max (mpre [argGreaterRecalls .min ():])
387
+ recallValid .append (r )
388
+ rhoInterp .append (pmax )
389
+ # By definition AP = sum(max(precision whose recall is above r))/11
390
+ ap = sum (rhoInterp ) / 11
391
+ # Generating values for the plot
392
+ rvals = []
393
+ rvals .append (recallValid [0 ])
394
+ [rvals .append (e ) for e in recallValid ]
395
+ rvals .append (0 )
396
+ pvals = []
397
+ pvals .append (0 )
398
+ [pvals .append (e ) for e in rhoInterp ]
399
+ pvals .append (0 )
400
+ # rhoInterp = rhoInterp[::-1]
401
+ cc = []
402
+ for i in range (len (rvals )):
403
+ p = (rvals [i ], pvals [i - 1 ])
404
+ if p not in cc :
405
+ cc .append (p )
406
+ p = (rvals [i ], pvals [i ])
407
+ if p not in cc :
408
+ cc .append (p )
409
+ recallValues = [i [0 ] for i in cc ]
410
+ rhoInterp = [i [1 ] for i in cc ]
411
+ return [ap , rhoInterp , recallValues , None ]
412
+
291
413
# For each detections, calculate IOU with reference
292
414
@staticmethod
293
415
def _getAllIOUs (reference , detections ):
0 commit comments