@@ -18,7 +18,7 @@ def fitness(x):
18
18
return (x [:, :4 ] * w ).sum (1 )
19
19
20
20
21
- def ap_per_class (tp , conf , pred_cls , target_cls , plot = False , save_dir = '.' , names = ()):
21
+ def ap_per_class (tp , conf , pred_cls , target_cls , plot = False , save_dir = '.' , names = (), eps = 1e-16 ):
22
22
""" Compute the average precision, given the recall and precision curves.
23
23
Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
24
24
# Arguments
@@ -37,15 +37,15 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names
37
37
tp , conf , pred_cls = tp [i ], conf [i ], pred_cls [i ]
38
38
39
39
# Find unique classes
40
- unique_classes = np .unique (target_cls )
40
+ unique_classes , nt = np .unique (target_cls , return_counts = True )
41
41
nc = unique_classes .shape [0 ] # number of classes, number of detections
42
42
43
43
# Create Precision-Recall curve and compute AP for each class
44
44
px , py = np .linspace (0 , 1 , 1000 ), [] # for plotting
45
45
ap , p , r = np .zeros ((nc , tp .shape [1 ])), np .zeros ((nc , 1000 )), np .zeros ((nc , 1000 ))
46
46
for ci , c in enumerate (unique_classes ):
47
47
i = pred_cls == c
48
- n_l = ( target_cls == c ). sum () # number of labels
48
+ n_l = nt [ ci ] # number of labels
49
49
n_p = i .sum () # number of predictions
50
50
51
51
if n_p == 0 or n_l == 0 :
@@ -56,7 +56,7 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names
56
56
tpc = tp [i ].cumsum (0 )
57
57
58
58
# Recall
59
- recall = tpc / (n_l + 1e-16 ) # recall curve
59
+ recall = tpc / (n_l + eps ) # recall curve
60
60
r [ci ] = np .interp (- px , - conf [i ], recall [:, 0 ], left = 0 ) # negative x, xp because xp decreases
61
61
62
62
# Precision
@@ -70,7 +70,7 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names
70
70
py .append (np .interp (px , mrec , mpre )) # precision at mAP@0.5
71
71
72
72
# Compute F1 (harmonic mean of precision and recall)
73
- f1 = 2 * p * r / (p + r + 1e-16 )
73
+ f1 = 2 * p * r / (p + r + eps )
74
74
names = [v for k , v in names .items () if k in unique_classes ] # list: only classes that have data
75
75
names = {i : v for i , v in enumerate (names )} # to dict
76
76
if plot :
@@ -80,7 +80,10 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names
80
80
plot_mc_curve (px , r , Path (save_dir ) / 'R_curve.png' , names , ylabel = 'Recall' )
81
81
82
82
i = f1 .mean (0 ).argmax () # max F1 index
83
- return p [:, i ], r [:, i ], ap , f1 [:, i ], unique_classes .astype ('int32' )
83
+ p , r , f1 = p [:, i ], r [:, i ], f1 [:, i ]
84
+ tp = (r * nt ).round () # true positives
85
+ fp = (tp / (p + eps ) - tp ).round () # false positives
86
+ return tp , fp , p , r , f1 , ap , unique_classes .astype ('int32' )
84
87
85
88
86
89
def compute_ap (recall , precision ):
@@ -162,6 +165,12 @@ def process_batch(self, detections, labels):
162
165
def matrix (self ):
163
166
return self .matrix
164
167
168
+ def tp_fp (self ):
169
+ tp = self .matrix .diagonal () # true positives
170
+ fp = self .matrix .sum (1 ) - tp # false positives
171
+ # fn = self.matrix.sum(0) - tp # false negatives (missed detections)
172
+ return tp [:- 1 ], fp [:- 1 ] # remove background class
173
+
165
174
def plot (self , normalize = True , save_dir = '' , names = ()):
166
175
try :
167
176
import seaborn as sn
0 commit comments