Skip to content

Commit 4d615c3

Browse files
authored
Explicitly compute TP, FP in val.py (ultralytics#5727)
1 parent cb4673c commit 4d615c3

File tree

2 files changed

+16
-7
lines changed

2 files changed

+16
-7
lines changed

utils/metrics.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def fitness(x):
1818
return (x[:, :4] * w).sum(1)
1919

2020

21-
def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names=()):
21+
def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names=(), eps=1e-16):
2222
""" Compute the average precision, given the recall and precision curves.
2323
Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
2424
# Arguments
@@ -37,15 +37,15 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names
3737
tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]
3838

3939
# Find unique classes
40-
unique_classes = np.unique(target_cls)
40+
unique_classes, nt = np.unique(target_cls, return_counts=True)
4141
nc = unique_classes.shape[0] # number of classes, number of detections
4242

4343
# Create Precision-Recall curve and compute AP for each class
4444
px, py = np.linspace(0, 1, 1000), [] # for plotting
4545
ap, p, r = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000))
4646
for ci, c in enumerate(unique_classes):
4747
i = pred_cls == c
48-
n_l = (target_cls == c).sum() # number of labels
48+
n_l = nt[ci] # number of labels
4949
n_p = i.sum() # number of predictions
5050

5151
if n_p == 0 or n_l == 0:
@@ -56,7 +56,7 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names
5656
tpc = tp[i].cumsum(0)
5757

5858
# Recall
59-
recall = tpc / (n_l + 1e-16) # recall curve
59+
recall = tpc / (n_l + eps) # recall curve
6060
r[ci] = np.interp(-px, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases
6161

6262
# Precision
@@ -70,7 +70,7 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names
7070
py.append(np.interp(px, mrec, mpre)) # precision at mAP@0.5
7171

7272
# Compute F1 (harmonic mean of precision and recall)
73-
f1 = 2 * p * r / (p + r + 1e-16)
73+
f1 = 2 * p * r / (p + r + eps)
7474
names = [v for k, v in names.items() if k in unique_classes] # list: only classes that have data
7575
names = {i: v for i, v in enumerate(names)} # to dict
7676
if plot:
@@ -80,7 +80,10 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names
8080
plot_mc_curve(px, r, Path(save_dir) / 'R_curve.png', names, ylabel='Recall')
8181

8282
i = f1.mean(0).argmax() # max F1 index
83-
return p[:, i], r[:, i], ap, f1[:, i], unique_classes.astype('int32')
83+
p, r, f1 = p[:, i], r[:, i], f1[:, i]
84+
tp = (r * nt).round() # true positives
85+
fp = (tp / (p + eps) - tp).round() # false positives
86+
return tp, fp, p, r, f1, ap, unique_classes.astype('int32')
8487

8588

8689
def compute_ap(recall, precision):
@@ -162,6 +165,12 @@ def process_batch(self, detections, labels):
162165
def matrix(self):
163166
return self.matrix
164167

168+
def tp_fp(self):
169+
tp = self.matrix.diagonal() # true positives
170+
fp = self.matrix.sum(1) - tp # false positives
171+
# fn = self.matrix.sum(0) - tp # false negatives (missed detections)
172+
return tp[:-1], fp[:-1] # remove background class
173+
165174
def plot(self, normalize=True, save_dir='', names=()):
166175
try:
167176
import seaborn as sn

val.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def run(data,
237237
# Compute metrics
238238
stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy
239239
if len(stats) and stats[0].any():
240-
p, r, ap, f1, ap_class = ap_per_class(*stats, plot=plots, save_dir=save_dir, names=names)
240+
tp, fp, p, r, f1, ap, ap_class = ap_per_class(*stats, plot=plots, save_dir=save_dir, names=names)
241241
ap50, ap = ap[:, 0], ap.mean(1) # AP@0.5, AP@0.5:0.95
242242
mp, mr, map50, map = p.mean(), r.mean(), ap50.mean(), ap.mean()
243243
nt = np.bincount(stats[3].astype(np.int64), minlength=nc) # number of targets per class

0 commit comments

Comments
 (0)