Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add average to curve metrics #2084

Merged
merged 20 commits into from
Oct 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

-
- Added `average` argument to multiclass versions of `PrecisionRecallCurve` and `ROC` ([#2084](https://github.com/Lightning-AI/torchmetrics/pull/2084))

### Changed

Expand Down
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,4 @@
.. _Completeness Score: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.completeness_score.html
.. _Davies-Bouldin Score: https://en.wikipedia.org/wiki/Davies%E2%80%93Bouldin_index
.. _Fowlkes-Mallows Index: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.fowlkes_mallows_score.html#sklearn.metrics.fowlkes_mallows_score
.. _averaging curve objects: https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html
6 changes: 4 additions & 2 deletions src/torchmetrics/classification/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,13 +266,15 @@ def __init__(
)
if validate_args:
_multiclass_auroc_arg_validation(num_classes, average, thresholds, ignore_index)
self.average = average
self.average = average # type: ignore[assignment]
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
self.validate_args = validate_args

def compute(self) -> Tensor: # type: ignore[override]
"""Compute metric."""
state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat
return _multiclass_auroc_compute(state, self.num_classes, self.average, self.thresholds)
return _multiclass_auroc_compute(
state, self.num_classes, self.average, self.thresholds # type: ignore[arg-type]
)

def plot( # type: ignore[override]
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
Expand Down
6 changes: 4 additions & 2 deletions src/torchmetrics/classification/average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,13 +264,15 @@ def __init__(
)
if validate_args:
_multiclass_average_precision_arg_validation(num_classes, average, thresholds, ignore_index)
self.average = average
self.average = average # type: ignore[assignment]
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
self.validate_args = validate_args

def compute(self) -> Tensor: # type: ignore[override]
"""Compute metric."""
state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat
return _multiclass_average_precision_compute(state, self.num_classes, self.average, self.thresholds)
return _multiclass_average_precision_compute(
state, self.num_classes, self.average, self.thresholds # type: ignore[arg-type]
)

def plot( # type: ignore[override]
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
Expand Down
25 changes: 21 additions & 4 deletions src/torchmetrics/classification/precision_recall_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ class BinaryPrecisionRecallCurve(Metric):
- If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
bins for the calculation.
ignore_index:
Specifies a target value that is ignored and does not contribute to the metric calculation
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Expand Down Expand Up @@ -266,6 +268,15 @@ class MulticlassPrecisionRecallCurve(Metric):
- If set to a 1D `tensor` of floats, will use the indicated thresholds in the tensor as
bins for the calculation.
average:
If aggregation of curves should be applied. By default, the curves are not aggregated and a curve for
each class is returned. If `average` is set to ``"micro"``, the metric will aggregate the curves by one hot
encoding the targets and flattening the predictions, considering all classes jointly as a binary problem.
If `average` is set to ``"macro"``, the metric will aggregate the curves by first interpolating the curves
from each class at a combined set of thresholds and then average over the classwise interpolated curves.
See `averaging curve objects`_ for more info on the different averaging methods.
ignore_index:
Specifies a target value that is ignored and does not contribute to the metric calculation
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Expand Down Expand Up @@ -314,15 +325,17 @@ def __init__(
self,
num_classes: int,
thresholds: Optional[Union[int, List[float], Tensor]] = None,
average: Optional[Literal["micro", "macro"]] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
if validate_args:
_multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index)
_multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index, average)

self.num_classes = num_classes
self.average = average
self.ignore_index = ignore_index
self.validate_args = validate_args

Expand All @@ -344,9 +357,11 @@ def update(self, preds: Tensor, target: Tensor) -> None:
if self.validate_args:
_multiclass_precision_recall_curve_tensor_validation(preds, target, self.num_classes, self.ignore_index)
preds, target, _ = _multiclass_precision_recall_curve_format(
preds, target, self.num_classes, self.thresholds, self.ignore_index
preds, target, self.num_classes, self.thresholds, self.ignore_index, self.average
)
state = _multiclass_precision_recall_curve_update(
preds, target, self.num_classes, self.thresholds, self.average
)
state = _multiclass_precision_recall_curve_update(preds, target, self.num_classes, self.thresholds)
if isinstance(state, Tensor):
self.confmat += state
else:
Expand All @@ -356,7 +371,7 @@ def update(self, preds: Tensor, target: Tensor) -> None:
def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
"""Compute metric."""
state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat
return _multiclass_precision_recall_curve_compute(state, self.num_classes, self.thresholds)
return _multiclass_precision_recall_curve_compute(state, self.num_classes, self.thresholds, self.average)

def plot(
self,
Expand Down Expand Up @@ -456,6 +471,8 @@ class MultilabelPrecisionRecallCurve(Metric):
- If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
bins for the calculation.
ignore_index:
Specifies a target value that is ignored and does not contribute to the metric calculation
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
Expand Down
15 changes: 14 additions & 1 deletion src/torchmetrics/classification/roc.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ class BinaryROC(BinaryPrecisionRecallCurve):
- If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
bins for the calculation.
ignore_index:
Specifies a target value that is ignored and does not contribute to the metric calculation
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Expand Down Expand Up @@ -229,6 +231,15 @@ class MulticlassROC(MulticlassPrecisionRecallCurve):
- If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
bins for the calculation.
average:
If aggregation of curves should be applied. By default, the curves are not aggregated and a curve for
each class is returned. If `average` is set to ``"micro"``, the metric will aggregate the curves by one hot
encoding the targets and flattening the predictions, considering all classes jointly as a binary problem.
If `average` is set to ``"macro"``, the metric will aggregate the curves by first interpolating the curves
from each class at a combined set of thresholds and then average over the classwise interpolated curves.
See `averaging curve objects`_ for more info on the different averaging methods.
ignore_index:
Specifies a target value that is ignored and does not contribute to the metric calculation
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Expand Down Expand Up @@ -276,7 +287,7 @@ class MulticlassROC(MulticlassPrecisionRecallCurve):
def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
"""Compute metric."""
state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat
return _multiclass_roc_compute(state, self.num_classes, self.thresholds)
return _multiclass_roc_compute(state, self.num_classes, self.thresholds, self.average)

def plot(
self,
Expand Down Expand Up @@ -381,6 +392,8 @@ class MultilabelROC(MultilabelPrecisionRecallCurve):
- If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
bins for the calculation.
ignore_index:
Specifies a target value that is ignored and does not contribute to the metric calculation
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing_extensions import Literal

from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.compute import _safe_divide
from torchmetrics.utilities.compute import _safe_divide, interp
from torchmetrics.utilities.data import _bincount, _cumsum
from torchmetrics.utilities.enums import ClassificationTask

Expand Down Expand Up @@ -363,6 +363,7 @@ def _multiclass_precision_recall_curve_arg_validation(
num_classes: int,
thresholds: Optional[Union[int, List[float], Tensor]] = None,
ignore_index: Optional[int] = None,
average: Optional[Literal["micro", "macro"]] = None,
) -> None:
"""Validate non tensor input.
Expand All @@ -373,6 +374,8 @@ def _multiclass_precision_recall_curve_arg_validation(
"""
if not isinstance(num_classes, int) or num_classes < 2:
raise ValueError(f"Expected argument `num_classes` to be an integer larger than 1, but got {num_classes}")
if average not in (None, "micro", "macro"):
raise ValueError(f"Expected argument `average` to be one of None, 'micro' or 'macro', but got {average}")
_binary_precision_recall_curve_arg_validation(thresholds, ignore_index)


Expand Down Expand Up @@ -423,6 +426,7 @@ def _multiclass_precision_recall_curve_format(
num_classes: int,
thresholds: Optional[Union[int, List[float], Tensor]] = None,
ignore_index: Optional[int] = None,
average: Optional[Literal["micro", "macro"]] = None,
) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
"""Convert all input to the right format.
Expand All @@ -443,6 +447,10 @@ def _multiclass_precision_recall_curve_format(
if not torch.all((preds >= 0) * (preds <= 1)):
preds = preds.softmax(1)

if average == "micro":
preds = preds.flatten()
target = torch.nn.functional.one_hot(target, num_classes=num_classes).flatten()

thresholds = _adjust_threshold_arg(thresholds, preds.device)
return preds, target, thresholds

Expand All @@ -452,6 +460,7 @@ def _multiclass_precision_recall_curve_update(
target: Tensor,
num_classes: int,
thresholds: Optional[Tensor],
average: Optional[Literal["micro", "macro"]] = None,
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
"""Return the state to calculate the pr-curve with.
Expand All @@ -461,6 +470,8 @@ def _multiclass_precision_recall_curve_update(
"""
if thresholds is None:
return preds, target
if average == "micro":
return _binary_precision_recall_curve_update(preds, target, thresholds)
if preds.numel() * num_classes <= 1_000_000:
update_fn = _multiclass_precision_recall_curve_update_vectorized
else:
Expand Down Expand Up @@ -520,13 +531,17 @@ def _multiclass_precision_recall_curve_compute(
state: Union[Tensor, Tuple[Tensor, Tensor]],
num_classes: int,
thresholds: Optional[Tensor],
average: Optional[Literal["micro", "macro"]] = None,
) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
"""Compute the final pr-curve.
If state is a single tensor, then we calculate the pr-curve from a multi threshold confusion matrix. If state is
original input, then we dynamically compute the binary classification curve in an iterative way.
"""
if average == "micro":
return _binary_precision_recall_curve_compute(state, thresholds)

if isinstance(state, Tensor) and thresholds is not None:
tps = state[:, :, 1, 1]
fps = state[:, :, 0, 1]
Expand All @@ -535,22 +550,45 @@ def _multiclass_precision_recall_curve_compute(
recall = _safe_divide(tps, tps + fns)
precision = torch.cat([precision, torch.ones(1, num_classes, dtype=precision.dtype, device=precision.device)])
recall = torch.cat([recall, torch.zeros(1, num_classes, dtype=recall.dtype, device=recall.device)])
return precision.T, recall.T, thresholds

precision_list, recall_list, threshold_list = [], [], []
for i in range(num_classes):
res = _binary_precision_recall_curve_compute((state[0][:, i], state[1]), thresholds=None, pos_label=i)
precision_list.append(res[0])
recall_list.append(res[1])
threshold_list.append(res[2])
return precision_list, recall_list, threshold_list
precision = precision.T
recall = recall.T
thres = thresholds
tensor_state = True
else:
precision_list, recall_list, thres_list = [], [], []
for i in range(num_classes):
res = _binary_precision_recall_curve_compute((state[0][:, i], state[1]), thresholds=None, pos_label=i)
precision_list.append(res[0])
recall_list.append(res[1])
thres_list.append(res[2])
tensor_state = False

if average == "macro":
thres = thres.repeat(num_classes) if tensor_state else torch.cat(thres_list, 0)
thres = thres.sort().values
mean_precision = precision.flatten() if tensor_state else torch.cat(precision_list, 0)
mean_precision = mean_precision.sort().values
mean_recall = torch.zeros_like(mean_precision)
for i in range(num_classes):
mean_recall += interp(
mean_precision,
precision[i] if tensor_state else precision_list[i],
recall[i] if tensor_state else recall_list[i],
)
mean_recall /= num_classes
return mean_precision, mean_recall, thres

if tensor_state:
return precision, recall, thres
return precision_list, recall_list, thres_list


def multiclass_precision_recall_curve(
preds: Tensor,
target: Tensor,
num_classes: int,
thresholds: Optional[Union[int, List[float], Tensor]] = None,
average: Optional[Literal["micro", "macro"]] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
Expand Down Expand Up @@ -590,6 +628,13 @@ def multiclass_precision_recall_curve(
- If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
bins for the calculation.
average:
If aggregation of curves should be applied. By default, the curves are not aggregated and a curve for
each class is returned. If `average` is set to ``"micro"``, the metric will aggregate the curves by one hot
encoding the targets and flattening the predictions, considering all classes jointly as a binary problem.
If `average` is set to ``"macro"``, the metric will aggregate the curves by first interpolating the curves
from each class at a combined set of thresholds and then average over the classwise interpolated curves.
See `averaging curve objects`_ for more info on the different averaging methods.
ignore_index:
Specifies a target value that is ignored and does not contribute to the metric calculation
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Expand Down Expand Up @@ -643,13 +688,18 @@ def multiclass_precision_recall_curve(
"""
if validate_args:
_multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index)
_multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index, average)
_multiclass_precision_recall_curve_tensor_validation(preds, target, num_classes, ignore_index)
preds, target, thresholds = _multiclass_precision_recall_curve_format(
preds, target, num_classes, thresholds, ignore_index
preds,
target,
num_classes,
thresholds,
ignore_index,
average,
)
state = _multiclass_precision_recall_curve_update(preds, target, num_classes, thresholds)
return _multiclass_precision_recall_curve_compute(state, num_classes, thresholds)
state = _multiclass_precision_recall_curve_update(preds, target, num_classes, thresholds, average)
return _multiclass_precision_recall_curve_compute(state, num_classes, thresholds, average)


def _multilabel_precision_recall_curve_arg_validation(
Expand Down Expand Up @@ -892,6 +942,7 @@ def precision_recall_curve(
thresholds: Optional[Union[int, List[float], Tensor]] = None,
num_classes: Optional[int] = None,
num_labels: Optional[int] = None,
average: Optional[Literal["micro", "macro"]] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
Expand Down Expand Up @@ -940,7 +991,9 @@ def precision_recall_curve(
if task == ClassificationTask.MULTICLASS:
if not isinstance(num_classes, int):
raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`")
return multiclass_precision_recall_curve(preds, target, num_classes, thresholds, ignore_index, validate_args)
return multiclass_precision_recall_curve(
preds, target, num_classes, thresholds, average, ignore_index, validate_args
)
if task == ClassificationTask.MULTILABEL:
if not isinstance(num_labels, int):
raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`")
Expand Down
Loading
Loading