Skip to content

Commit e2f9ac0

Browse files
kzkadcvfdev-5
andauthored
Add rank correlation metrics (#3276)
* add SpearmanRankCorrelation metric * add KendallRankCorrelation metric * add import check of scipy * fix type hints * fix formatting error * minor modification to docstring --------- Co-authored-by: vfdev <vfdev.5@gmail.com>
1 parent e21baf9 commit e2f9ac0

File tree

6 files changed

+617
-0
lines changed

6 files changed

+617
-0
lines changed

docs/source/metrics.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,8 @@ Complete list of metrics
377377
regression.MedianAbsolutePercentageError
378378
regression.MedianRelativeAbsoluteError
379379
regression.PearsonCorrelation
380+
regression.SpearmanRankCorrelation
381+
regression.KendallRankCorrelation
380382
regression.R2Score
381383
regression.WaveHedgesDistance
382384

ignite/metrics/regression/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from ignite.metrics.regression.fractional_bias import FractionalBias
44
from ignite.metrics.regression.geometric_mean_absolute_error import GeometricMeanAbsoluteError
55
from ignite.metrics.regression.geometric_mean_relative_absolute_error import GeometricMeanRelativeAbsoluteError
6+
from ignite.metrics.regression.kendall_correlation import KendallRankCorrelation
67
from ignite.metrics.regression.manhattan_distance import ManhattanDistance
78
from ignite.metrics.regression.maximum_absolute_error import MaximumAbsoluteError
89
from ignite.metrics.regression.mean_absolute_relative_error import MeanAbsoluteRelativeError
@@ -13,4 +14,5 @@
1314
from ignite.metrics.regression.median_relative_absolute_error import MedianRelativeAbsoluteError
1415
from ignite.metrics.regression.pearson_correlation import PearsonCorrelation
1516
from ignite.metrics.regression.r2_score import R2Score
17+
from ignite.metrics.regression.spearman_correlation import SpearmanRankCorrelation
1618
from ignite.metrics.regression.wave_hedges_distance import WaveHedgesDistance
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
from typing import Any, Callable, Tuple, Union
2+
3+
import torch
4+
5+
from torch import Tensor
6+
7+
from ignite.exceptions import NotComputableError
8+
from ignite.metrics.epoch_metric import EpochMetric
9+
from ignite.metrics.regression._base import _check_output_shapes, _check_output_types
10+
11+
12+
def _get_kendall_tau(variant: str = "b") -> Callable[[Tensor, Tensor], float]:
13+
from scipy.stats import kendalltau
14+
15+
if variant not in ("b", "c"):
16+
raise ValueError(f"variant accepts 'b' or 'c', got {variant!r}.")
17+
18+
def _tau(predictions: Tensor, targets: Tensor) -> float:
19+
np_preds = predictions.flatten().numpy()
20+
np_targets = targets.flatten().numpy()
21+
r = kendalltau(np_preds, np_targets, variant=variant).statistic
22+
return r
23+
24+
return _tau
25+
26+
27+
class KendallRankCorrelation(EpochMetric):
28+
r"""Calculates the
29+
`Kendall rank correlation coefficient <https://en.wikipedia.org/wiki/Kendall_rank_correlation_coefficient>`_.
30+
31+
.. math::
32+
\tau = 1-\frac{2(\text{number of discordant pairs})}{\left( \begin{array}{c}n\\2\end{array} \right)}
33+
34+
Two prediction-target pairs :math:`(P_i, A_i)` and :math:`(P_j, A_j)`, where :math:`i<j`,
35+
are said to be concordant when both :math:`P_i<P_j` and :math:`A_i<A_j` holds
36+
or both :math:`P_i>P_j` and :math:`A_i>A_j`.
37+
38+
The `number of discordant pairs` counts the number of pairs that are not concordant.
39+
40+
The computation of this metric is implemented with
41+
`scipy.stats.kendalltau <https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.kendalltau.html>`_.
42+
43+
- ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
44+
- `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)`.
45+
46+
Parameters are inherited from ``Metric.__init__``.
47+
48+
Args:
49+
variant: variant of kendall rank correlation. ``b`` or ``c`` is accepted.
50+
Details can be found
51+
`here <https://en.wikipedia.org/wiki/Kendall_rank_correlation_coefficient#Accounting_for_ties>`_.
52+
Default: ``b``
53+
output_transform: a callable that is used to transform the
54+
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
55+
form expected by the metric. This can be useful if, for example, you have a multi-output model and
56+
you want to compute the metric with respect to one of the outputs.
57+
By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
58+
device: specifies which device updates are accumulated on. Setting the
59+
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
60+
non-blocking. By default, CPU.
61+
62+
Examples:
63+
To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine.
64+
The output of the engine's ``process_function`` needs to be in format of
65+
``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``.
66+
67+
.. include:: defaults.rst
68+
:start-after: :orphan:
69+
70+
.. testcode::
71+
72+
metric = KendallRankCorrelation()
73+
metric.attach(default_evaluator, 'kendall_tau')
74+
y_true = torch.tensor([0., 1., 2., 3., 4., 5.])
75+
y_pred = torch.tensor([0.5, 2.8, 1.9, 1.3, 6.0, 4.1])
76+
state = default_evaluator.run([[y_pred, y_true]])
77+
print(state.metrics['kendall_tau'])
78+
79+
.. testoutput::
80+
81+
0.4666666666666666
82+
"""
83+
84+
def __init__(
85+
self,
86+
variant: str = "b",
87+
output_transform: Callable[..., Any] = lambda x: x,
88+
check_compute_fn: bool = True,
89+
device: Union[str, torch.device] = torch.device("cpu"),
90+
skip_unrolling: bool = False,
91+
) -> None:
92+
try:
93+
from scipy.stats import kendalltau # noqa: F401
94+
except ImportError:
95+
raise ModuleNotFoundError("This module requires scipy to be installed.")
96+
97+
super().__init__(_get_kendall_tau(variant), output_transform, check_compute_fn, device, skip_unrolling)
98+
99+
def update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None:
100+
y_pred, y = output[0].detach(), output[1].detach()
101+
if y_pred.ndim == 1:
102+
y_pred = y_pred.unsqueeze(1)
103+
if y.ndim == 1:
104+
y = y.unsqueeze(1)
105+
106+
_check_output_shapes(output)
107+
_check_output_types(output)
108+
109+
super().update(output)
110+
111+
def compute(self) -> float:
112+
if len(self._predictions) < 1 or len(self._targets) < 1:
113+
raise NotComputableError("KendallRankCorrelation must have at least one example before it can be computed.")
114+
115+
return super().compute()
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
from typing import Any, Callable, Tuple, Union
2+
3+
import torch
4+
5+
from torch import Tensor
6+
7+
from ignite.exceptions import NotComputableError
8+
from ignite.metrics.epoch_metric import EpochMetric
9+
from ignite.metrics.regression._base import _check_output_shapes, _check_output_types
10+
11+
12+
def _get_spearman_r() -> Callable[[Tensor, Tensor], float]:
13+
from scipy.stats import spearmanr
14+
15+
def _compute_spearman_r(predictions: Tensor, targets: Tensor) -> float:
16+
np_preds = predictions.flatten().numpy()
17+
np_targets = targets.flatten().numpy()
18+
r = spearmanr(np_preds, np_targets).statistic
19+
return r
20+
21+
return _compute_spearman_r
22+
23+
24+
class SpearmanRankCorrelation(EpochMetric):
25+
r"""Calculates the
26+
`Spearman's rank correlation coefficient
27+
<https://en.wikipedia.org/wiki/Spearman%27s_rank_correlation_coefficient>`_.
28+
29+
.. math::
30+
r_\text{s} = \text{Corr}[R[P], R[A]] = \frac{\text{Cov}[R[P], R[A]]}{\sigma_{R[P]} \sigma_{R[A]}}
31+
32+
where :math:`A` and :math:`P` are the ground truth and predicted value,
33+
and :math:`R[X]` is the ranking value of :math:`X`.
34+
35+
The computation of this metric is implemented with
36+
`scipy.stats.spearmanr <https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.spearmanr.html>`_.
37+
38+
- ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
39+
- `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)`.
40+
41+
Parameters are inherited from ``Metric.__init__``.
42+
43+
Args:
44+
output_transform: a callable that is used to transform the
45+
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
46+
form expected by the metric. This can be useful if, for example, you have a multi-output model and
47+
you want to compute the metric with respect to one of the outputs.
48+
By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
49+
device: specifies which device updates are accumulated on. Setting the
50+
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
51+
non-blocking. By default, CPU.
52+
53+
Examples:
54+
To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine.
55+
The output of the engine's ``process_function`` needs to be in format of
56+
``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``.
57+
58+
.. include:: defaults.rst
59+
:start-after: :orphan:
60+
61+
.. testcode::
62+
63+
metric = SpearmanRankCorrelation()
64+
metric.attach(default_evaluator, 'spearman_corr')
65+
y_true = torch.tensor([0., 1., 2., 3., 4., 5.])
66+
y_pred = torch.tensor([0.5, 2.8, 1.9, 1.3, 6.0, 4.1])
67+
state = default_evaluator.run([[y_pred, y_true]])
68+
print(state.metrics['spearman_corr'])
69+
70+
.. testoutput::
71+
72+
0.7142857142857143
73+
"""
74+
75+
def __init__(
76+
self,
77+
output_transform: Callable[..., Any] = lambda x: x,
78+
check_compute_fn: bool = True,
79+
device: Union[str, torch.device] = torch.device("cpu"),
80+
skip_unrolling: bool = False,
81+
) -> None:
82+
try:
83+
from scipy.stats import spearmanr # noqa: F401
84+
except ImportError:
85+
raise ModuleNotFoundError("This module requires scipy to be installed.")
86+
87+
super().__init__(_get_spearman_r(), output_transform, check_compute_fn, device, skip_unrolling)
88+
89+
def update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None:
90+
y_pred, y = output[0].detach(), output[1].detach()
91+
if y_pred.ndim == 1:
92+
y_pred = y_pred.unsqueeze(1)
93+
if y.ndim == 1:
94+
y = y.unsqueeze(1)
95+
96+
_check_output_shapes(output)
97+
_check_output_types(output)
98+
99+
super().update(output)
100+
101+
def compute(self) -> float:
102+
if len(self._predictions) < 1 or len(self._targets) < 1:
103+
raise NotComputableError(
104+
"SpearmanRankCorrelation must have at least one example before it can be computed."
105+
)
106+
107+
return super().compute()

0 commit comments

Comments
 (0)