-
Notifications
You must be signed in to change notification settings - Fork 646
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added PAT metric #659
Added PAT metric #659
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -52,6 +52,10 @@ def __init__(self, | |
self.intersects = {} | ||
self.intersects_ovr = {} | ||
|
||
# PAT Tracking stuff. | ||
self.instance_preds = {} | ||
self.instance_gts = {} | ||
|
||
# Per-class association quality stuff. | ||
self.pan_aq = np.zeros(self.n_classes, dtype=np.double) | ||
self.pan_aq_ovr = 0.0 | ||
|
@@ -129,6 +133,48 @@ def get_panoptic_track_stats(self, | |
unique_combo_, counts_combo_ = np.unique(offset_combo_, return_counts=True) | ||
self.update_dict_stat(cl_intersects, unique_combo_, counts_combo_) | ||
|
||
# Computation for PAT score | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you help to correct the return typing to |
||
# Computes unique gt instances and its number of points > self.min_points | ||
unique_gt_, counts_gt_ = np.unique(y_inst_in_cl[y_inst_in_cl > 0], return_counts=True) | ||
id2idx_gt_ = {inst_id: idx for idx, inst_id in enumerate(unique_gt_)} | ||
# Computes unique pred instances (class-agnotstic) and its number of points | ||
unique_pred_, counts_pred_ = np.unique(x_inst_row[x_inst_row > 0], return_counts=True) | ||
id2idx_pred_ = {inst_id: idx for idx, inst_id in enumerate(unique_pred_)} | ||
# Actually unique_combo_ = pred_labels_ + self.offset * gt_labels_ | ||
gt_labels_ = unique_combo_ // self.offset | ||
pred_labels_ = unique_combo_ % self.offset | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps some comments on what these two lines are doing?:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added the comment. Let me know if it is intuitive. |
||
gt_areas_ = np.array([counts_gt_[id2idx_gt_[g_id]] for g_id in gt_labels_]) | ||
pred_areas_ = np.array([counts_pred_[id2idx_pred_[p_id]] for p_id in pred_labels_]) | ||
# Here counts_combo_ : TP (point-level) | ||
intersections_ = counts_combo_ | ||
# Here gt_areas_ : TP + FN, pred_areas_ : TP + FP (point-level) | ||
# Overall unions_ : TP + FP + FN (point-level) | ||
unions_ = gt_areas_ + pred_areas_ - intersections_ | ||
# IoU : TP / (TP + FP + FN) | ||
ious_agnostic = intersections_.astype(np.float32) / unions_.astype(np.float32) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could unions be 0? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No. gt_areas_ can never be 0. |
||
# tp_indexes_agnostic : TP (instance-level, IoU > 0.5) | ||
tp_indexes_agnostic = ious_agnostic > 0.5 | ||
matched_gt_ = np.array([False] * len(id2idx_gt_)) | ||
matched_gt_[[id2idx_gt_[g_id] for g_id in gt_labels_[tp_indexes_agnostic]]] = True | ||
|
||
# Stores matched tracks (the corresponding class-agnostic predicted instance) for the unique gt instances: | ||
for idx, value in enumerate(tp_indexes_agnostic): | ||
if value: | ||
g_label = gt_labels_[idx] | ||
p_label = pred_labels_[idx] | ||
if g_label not in self.instance_gts[scene][cl]: | ||
self.instance_gts[scene][cl][g_label] = [p_label,] | ||
else: | ||
self.instance_gts[scene][cl][g_label].append(p_label) | ||
|
||
# Stores unmatched tracks for the unique gt instances: assigns 1 for no match | ||
for g_label in unique_gt_: | ||
if not matched_gt_[id2idx_gt_[g_label]]: | ||
if g_label not in self.instance_gts[scene][cl]: | ||
self.instance_gts[scene][cl][g_label] = [1,] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you double check the indentations, here has 5 spaces, i see many places have wrong indentations. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
else: | ||
self.instance_gts[scene][cl][g_label].append(1) | ||
|
||
# Generate an intersection map, count the intersections with over 0.5 IoU as TP. | ||
gt_labels = unique_combo // self.offset | ||
pred_labels = unique_combo % self.offset | ||
|
@@ -160,7 +206,9 @@ def add_batch_panoptic(self, | |
self.gts[scene] = [{} for _ in range(self.n_classes)] | ||
self.intersects[scene] = [{} for _ in range(self.n_classes)] | ||
self.intersects_ovr[scene] = [{} for _ in range(self.n_classes)] | ||
# Make sure instance IDs are non-zeros. Otherwise, they will be ignored. Note in nuScenes-panoptic, | ||
self.instance_preds[scene] = {} | ||
self.instance_gts[scene] = [{} for _ in range(self.n_classes)] | ||
# Make sure instance IDs are non-zeros. Otherwise, they will be ignored. Note in Panoptic nuScenes, | ||
# instance IDs start from 1 already, so the following 2 lines of code are actually not necessary, but to be | ||
# consistent with the PanopticEval class in panoptic_seg_evaluator.py from 3rd party. We keep these 2 lines. It | ||
# means the actual instance IDs will start from 2 during metrics evaluation. | ||
|
@@ -186,6 +234,14 @@ def add_batch_panoptic(self, | |
x_inst_row[0] = x_inst_row[0][gt_not_in_excl_mask] | ||
y_inst_row[0] = y_inst_row[0][gt_not_in_excl_mask] | ||
|
||
# Accumulate class-agnostic predictions | ||
unique_pred_, counts_pred_ = np.unique(x_inst_row[1][x_inst_row[1] > 0], return_counts=True) | ||
for p_id in unique_pred_[counts_pred_ > self.min_points]: | ||
if p_id not in self.instance_preds[scene]: | ||
self.instance_preds[scene][p_id] = 1 | ||
else: | ||
self.instance_preds[scene][p_id] += 1 | ||
|
||
# First step is to count intersections > 0.5 IoU for each class (except the ignored ones). | ||
for cl in self.include: | ||
# Previous Frame. | ||
|
@@ -336,6 +392,82 @@ def get_lstq(self) -> Tuple[np.ndarray, np.ndarray]: | |
lstq = np.sqrt(s_assoc * s_cls) | ||
return lstq, s_assoc | ||
|
||
def get_pat(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: | ||
""" | ||
Calculate Panoptic Tracking (PAT) metric. https://arxiv.org/pdf/2109.03805.pdf | ||
:return: (PAT, mean_PQ, mean_TQ). | ||
PAT: <float64, 1>, PAT score over all classes. | ||
mean_PQ: <float64, 1>, mean PQ scores over all classes. | ||
mean_TQ: <float64, 1>, mean TQ score over all classes. | ||
""" | ||
# First calculate for all classes | ||
sq_all = self.pan_iou.astype(np.double) / np.maximum(self.pan_tp.astype(np.double), self.eps) | ||
rq_all = self.pan_tp.astype(np.double) / np.maximum( | ||
self.pan_tp.astype(np.double) + 0.5 * self.pan_fp.astype(np.double) + 0.5 * self.pan_fn.astype(np.double), | ||
self.eps) | ||
pq_all = sq_all * rq_all | ||
|
||
# Then do the REAL mean (no ignored classes) | ||
pq = pq_all[self.include].mean() | ||
|
||
accumulate_tq = 0.0 | ||
accumlate_norm = 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. typo: accumulate_norm |
||
|
||
for seq in self.sequences: | ||
preds = self.instance_preds[seq] | ||
for cl in self.include: | ||
cls_gts = self.instance_gts[seq][cl] | ||
for gt_id, pr_ids in cls_gts.items(): | ||
unique_pr_id, counts_pr_id = np.unique(pr_ids, return_counts=True) | ||
|
||
track_length = len(pr_ids) | ||
# void/stuff have instance value 1 due to the +1 in ln205 as well as unmatched gt is denoted by 1 | ||
# Thus we remove 1 from the prediction id list | ||
unique_pr_id, counts_pr_id = unique_pr_id[unique_pr_id != 1], counts_pr_id[unique_pr_id != 1] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. tks for the inline comment, sigh, this looks like a hack on top of another hack.. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yup |
||
fp_pr_id = [] | ||
|
||
# Computes the total false positve for each prediction id: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. typo: positive |
||
# preds[uid]: TPA + FPA (class-agnostic) | ||
# counts_pr_id[idx]: TPA (class-agnostic) | ||
# If prediction id is not in preds it means it has number of points < self.min_points. | ||
# Similar to PQ computation we consider pred with number of points < self.min_points with IoU overlap greater than 0.5 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. change line if it exceeds 120 line width |
||
# with gt as TPA but not for FPA (the else part). | ||
for idx, uid in enumerate(unique_pr_id): | ||
if uid in preds: | ||
fp_pr_id.append(preds[uid] - counts_pr_id[idx]) | ||
else: | ||
fp_pr_id.append(0) | ||
|
||
fp_pr_id = np.array(fp_pr_id) | ||
# AQ component of TQ where counts_pr_id = TPA, track_length = TPA + FNA, fp_pr_id = FPA. | ||
gt_id_aq = np.sum(counts_pr_id ** 2 / np.double(track_length + fp_pr_id)) / np.double(track_length) | ||
# Assigns ID switch component of TQ as 1.0 if the gt instance occurs only once. | ||
gt_id_is = 1.0 | ||
|
||
if track_length > 1: | ||
# Compute the ID switch component | ||
s_id = -1 | ||
ids = 0 | ||
# Total possible id switches | ||
total_ids = track_length - 1 | ||
# Gt tracks with no corresponding prediction match are assigned 1. | ||
# We consider an id switch occurs if previous predicted id and the current one don't match for the given gt track | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. exceed 120 line width There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
# or if there is no matching prediction for the given gt track | ||
for pr_id in pr_ids: | ||
if s_id != -1: | ||
if pr_id != s_id or s_id == 1: | ||
ids += 1 | ||
s_id = pr_id | ||
gt_id_is = 1-(ids/np.double(total_ids)) | ||
# Accumulate TQ over all the possible unique gt instances | ||
accumulate_tq += np.sqrt(gt_id_aq * gt_id_is) | ||
# Count the total number of unique gt instances | ||
accumlate_norm +=1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. minor: space after +=, similarly for line 461. |
||
# Normalization | ||
tq = np.array(accumulate_tq/accumlate_norm) | ||
pat = (2 * pq * tq) / (pq + tq) | ||
return pat, pq, tq | ||
|
||
def add_batch(self, scene: str, x_sem: List[np.ndarray], x_inst: List[np.ndarray], y_sem: List[np.ndarray], | ||
y_inst: List[np.ndarray]) -> None: | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
update the function doc-string part at line 203-210 to add the added PAT, PQ, TQ fields