Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Add KeypointEndPointError #92

Merged
merged 7 commits into from
Feb 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/en/api/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,4 @@ Metrics
ConnectivityError
DOTAMeanAP
ROUGE
KeypointEndPointError
1 change: 1 addition & 0 deletions docs/zh_cn/api/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,4 @@ Metrics
ConnectivityError
DOTAMeanAP
ROUGE
KeypointEndPointError
3 changes: 2 additions & 1 deletion mmeval/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .f1_score import F1Score
from .gradient_error import GradientError
from .hmean_iou import HmeanIoU
from .keypoint_epe import KeypointEndPointError
from .mae import MeanAbsoluteError
from .matting_mse import MattingMeanSquaredError
from .mean_iou import MeanIoU
Expand All @@ -36,7 +37,7 @@
'StructuralSimilarity', 'SignalNoiseRatio', 'MultiLabelMetric',
'AveragePrecision', 'AVAMeanAP', 'BLEU', 'DOTAMeanAP',
'SumAbsoluteDifferences', 'GradientError', 'MattingMeanSquaredError',
'ConnectivityError', 'ROUGE'
'ConnectivityError', 'ROUGE', 'KeypointEndPointError'
]

_deprecated_msg = (
Expand Down
115 changes: 115 additions & 0 deletions mmeval/metrics/keypoint_epe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import numpy as np
from typing import Dict, Sequence

from mmeval.core.base_metric import BaseMetric
from .utils import calc_distances

logger = logging.getLogger(__name__)


def keypoint_epe_accuracy(pred: np.ndarray, gt: np.ndarray,
mask: np.ndarray) -> float:
"""Calculate the end-point error.

Note:
- instance number: N
- keypoint number: K

Args:
pred (np.ndarray[N, K, 2]): Predicted keypoint location.
gt (np.ndarray[N, K, 2]): Groundtruth keypoint location.
mask (np.ndarray[N, K]): Visibility of the target. False for invisible
joints, and True for visible. Invisible joints will be ignored for
accuracy calculation.

Returns:
float: Average end-point error.
"""

distances = calc_distances(
pred, gt, mask,
np.ones((pred.shape[0], pred.shape[2]), dtype=np.float32))
distance_valid = distances[distances != -1]
return distance_valid.sum() / max(1, len(distance_valid))


class KeypointEndPointError(BaseMetric):
"""EPE evaluation metric.

Calculate the end-point error (EPE) of keypoints.

Note:
- length of dataset: N
- num_keypoints: K
- number of keypoint dimensions: D (typically D = 2)

Examples:

>>> from mmeval.metrics import KeypointEndPointError
>>> import numpy as np
>>> output = np.array([[[10., 4.],
... [10., 18.],
... [ 0., 0.],
... [40., 40.],
... [20., 10.]]])
>>> target = np.array([[[10., 0.],
... [10., 10.],
... [ 0., -1.],
... [30., 30.],
... [ 0., 10.]]])
>>> keypoints_visible = np.array([[True, True, False, True, True]])
>>> predictions = [{'coords': output}]
>>> groundtruths = [{'coords': target, 'mask': keypoints_visible}]
>>> epe_metric = KeypointEndPointError()
>>> epe_metric(predictions, groundtruths)
{'EPE': 11.535533905029297}
"""

def add(self, predictions: Sequence[Dict], groundtruths: Sequence[Dict]) -> None: # type: ignore # yapf: disable # noqa: E501
"""Process one batch of predictions and groundtruths and add the
intermediate results to `self._results`.

Args:
predictions (Sequence[dict]): Predictions from the model.
Each prediction dict has the following keys:

- coords (np.ndarray, [1, K, D]): predicted keypoints
coordinates
C1rN09 marked this conversation as resolved.
Show resolved Hide resolved

groundtruths (Sequence[dict]): The ground truth labels.
Each groundtruth dict has the following keys:

- coords (np.ndarray, [1, K, D]): ground truth keypoints
coordinates
- mask (np.ndarray, [1, K]): ground truth keypoints_visible
"""
for prediction, groundtruth in zip(predictions, groundtruths):
self._results.append((prediction, groundtruth))

def compute_metric(self, results: list) -> Dict[str, float]:
"""Compute the metrics from processed results.

Args:
results (list): The processed results of each batch.

Returns:
Dict[str, float]: The computed metrics. The keys are the names of
the metrics, and the values are corresponding results.
"""
# split gt and prediction list
preds, gts = zip(*results)

# pred_coords: [N, K, D]
pred_coords = np.concatenate([pred['coords'] for pred in preds])
# gt_coords: [N, K, D]
gt_coords = np.concatenate([gt['coords'] for gt in gts])
# mask: [N, K]
mask = np.concatenate([gt['mask'] for gt in gts])

logger.info(f'Evaluating {self.__class__.__name__}...')

epe = keypoint_epe_accuracy(pred_coords, gt_coords, mask)

return {'EPE': epe}
66 changes: 66 additions & 0 deletions tests/test_metrics/test_keypoint_epe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
from unittest import TestCase

from mmeval.metrics import KeypointEndPointError


class TestKeypointEndPointError(TestCase):

def setUp(self):
"""Setup some variables which are used in every test method.

TestCase calls functions in this order: setUp() -> testMethod() ->
tearDown() -> cleanUp()
"""
self.output = np.zeros((1, 5, 2))
self.target = np.zeros((1, 5, 2))
# first channel
self.output[0, 0] = [10, 4]
self.target[0, 0] = [10, 0]
# second channel
self.output[0, 1] = [10, 18]
self.target[0, 1] = [10, 10]
# third channel
self.output[0, 2] = [0, 0]
self.target[0, 2] = [0, -1]
# fourth channel
self.output[0, 3] = [40, 40]
self.target[0, 3] = [30, 30]
# fifth channel
self.output[0, 4] = [20, 10]
self.target[0, 4] = [0, 10]

self.keypoints_visible = np.array([[True, True, False, True, True]])

def test_epe_evaluate(self):
"""test EPE evaluation metric."""
# case 1: test normal use case
epe_metric = KeypointEndPointError()

prediction = {'coords': self.output}
groundtruth = {'coords': self.target, 'mask': self.keypoints_visible}
predictions = [prediction]
groundtruths = [groundtruth]

epe_results = epe_metric(predictions, groundtruths)
self.assertAlmostEqual(epe_results['EPE'], 11.5355339)

# case 2: use ``add`` multiple times then ``compute``
epe_metric._results = []
preds1 = [{'coords': self.output[:3]}]
preds2 = [{'coords': self.output[3:]}]
gts1 = [{
'coords': self.target[:3],
'mask': self.keypoints_visible[:3]
}]
gts2 = [{
'coords': self.target[3:],
'mask': self.keypoints_visible[3:]
}]

epe_metric.add(preds1, gts1)
epe_metric.add(preds2, gts2)

epe_results = epe_metric.compute_metric(epe_metric._results)
self.assertAlmostEqual(epe_results['EPE'], 11.5355339)