From 6965603b5d2cdb48bfe017040a2593659b6f87c2 Mon Sep 17 00:00:00 2001 From: Franklin <41602287+fcogidi@users.noreply.github.com> Date: Tue, 9 Jan 2024 18:23:04 -0500 Subject: [PATCH 1/6] Minor patches (#542) * Fix capitalization of Python * Change ValueError to TypeError and update docstrings --- .../evaluate/metrics/experimental/__init__.py | 2 +- .../metrics/experimental/functional/accuracy.py | 6 +++--- .../metrics/experimental/functional/f_score.py | 12 ++++++------ .../functional/negative_predictive_value.py | 6 +++--- .../experimental/functional/precision_recall.py | 12 ++++++------ .../experimental/functional/specificity.py | 6 +++--- cyclops/evaluate/metrics/experimental/metric.py | 5 +++-- .../evaluate/metrics/experimental/metric_dict.py | 2 +- .../metrics/experimental/utils/validation.py | 4 ++-- cyclops/models/wrappers/utils.py | 2 +- cyclops/report/model_card/base.py | 4 ++-- .../evaluate/metrics/experimental/test_metric.py | 2 +- .../metrics/experimental/test_operator_metric.py | 16 ++++++++-------- 13 files changed, 40 insertions(+), 39 deletions(-) diff --git a/cyclops/evaluate/metrics/experimental/__init__.py b/cyclops/evaluate/metrics/experimental/__init__.py index 0eb77c1ed..ed04db9fa 100644 --- a/cyclops/evaluate/metrics/experimental/__init__.py +++ b/cyclops/evaluate/metrics/experimental/__init__.py @@ -1,4 +1,4 @@ -"""Metrics for arrays that conform to the python array API standard.""" +"""Metrics for arrays that conform to the Python array API standard.""" from cyclops.evaluate.metrics.experimental.accuracy import ( BinaryAccuracy, MulticlassAccuracy, diff --git a/cyclops/evaluate/metrics/experimental/functional/accuracy.py b/cyclops/evaluate/metrics/experimental/functional/accuracy.py index f4617e919..f7d009115 100644 --- a/cyclops/evaluate/metrics/experimental/functional/accuracy.py +++ b/cyclops/evaluate/metrics/experimental/functional/accuracy.py @@ -68,7 +68,7 @@ def binary_accuracy( Raises ------ - ValueError + TypeError If the arrays `target` and `preds` are not compatible with the Python array API standard. ValueError @@ -216,7 +216,7 @@ def multiclass_accuracy( Raises ------ - ValueError + TypeError If the arrays `target` and `preds` are not compatible with the Python array API standard. ValueError @@ -378,7 +378,7 @@ def multilabel_accuracy( Raises ------ - ValueError + TypeError If the arrays `target` and `preds` are not compatible with the Python array API standard. ValueError diff --git a/cyclops/evaluate/metrics/experimental/functional/f_score.py b/cyclops/evaluate/metrics/experimental/functional/f_score.py index 401d25f5b..6cb74db4d 100644 --- a/cyclops/evaluate/metrics/experimental/functional/f_score.py +++ b/cyclops/evaluate/metrics/experimental/functional/f_score.py @@ -144,7 +144,7 @@ def binary_fbeta_score( Raises ------ - ValueError + TypeError If the arrays `target` and `preds` are not compatible with the Python array API standard. ValueError @@ -276,7 +276,7 @@ def multiclass_fbeta_score( Raises ------ - ValueError + TypeError If the arrays `target` and `preds` are not compatible with the Python array API standard. ValueError @@ -468,7 +468,7 @@ def multilabel_fbeta_score( Raises ------ - ValueError + TypeError If the arrays `target` and `preds` are not compatible with the Python array API standard. ValueError @@ -593,7 +593,7 @@ def binary_f1_score( Raises ------ - ValueError + TypeError If the arrays `target` and `preds` are not compatible with the Python array API standard. ValueError @@ -692,7 +692,7 @@ def multiclass_f1_score( Raises ------ - ValueError + TypeError If the arrays `target` and `preds` are not compatible with the Python array API standard. ValueError @@ -824,7 +824,7 @@ def multilabel_f1_score( Raises ------ - ValueError + TypeError If the arrays `target` and `preds` are not compatible with the Python array API standard. ValueError diff --git a/cyclops/evaluate/metrics/experimental/functional/negative_predictive_value.py b/cyclops/evaluate/metrics/experimental/functional/negative_predictive_value.py index 81a1fd5ed..2ca4411cc 100644 --- a/cyclops/evaluate/metrics/experimental/functional/negative_predictive_value.py +++ b/cyclops/evaluate/metrics/experimental/functional/negative_predictive_value.py @@ -87,7 +87,7 @@ def binary_npv( Raises ------ - ValueError + TypeError If the arrays `target` and `preds` are not compatible with the Python array API standard. ValueError @@ -200,7 +200,7 @@ def multiclass_npv( Raises ------ - ValueError + TypeError If the arrays `target` and `preds` are not compatible with the Python array API standard. ValueError @@ -360,7 +360,7 @@ def multilabel_npv( Raises ------ - ValueError + TypeError If the arrays `target` and `preds` are not compatible with the Python array API standard. ValueError diff --git a/cyclops/evaluate/metrics/experimental/functional/precision_recall.py b/cyclops/evaluate/metrics/experimental/functional/precision_recall.py index b4411db31..283b485cc 100644 --- a/cyclops/evaluate/metrics/experimental/functional/precision_recall.py +++ b/cyclops/evaluate/metrics/experimental/functional/precision_recall.py @@ -95,7 +95,7 @@ def binary_precision( Raises ------ - ValueError + TypeError If the arrays `target` and `preds` are not compatible with the Python array API standard. ValueError @@ -206,7 +206,7 @@ def multiclass_precision( Raises ------ - ValueError + TypeError If the arrays `target` and `preds` are not compatible with the Python array API standard. ValueError @@ -364,7 +364,7 @@ def multilabel_precision( Raises ------ - ValueError + TypeError If the arrays `target` and `preds` are not compatible with the Python array API standard. ValueError @@ -478,7 +478,7 @@ def binary_recall( Raises ------ - ValueError + TypeError If the arrays `target` and `preds` are not compatible with the Python array API standard. ValueError @@ -589,7 +589,7 @@ def multiclass_recall( Raises ------ - ValueError + TypeError If the arrays `target` and `preds` are not compatible with the Python array API standard. ValueError @@ -750,7 +750,7 @@ def multilabel_recall( Raises ------ - ValueError + TypeError If the arrays `target` and `preds` are not compatible with the Python array API standard. ValueError diff --git a/cyclops/evaluate/metrics/experimental/functional/specificity.py b/cyclops/evaluate/metrics/experimental/functional/specificity.py index 328deebc0..da27e0957 100644 --- a/cyclops/evaluate/metrics/experimental/functional/specificity.py +++ b/cyclops/evaluate/metrics/experimental/functional/specificity.py @@ -87,7 +87,7 @@ def binary_specificity( Raises ------ - ValueError + TypeError If the arrays `target` and `preds` are not compatible with the Python array API standard. ValueError @@ -199,7 +199,7 @@ def multiclass_specificity( Raises ------ - ValueError + TypeError If the arrays `target` and `preds` are not compatible with the Python array API standard. ValueError @@ -359,7 +359,7 @@ def multilabel_specificity( Raises ------ - ValueError + TypeError If the arrays `target` and `preds` are not compatible with the Python array API standard. ValueError diff --git a/cyclops/evaluate/metrics/experimental/metric.py b/cyclops/evaluate/metrics/experimental/metric.py index d3e69de05..1712fabd3 100644 --- a/cyclops/evaluate/metrics/experimental/metric.py +++ b/cyclops/evaluate/metrics/experimental/metric.py @@ -180,7 +180,7 @@ def add_state_default_factory( """ if not name.isidentifier(): raise ValueError( - f"Argument `name` must be a valid python identifier. Got `{name}`.", + f"Argument `name` must be a valid Python identifier. Got `{name}`.", ) if not callable(default_factory): raise TypeError( @@ -286,12 +286,13 @@ def update(self, *args: Any, **kwargs: Any) -> None: "not yet be defined.", ) xp = apc.get_namespace(*arrays) - self._add_states(xp) # move state variables to device of first array device = apc.device(arrays[0]) self.to_device(device) + self._add_states(xp) + self._computed = None self._update_count += 1 diff --git a/cyclops/evaluate/metrics/experimental/metric_dict.py b/cyclops/evaluate/metrics/experimental/metric_dict.py index 25dfca259..08a833522 100644 --- a/cyclops/evaluate/metrics/experimental/metric_dict.py +++ b/cyclops/evaluate/metrics/experimental/metric_dict.py @@ -49,7 +49,7 @@ class ArrayEncoder(json.JSONEncoder): def default(self, obj: Any) -> Any: """Return a JSON-serializable representation of the object. - Objects conforming to the array API standard are converted to python lists + Objects conforming to the array API standard are converted to Python lists via numpy. Arrays are moved to the CPU before converting to numpy. """ if apc.is_array_api_obj(obj): diff --git a/cyclops/evaluate/metrics/experimental/utils/validation.py b/cyclops/evaluate/metrics/experimental/utils/validation.py index db6122feb..5f02c090a 100644 --- a/cyclops/evaluate/metrics/experimental/utils/validation.py +++ b/cyclops/evaluate/metrics/experimental/utils/validation.py @@ -62,13 +62,13 @@ def _basic_input_array_checks( ) -> None: """Perform basic validation of `target` and `preds`.""" if not apc.is_array_api_obj(target): - raise ValueError( + raise TypeError( "Expected `target` to be an array-API-compatible object, but got " f"{type(target)}.", ) if not apc.is_array_api_obj(preds): - raise ValueError( + raise TypeError( "Expected `preds` to be an array-API-compatible object, but got " f"{type(preds)}.", ) diff --git a/cyclops/models/wrappers/utils.py b/cyclops/models/wrappers/utils.py index 355b1f8f6..f77747333 100644 --- a/cyclops/models/wrappers/utils.py +++ b/cyclops/models/wrappers/utils.py @@ -280,7 +280,7 @@ def set_params(cls, **params): def set_random_seed(seed: int, deterministic: bool = False) -> None: - """Set a random seed for python, numpy and PyTorch globally. + """Set a random seed for Python, numpy and PyTorch globally. Parameters ---------- diff --git a/cyclops/report/model_card/base.py b/cyclops/report/model_card/base.py index 82774f05d..235f37fba 100644 --- a/cyclops/report/model_card/base.py +++ b/cyclops/report/model_card/base.py @@ -141,12 +141,12 @@ def add_field(self, name: str, value: Any) -> None: Raises ------ ValueError - If the field name is not a valid python identifier. + If the field name is not a valid Python identifier. """ if not name.isidentifier() or keyword.iskeyword(name): raise ValueError( - f"Expected `field_name` to be a valid python identifier." + f"Expected `field_name` to be a valid Python identifier." f" Got {name} instead.", ) diff --git a/tests/cyclops/evaluate/metrics/experimental/test_metric.py b/tests/cyclops/evaluate/metrics/experimental/test_metric.py index 39c80fb75..72279f58e 100644 --- a/tests/cyclops/evaluate/metrics/experimental/test_metric.py +++ b/tests/cyclops/evaluate/metrics/experimental/test_metric.py @@ -194,7 +194,7 @@ def custom_fn(xp, _): with pytest.raises( ValueError, - match="Argument `name` must be a valid python identifier. Got `h6!`.", + match="Argument `name` must be a valid Python identifier. Got `h6!`.", ): metric.add_state_default_factory("h6!", list) # type: ignore diff --git a/tests/cyclops/evaluate/metrics/experimental/test_operator_metric.py b/tests/cyclops/evaluate/metrics/experimental/test_operator_metric.py index 0daed27b6..268cf31f8 100644 --- a/tests/cyclops/evaluate/metrics/experimental/test_operator_metric.py +++ b/tests/cyclops/evaluate/metrics/experimental/test_operator_metric.py @@ -139,7 +139,7 @@ def test_metrics_floordiv(second_operand, expected_result): """Test that `floordiv` operator works and returns an operator metric.""" first_metric = DummyMetric( anp.asarray( - 5, # python scalars can only be promoted with floating-point arrays + 5, # Python scalars can only be promoted with floating-point arrays dtype=anp.float32 if isinstance(second_operand, float) else None, ), ) @@ -164,7 +164,7 @@ def test_metrics_ge(second_operand, expected_result): """Test that `ge` operator works and returns an operator metric.""" first_metric = DummyMetric( anp.asarray( - 5, # python scalars can only be promoted with floating-point arrays + 5, # Python scalars can only be promoted with floating-point arrays dtype=anp.float32 if isinstance(second_operand, float) else None, ), ) @@ -189,7 +189,7 @@ def test_metrics_gt(second_operand, expected_result): """Test that `gt` operator works and returns an operator metric.""" first_metric = DummyMetric( anp.asarray( - 5, # python scalars can only be promoted with floating-point arrays + 5, # Python scalars can only be promoted with floating-point arrays dtype=anp.float32 if isinstance(second_operand, float) else None, ), ) @@ -224,7 +224,7 @@ def test_metrics_le(second_operand, expected_result): """Test that `le` operator works and returns an operator metric.""" first_metric = DummyMetric( anp.asarray( - 5, # python scalars can only be promoted with floating-point arrays + 5, # Python scalars can only be promoted with floating-point arrays dtype=anp.float32 if isinstance(second_operand, float) else None, ), ) @@ -249,7 +249,7 @@ def test_metrics_lt(second_operand, expected_result): """Test that `lt` operator works and returns an operator metric.""" first_metric = DummyMetric( anp.asarray( - 5, # python scalars can only be promoted with floating-point arrays + 5, # Python scalars can only be promoted with floating-point arrays dtype=anp.float32 if isinstance(second_operand, float) else None, ), ) @@ -292,7 +292,7 @@ def test_metrics_mod(second_operand, expected_result): """Test that `mod` operator works and returns an operator metric.""" first_metric = DummyMetric( anp.asarray( - 5, # python scalars can only be promoted with floating-point arrays + 5, # Python scalars can only be promoted with floating-point arrays dtype=anp.float32 if isinstance(second_operand, float) else None, ), ) @@ -317,7 +317,7 @@ def test_metrics_mul(second_operand, expected_result): """Test that `mul` operator works and returns an operator metric.""" first_metric = DummyMetric( anp.asarray( - 2, # python scalars can only be promoted with floating-point arrays + 2, # Python scalars can only be promoted with floating-point arrays dtype=anp.float32 if isinstance(second_operand, float) else None, ), ) @@ -350,7 +350,7 @@ def test_metrics_ne(second_operand, expected_result): """Test that `!=` operator works and returns an operator metric.""" first_metric = DummyMetric( anp.asarray( - 2, # python scalars can only be promoted with floating-point arrays + 2, # Python scalars can only be promoted with floating-point arrays dtype=anp.float32 if isinstance(second_operand, float) else None, ), ) From b62bc73e1c497034174648ea1eebb10f314a07f6 Mon Sep 17 00:00:00 2001 From: Franklin <41602287+fcogidi@users.noreply.github.com> Date: Tue, 9 Jan 2024 20:02:14 -0500 Subject: [PATCH 2/6] Add Precision-Recall Curve to experimental metrics (#544) --- .../evaluate/metrics/experimental/__init__.py | 5 + .../experimental/functional/__init__.py | 5 + .../functional/precision_recall_curve.py | 1182 +++++++++++++++++ .../experimental/precision_recall_curve.py | 486 +++++++ .../metrics/experimental/utils/ops.py | 70 + .../evaluate/metrics/experimental/inputs.py | 102 +- .../test_precision_recall_curve.py | 492 +++++++ 7 files changed, 2291 insertions(+), 51 deletions(-) create mode 100644 cyclops/evaluate/metrics/experimental/functional/precision_recall_curve.py create mode 100644 cyclops/evaluate/metrics/experimental/precision_recall_curve.py create mode 100644 tests/cyclops/evaluate/metrics/experimental/test_precision_recall_curve.py diff --git a/cyclops/evaluate/metrics/experimental/__init__.py b/cyclops/evaluate/metrics/experimental/__init__.py index ed04db9fa..05fea662a 100644 --- a/cyclops/evaluate/metrics/experimental/__init__.py +++ b/cyclops/evaluate/metrics/experimental/__init__.py @@ -40,6 +40,11 @@ MultilabelSensitivity, MultilabelTPR, ) +from cyclops.evaluate.metrics.experimental.precision_recall_curve import ( + BinaryPrecisionRecallCurve, + MulticlassPrecisionRecallCurve, + MultilabelPrecisionRecallCurve, +) from cyclops.evaluate.metrics.experimental.specificity import ( BinarySpecificity, BinaryTNR, diff --git a/cyclops/evaluate/metrics/experimental/functional/__init__.py b/cyclops/evaluate/metrics/experimental/functional/__init__.py index 63492c010..23b91cc37 100644 --- a/cyclops/evaluate/metrics/experimental/functional/__init__.py +++ b/cyclops/evaluate/metrics/experimental/functional/__init__.py @@ -36,6 +36,11 @@ multilabel_recall, multilabel_tpr, ) +from cyclops.evaluate.metrics.experimental.functional.precision_recall_curve import ( + binary_precision_recall_curve, + multiclass_precision_recall_curve, + multilabel_precision_recall_curve, +) from cyclops.evaluate.metrics.experimental.functional.specificity import ( binary_specificity, binary_tnr, diff --git a/cyclops/evaluate/metrics/experimental/functional/precision_recall_curve.py b/cyclops/evaluate/metrics/experimental/functional/precision_recall_curve.py new file mode 100644 index 000000000..609548cf8 --- /dev/null +++ b/cyclops/evaluate/metrics/experimental/functional/precision_recall_curve.py @@ -0,0 +1,1182 @@ +"""Functions for computing the precision and recall for different unique thresholds.""" +from types import ModuleType +from typing import Any, List, Literal, Optional, Sequence, Tuple, Union + +import array_api_compat as apc +import numpy as np + +from cyclops.evaluate.metrics.experimental.utils.ops import ( + _array_indexing, + _cumsum, + _interp, + _to_one_hot, + bincount, + clone, + flatten, + moveaxis, + remove_ignore_index, + safe_divide, + sigmoid, + softmax, + to_int, +) +from cyclops.evaluate.metrics.experimental.utils.types import Array +from cyclops.evaluate.metrics.experimental.utils.validation import ( + _basic_input_array_checks, + _check_same_shape, + is_floating_point, +) + + +def _validate_thresholds(thresholds: Optional[Union[int, List[float], Array]]) -> None: + """Validate the `thresholds` argument.""" + if thresholds is not None and not ( + isinstance(thresholds, (int, list)) or apc.is_array_api_obj(thresholds) + ): + raise TypeError( + "Expected argument `thresholds` to either be an integer, a list of floats or " + f"an Array of floats, but got {thresholds}", + ) + if isinstance(thresholds, int) and thresholds < 2: + raise ValueError( + "Expected argument `thresholds` to be an integer greater than 1, " + f"but got {thresholds}", + ) + if isinstance(thresholds, list): + if not all(isinstance(t, float) and 0 <= t <= 1 for t in thresholds): + raise ValueError( + "Expected argument `thresholds` to be a list of floats in the [0,1] range, " + f"but got {thresholds}", + ) + if not all(np.diff(thresholds) > 0): + raise ValueError( + "Expected argument `thresholds` to be monotonically increasing," + f" but got {thresholds}", + ) + + if apc.is_array_api_obj(thresholds): + xp = apc.array_namespace(thresholds) + if not xp.all((thresholds >= 0) & (thresholds <= 1)): # type: ignore + raise ValueError( + "Expected argument `thresholds` to be an Array of floats in the [0,1] " + f"range, but got {thresholds}", + ) + if not thresholds.ndim == 1: # type: ignore + raise ValueError( + "Expected argument `thresholds` to be a 1D Array, but got an Array with " + f"{thresholds.ndim} dimensions", # type: ignore + ) + + +def _binary_precision_recall_curve_validate_args( + thresholds: Optional[Union[int, List[float], Array]], + ignore_index: Optional[int], +) -> None: + """Validate the arguments for the `binary_precision_recall_curve` function.""" + _validate_thresholds(thresholds) + if ignore_index is not None and not isinstance(ignore_index, int): + raise ValueError( + "Expected argument `ignore_index` to either be `None` or an integer, " + f"but got {ignore_index}", + ) + + +def _binary_precision_recall_curve_validate_arrays( + target: Array, + preds: Array, + thresholds: Optional[Union[int, List[float], Array]], + ignore_index: Optional[int], +) -> ModuleType: + """Validate the arrays for the `binary_precision_recall_curve` function.""" + _basic_input_array_checks(target, preds) + _check_same_shape(target, preds) + + if is_floating_point(target): + raise ValueError( + "Expected argument `target` to be an Array of integers representing " + f"binary groundtruth labels, but got tensor with dtype {target.dtype}", + ) + + if not is_floating_point(preds): + raise ValueError( + "Expected argument `preds` to be an floating tensor with probability/logit scores," + f" but got tensor with dtype {preds.dtype}", + ) + + xp: ModuleType = apc.array_namespace(target, preds) + # check that target only contains {0,1} values or value in ignore_index + unique_values = xp.unique_values(target) + if ignore_index is None: + check = xp.any((unique_values != 0) & (unique_values != 1)) + else: + check = xp.any( + (unique_values != 0) + & (unique_values != 1) + & (unique_values != ignore_index), + ) + if check: + raise RuntimeError( + "Expected only the following values " + f"{[0, 1] if ignore_index is None else [ignore_index]} in `target`. " + f"But found the following values: {unique_values}", + ) + + if apc.is_array_api_obj(thresholds) and xp != apc.array_namespace(thresholds): + raise ValueError( + "Expected the array API namespace of `target` and `preds` to be the same as " + f"the array API namespace of `thresholds`, but got {xp} and " + f"{apc.array_namespace(thresholds)}", + ) + + return xp + + +def _format_thresholds( + thresholds: Optional[Union[int, List[float], Array]] = None, + device: Optional[Any] = None, + *, + xp: ModuleType, +) -> Optional[Array]: + """Convert the `thresholds` argument to an Array.""" + if isinstance(thresholds, int): + return xp.linspace( # type: ignore[no-any-return] + 0, + 1, + thresholds, + dtype=xp.float32, + device=device, + ) + if isinstance(thresholds, list): + return xp.asarray( # type: ignore[no-any-return] + thresholds, + dtype=xp.float32, + device=device, + ) + return thresholds + + +def _binary_precision_recall_curve_format_arrays( + target: Array, + preds: Array, + thresholds: Optional[Union[int, List[float], Array]], + ignore_index: Optional[int], + *, + xp: ModuleType, +) -> Tuple[Array, Array, Optional[Array]]: + """Format the arrays for the `binary_precision_recall_curve` function.""" + preds = flatten(preds) + target = flatten(target) + + if ignore_index is not None: + target, preds = remove_ignore_index(target, preds, ignore_index=ignore_index) + + if not xp.all(to_int((preds >= 0)) * to_int((preds <= 1))): # preds are logits + preds = sigmoid(preds) + + thresholds = _format_thresholds(thresholds, device=apc.device(preds), xp=xp) + return target, preds, thresholds + + +def _binary_precision_recall_curve_update( + target: Array, + preds: Array, + thresholds: Optional[Array], + *, + xp: ModuleType, +) -> Union[Array, Tuple[Array, Array]]: + """Update the state for the `binary_precision_recall_curve` function.""" + if thresholds is None: + return target, preds + + len_t = int(apc.size(thresholds) or 0) + target = target == 1 + confmat = xp.empty((len_t, 2, 2), dtype=xp.int32, device=apc.device(preds)) + + for i in range(len_t): + preds_t = preds >= thresholds[i] + confmat[i, 1, 1] = xp.sum(to_int(target & preds_t)) + confmat[i, 0, 1] = xp.sum(to_int(((~target) & preds_t))) + confmat[i, 1, 0] = xp.sum(to_int((target & (~preds_t)))) + confmat[:, 0, 0] = ( + preds_t.shape[0] - confmat[:, 0, 1] - confmat[:, 1, 0] - confmat[:, 1, 1] + ) + return confmat # type: ignore[no-any-return] + + +def _binary_clf_curve( + target: Array, + preds: Array, + sample_weights: Optional[Union[Sequence[float], Array]] = None, + pos_label: int = 1, +) -> Tuple[Array, Array, Array]: + """Compute the TPs and FPs for all unique thresholds in the `preds` Array. + + Adapted from + https://github.com/Lightning-AI/torchmetrics/blob/master/src/torchmetrics/functional/classification/precision_recall_curve.py#L28. + """ + xp = apc.array_namespace(target, preds) + if sample_weights is not None and not apc.is_array_api_obj(sample_weights): + sample_weights = xp.asarray( + sample_weights, + device=apc.device(preds), + dtype=xp.float32, + ) + + # remove class dimension if necessary + if preds.ndim > target.ndim: + preds = preds[:, 0] + + # sort preds in descending order + sort_index = xp.argsort(preds, descending=True) + preds = _array_indexing(preds, sort_index) + target = _array_indexing(target, sort_index) + weight = ( + _array_indexing(sample_weights, sort_index) # type: ignore[arg-type] + if sample_weights is not None + else xp.asarray(1, device=apc.device(preds), dtype=xp.float32) + ) + + # extract indices of distinct values in preds to avoid ties + distinct_value_indices = ( + xp.nonzero(preds[1:] - preds[:-1])[0] + if int(apc.size(preds) or 0) > 1 + else xp.empty(0, dtype=xp.int32, device=apc.device(preds)) + ) + + # concatenate a value for the end of the curve + threshold_idxs = xp.concat( + [ + distinct_value_indices, + xp.asarray( + [int(apc.size(target) or 0) - 1], + dtype=xp.int32, + device=apc.device(preds), + ), + ], + ) + + target = xp.astype(target == pos_label, xp.float32, copy=False) + tps = _array_indexing(_cumsum(target * weight, axis=0), threshold_idxs) + if sample_weights is not None: + # express fps as a cumsum to ensure fps is increasing even in + # the presence of floating point errors + fps = _array_indexing(_cumsum((1 - target) * weight, axis=0), threshold_idxs) + else: + fps = 1 + xp.astype(threshold_idxs, xp.float32) - tps + + return fps, tps, _array_indexing(preds, threshold_idxs) + + +def _binary_precision_recall_curve_compute( + state: Union[Array, Tuple[Array, Array]], + thresholds: Optional[Array], + pos_label: int = 1, +) -> Tuple[Array, Array, Array]: + """Compute the precision and recall for all unique thresholds.""" + if apc.is_array_api_obj(state) and thresholds is not None: + xp = apc.array_namespace(state, thresholds) + tps = state[:, 1, 1] # type: ignore[call-overload] + fps = state[:, 0, 1] # type: ignore[call-overload] + fns = state[:, 1, 0] # type: ignore[call-overload] + precision = safe_divide(tps, tps + fps) + recall = safe_divide(tps, tps + fns) + precision = xp.concat( + [ + precision, + xp.ones(1, dtype=precision.dtype, device=apc.device(precision)), + ], + ) + recall = xp.concat( + [recall, xp.zeros(1, dtype=recall.dtype, device=apc.device(recall))], + ) + return precision, recall, thresholds + + fps, tps, thresholds = _binary_clf_curve(state[0], state[1], pos_label=pos_label) + precision = tps / (tps + fps) + recall = tps / tps[-1] + + xp = apc.array_namespace(precision, recall) + + # need to call reversed explicitly, since including that to slice would + # introduce negative strides that are not yet supported in pytorch + precision = xp.concat( + [ + xp.flip(precision, axis=0), + xp.ones(1, dtype=precision.dtype, device=apc.device(precision)), + ], + ) + recall = xp.concat( + [ + xp.flip(recall, axis=0), + xp.zeros(1, dtype=recall.dtype, device=apc.device(recall)), + ], + ) + thresholds = xp.flip(thresholds, axis=0) + if hasattr(thresholds, "detach"): + thresholds = clone(thresholds.detach()) # type: ignore + return precision, recall, thresholds # type: ignore[return-value] + + +def binary_precision_recall_curve( + target: Array, + preds: Array, + thresholds: Optional[Union[int, List[float], Array]] = None, + ignore_index: Optional[int] = None, +) -> Tuple[Array, Array, Array]: + """Compute the precision and recall for all unique thresholds. + + Parameters + ---------- + target : Array + An array object that is compatible with the Python array API standard + and contains the ground truth labels in the range [0, 1]. The expected + shape of the array is `(N, ...)`, where `N` is the number of samples. + preds : Array + An array object that is compatible with the Python array API standard and + contains the probability/logit scores for the positive class. The expected + shape of the array is `(N, ...)` where `N` is the number of samples. If + `preds` contains floating point values that are not in the range `[0, 1]`, + a sigmoid function will be applied to each value before thresholding. + thresholds : Union[int, List[float], Array], optional, default=None + The thresholds to use for computing the precision and recall. Can be one + of the following: + - `None`: use all unique values in `preds` as thresholds. + - `int`: use `int` (larger than 1) uniformly spaced thresholds in the range + [0, 1]. + - `List[float]`: use the values in the list as bins for the thresholds. + - `Array`: use the values in the Array as bins for the thresholds. The + array must be 1D. + ignore_index : int, optional, default=None + The value in `target` that should be ignored when computing the precision + and recall. If `None`, all values in `target` are used. + + Returns + ------- + precision : Array + The precision values for all unique thresholds. The shape of the array is + `(num_thresholds + 1,)`. + recall : Array + The recall values for all unique thresholds. The shape of the array is + `(num_thresholds + 1,)`. + thresholds : Array + The thresholds used for computing the precision and recall values, in + ascending order. The shape of the array is `(num_thresholds,)`. + + Raises + ------ + TypeError + If `thresholds` is not `None` and not an integer, a list of floats or an + Array of floats. + ValueError + If `thresholds` is an integer and smaller than 2. + ValueError + If `thresholds` is a list of floats with values outside the range [0, 1] + and not monotonically increasing. + ValueError + If `thresholds` is an Array of floats and not all values are in the range + [0, 1] or the array is not one-dimensional. + ValueError + If `ignore_index` is not `None` or an integer. + TypeError + If `target` and `preds` are not compatible with the Python array API standard. + ValueError + If `target` and `preds` are empty. + ValueError + If `target` and `preds` are not numeric arrays. + ValueError + If `target` and `preds` do not have the same shape. + ValueError + If `target` contains floating point values. + ValueError + If `preds` contains non-floating point values. + RuntimeError + If `target` contains values outside the range [0, 1] or does not contain + `ignore_index` if `ignore_index` is not `None`. + ValueError + If the array API namespace of `target` and `preds` are not the same as the + array API namespace of `thresholds`. + + Examples + -------- + >>> from cyclops.evaluate.metrics.experimental.functional import ( + ... binary_precision_recall_curve + ... ) + >>> import numpy.array_api as anp + >>> target = anp.asarray([0, 1, 0, 1, 0, 1]) + >>> preds = anp.asarray([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) + >>> precision, recall, thresholds = binary_precision_recall_curve( + ... target, preds, thresholds=None, + ... ) + >>> precision + Array([0.5 , 0.6 , 0.5 , 0.6666667, + 0.5 , 1. , 1. ], dtype=float32) + >>> recall + Array([1. , 1. , 0.6666667 , 0.6666667 , + 0.33333334, 0.33333334, 0. ], dtype=float32) + >>> thresholds + Array([0.11, 0.22, 0.33, 0.73, 0.84, 0.92], dtype=float64) + >>> precision, recall, thresholds = binary_precision_recall_curve( + ... target, preds, thresholds=5, + ... ) + >>> precision + Array([0.5 , 0.5 , 0.6666667, 0.5 , + 0. , 1. ], dtype=float32) + >>> recall + Array([1. , 0.6666667 , 0.6666667 , 0.33333334, + 0. , 0. ], dtype=float32) + >>> thresholds + Array([0. , 0.25, 0.5 , 0.75, 1. ], dtype=float32) + + """ + _binary_precision_recall_curve_validate_args(thresholds, ignore_index) + xp = _binary_precision_recall_curve_validate_arrays( + target, + preds, + thresholds, + ignore_index, + ) + target, preds, thresholds = _binary_precision_recall_curve_format_arrays( + target, + preds, + thresholds, + ignore_index, + xp=xp, + ) + state = _binary_precision_recall_curve_update(target, preds, thresholds, xp=xp) + return _binary_precision_recall_curve_compute(state, thresholds) + + +def _multiclass_precision_recall_curve_validate_args( + num_classes: int, + thresholds: Optional[Union[int, List[float], Array]] = None, + ignore_index: Optional[Union[int, Tuple[int]]] = None, + average: Optional[Literal["macro", "micro", "none"]] = None, +) -> None: + """Validate the arguments for the `multiclass_precision_recall_curve` function.""" + _validate_thresholds(thresholds) + if not isinstance(num_classes, int) or num_classes < 2: + raise ValueError( + "Expected argument `num_classes` to be an integer larger than 1, " + f"but got {num_classes}.", + ) + if ignore_index is not None and not ( + isinstance(ignore_index, int) + or ( + isinstance(ignore_index, tuple) + and all(isinstance(i, int) for i in ignore_index) + ) + ): + raise ValueError( + "Expected argument `ignore_index` to either be `None`, an integer, " + f"or a tuple of integers but got {ignore_index}", + ) + allowed_average = ("micro", "macro", "none", None) + if average not in allowed_average: + raise ValueError( + f"Expected argument `average` to be one of {allowed_average}, " + f"but got {average}", + ) + + +def _multiclass_precision_recall_curve_validate_arrays( + target: Array, + preds: Array, + num_classes: int, + ignore_index: Optional[Union[int, Tuple[int]]], +) -> ModuleType: + """Validate the arrays for the `multiclass_precision_recall_curve` function.""" + _basic_input_array_checks(target, preds) + if not preds.ndim == target.ndim + 1: + raise ValueError( + f"Expected `preds` to have one more dimension than `target` but got {preds.ndim} and {target.ndim}", + ) + if is_floating_point(target): + raise ValueError( + "Expected argument `target` to be an integer array, but got array " + f"with dtype {target.dtype}", + ) + if not is_floating_point(preds): + raise ValueError( + f"Expected `preds` to be an array with floating point values, but got " + f"array with dtype {preds.dtype}", + ) + if preds.shape[1] != num_classes: + raise ValueError( + f"Expected `preds.shape[1]={preds.shape[1]}` to be equal to " + f"`num_classes={num_classes}`", + ) + if preds.shape[0] != target.shape[0] or preds.shape[2:] != target.shape[1:]: + raise ValueError( + "Expected the shape of `preds` should be (N, C, ...) and the shape of " + f"`target` should be (N, ...) but got {preds.shape} and {target.shape}", + ) + + xp = apc.array_namespace(target, preds) + num_unique_values = xp.unique_values(target).shape[0] + num_allowed_extra_values = 0 + if ignore_index is not None: + num_allowed_extra_values = ( + 1 if isinstance(ignore_index, int) else len(ignore_index) + ) + check = ( + num_unique_values > num_classes + if ignore_index is None + else num_unique_values > num_classes + num_allowed_extra_values + ) + if check: + raise RuntimeError( + f"Expected only {num_classes if ignore_index is None else num_classes + num_allowed_extra_values} " + f"values in `target` but found {num_unique_values} values.", + ) + + return xp # type: ignore[no-any-return] + + +def _multiclass_precision_recall_curve_format_arrays( + target: Array, + preds: Array, + num_classes: int, + thresholds: Optional[Union[int, List[float], Array]], + ignore_index: Optional[Union[int, Tuple[int]]], + average: Optional[Literal["macro", "micro", "none"]], + *, + xp: ModuleType, +) -> Tuple[Array, Array, Optional[Array]]: + """Format the arrays for the `multiclass_precision_recall_curve` function.""" + preds = xp.reshape(moveaxis(preds, 0, 1), (num_classes, -1)).T + target = flatten(target) + + if ignore_index is not None: + target, preds = remove_ignore_index(target, preds, ignore_index=ignore_index) + + if not xp.all(to_int(preds >= 0) * to_int(preds <= 1)): + preds = softmax(preds, axis=1) + + if average == "micro": + preds = flatten(preds) + target = flatten(_to_one_hot(target, num_classes=num_classes)) + + thresholds = _format_thresholds(thresholds, device=apc.device(preds), xp=xp) + return target, preds, thresholds + + +def _multiclass_precision_recall_curve_update( + target: Array, + preds: Array, + num_classes: int, + thresholds: Optional[Array], + average: Optional[Literal["macro", "micro", "none"]] = None, + *, + xp: ModuleType, +) -> Union[Array, Tuple[Array, Array]]: + """Update the state for the `multiclass_precision_recall_curve` function.""" + if thresholds is None: + return target, preds + + if average == "micro": + return _binary_precision_recall_curve_update(target, preds, thresholds, xp=xp) + + len_t = thresholds.shape[0] if thresholds.ndim > 0 else 1 + preds_t = to_int( + ( + xp.expand_dims(preds, axis=-1) + >= xp.expand_dims(xp.expand_dims(thresholds, axis=0), axis=0) + ), + ) + target_t = _to_one_hot(target, num_classes=num_classes) + unique_mapping = preds_t + 2 * xp.expand_dims(to_int(target_t), axis=-1) + unique_mapping += 4 * xp.expand_dims( + xp.expand_dims(xp.arange(num_classes, device=apc.device(preds)), axis=0), + axis=-1, + ) + unique_mapping += 4 * num_classes * xp.arange(len_t, device=apc.device(preds)) + bins = bincount(flatten(unique_mapping), minlength=4 * num_classes * len_t) + return xp.reshape(xp.astype(bins, xp.int32, copy=False), (len_t, num_classes, 2, 2)) # type: ignore[no-any-return] + + +def _multiclass_precision_recall_curve_compute( + state: Union[Array, Tuple[Array, Array]], + num_classes: int, + thresholds: Optional[Array], + average: Optional[Literal["macro", "micro", "none"]], +) -> Union[Tuple[Array, Array, Array], Tuple[List[Array], List[Array], List[Array]]]: + """Compute the precision and recall for all unique thresholds.""" + if average == "micro": + return _binary_precision_recall_curve_compute(state, thresholds) + + if apc.is_array_api_obj(state) and thresholds is not None: + xp = apc.array_namespace(state, thresholds) + tps = state[:, :, 1, 1] # type: ignore[call-overload] + fps = state[:, :, 0, 1] # type: ignore[call-overload] + fns = state[:, :, 1, 0] # type: ignore[call-overload] + precision = safe_divide(tps, tps + fps) + recall = safe_divide(tps, tps + fns) + precision = xp.concat( + [ + precision, + xp.ones( + (1, num_classes), + dtype=precision.dtype, + device=apc.device(precision), + ), + ], + ) + recall = xp.concat( + [ + recall, + xp.zeros( + (1, num_classes), + dtype=recall.dtype, + device=apc.device(recall), + ), + ], + ) + precision = precision.T + recall = recall.T + thres = thresholds + array_state = True + else: + xp = apc.array_namespace(state[0], state[1]) + precision_list, recall_list, thres_list = [], [], [] + for i in range(num_classes): + res = _binary_precision_recall_curve_compute( + (state[0], state[1][:, i]), + thresholds=None, + pos_label=i, + ) + precision_list.append(res[0]) + recall_list.append(res[1]) + thres_list.append(res[2]) + array_state = False + + if average == "macro": + thres = ( + xp.concat([xp.expand_dims(thres, axis=0)] * num_classes, axis=0) # repeat + if array_state + else xp.concat(xp.asarray(thres_list), 0) + ) + thres = xp.sort(thres) + mean_precision = ( + flatten(precision) + if array_state + else xp.concat(xp.asarray(precision_list), 0) + ) + mean_precision = xp.sort(mean_precision) + mean_recall = xp.zeros_like(mean_precision) + for i in range(num_classes): + mean_recall += _interp( + mean_precision, + precision[i] if array_state else precision_list[i], + recall[i] if array_state else recall_list[i], + ) + mean_recall /= num_classes + return mean_precision, mean_recall, thres + + if array_state: + return precision, recall, thres + return precision_list, recall_list, thres_list + + +def multiclass_precision_recall_curve( + target: Array, + preds: Array, + num_classes: int, + thresholds: Optional[Union[int, List[float], Array]] = None, + average: Optional[Literal["macro", "micro", "none"]] = None, + ignore_index: Optional[Union[int, Tuple[int]]] = None, +) -> Union[Tuple[Array, Array, Array], Tuple[List[Array], List[Array], List[Array]]]: + """Compute the precision and recall for all unique thresholds. + + Parameters + ---------- + target : Array + An array object that is compatible with the Python array API standard + and contains the ground truth labels in the range [0, `num_classes`] + (except if `ignore_index` is specified). The expected shape of the array + is `(N, ...)`, where `N` is the number of samples. + preds : Array + An array object that is compatible with the Python array API standard and + contains the probability/logit scores for each sample. The expected shape + of the array is `(N, C, ...)` where `N` is the number of samples and `C` + is the number of classes. If `preds` contains floating point values that + are not in the range `[0, 1]`, a softmax function will be applied to each + value before thresholding. + num_classes : int + The number of classes in the classification problem. + thresholds : Union[int, List[float], Array], optional, default=None + The thresholds to use for computing the precision and recall. Can be one + of the following: + - `None`: use all unique values in `preds` as thresholds. + - `int`: use `int` (larger than 1) uniformly spaced thresholds in the range + [0, 1]. + - `List[float]`: use the values in the list as bins for the thresholds. + - `Array`: use the values in the Array as bins for the thresholds. The + array must be 1D. + average : {"macro", "micro", "none"}, optional, default=None + The type of averaging to use for computing the precision and recall. Can + be one of the following: + - `"macro"`: interpolates the curves from each class at a combined set of + thresholds and then average over the classwise interpolated curves. + - `"micro"`: one-hot encodes the targets and flattens the predictions, + considering all classes jointly as a binary problem. + - `"none"`: do not average over the classwise curves. + ignore_index : int or Tuple[int], optional, default=None + The value(s) in `target` that should be ignored when computing the + precision and recall. If `None`, all values in `target` are used. + + Returns + ------- + precision : Array or List[Array] + The precision values for all unique thresholds. If `thresholds` is `"none"` + or `None`, a list for each class is returned with 1-D Arrays of shape + `(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape + `(num_thresholds + 1, num_classes)` is returned. + recall : Array or List[Array] + The recall values for all unique thresholds. If `thresholds` is `"none"` + or `None`, a list for each class is returned with 1-D Arrays of shape + `(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape + `(num_thresholds + 1, num_classes)` is returned. + thresholds : Array or List[Array] + The thresholds used for computing the precision and recall values, in + ascending order. If `thresholds` is `"none"` or `None`, a list for each + class is returned with 1-D Arrays of shape `(num_thresholds,)`. Otherwise, + a 1-D Array of shape `(num_thresholds,)` is returned. + + Raises + ------ + TypeError + If `thresholds` is not `None` and not an integer, a list of floats or an + Array of floats. + ValueError + If `thresholds` is an integer and smaller than 2. + ValueError + If `thresholds` is a list of floats with values outside the range [0, 1] + and not monotonically increasing. + ValueError + If `thresholds` is an Array of floats and not all values are in the range + [0, 1] or the array is not one-dimensional. + ValueError + If `num_classes` is not an integer larger than 1. + ValueError + If `ignore_index` is not `None`, an integer or a tuple of integers. + ValueError + If `average` is not `"macro"`, `"micro"`, `"none"` or `None`. + TypeError + If `target` and `preds` are not compatible with the Python array API standard. + ValueError + If `target` and `preds` are empty. + ValueError + If `target` and `preds` are not numeric arrays. + ValueError + If `preds` does not have one more dimension than `target`. + ValueError + If `target` contains floating point values. + ValueError + If `preds` contains non-floating point values. + ValueError + If the second dimension of `preds` is not equal to `num_classes`. + ValueError + If the first dimension of `preds` is not equal to the first dimension of + `target` or the third dimension of `preds` is not equal to the second + dimension of `target`. + RuntimeError + If `target` contains more unique values than `num_classes` or `num_classes` + plus the number of values in `ignore_index` if `ignore_index` is not `None`. + + Examples + -------- + >>> from cyclops.evaluate.metrics.experimental.functional import ( + ... multiclass_precision_recall_curve + ... ) + >>> import numpy.array_api as anp + >>> target = anp.asarray([0, 1, 2, 0, 1, 2]) + >>> preds = anp.asarray( + ... [[0.11, 0.22, 0.67], + ... [0.84, 0.73, 0.12], + ... [0.33, 0.92, 0.44], + ... [0.11, 0.22, 0.67], + ... [0.84, 0.73, 0.12], + ... [0.33, 0.92, 0.44]]) + >>> precision, recall, thresholds = multiclass_precision_recall_curve( + ... target, preds, num_classes=3, thresholds=None, + ... ) + >>> precision + [Array([0.33333334, 0. , 0. , 1. ], dtype=float32), + Array([0.33333334, 0.5 , 0. , 1. ], dtype=float32), + Array([0.33333334, 0.5 , 0. , 1. ], dtype=float32)] + >>> recall + [Array([1., 0., 0., 0.], dtype=float32), Array([1., 1., 0., 0.], dtype=float32), Array([1., 1., 0., 0.], dtype=float32)] + >>> thresholds + [Array([0.11, 0.33, 0.84], dtype=float64), Array([0.22, 0.73, 0.92], dtype=float64), Array([0.12, 0.44, 0.67], dtype=float64)] + >>> precision, recall, thresholds = multiclass_precision_recall_curve( + ... target, preds, num_classes=3, thresholds=5, + ... ) + >>> precision + Array([[0.33333334, 0. , 0. , 0. , + 0. , 1. ], + [0.33333334, 0.5 , 0.5 , 0. , + 0. , 1. ], + [0.33333334, 0.5 , 0. , 0. , + 0. , 1. ]], dtype=float32) + >>> recall + Array([[1., 0., 0., 0., 0., 0.], + [1., 1., 1., 0., 0., 0.], + [1., 1., 0., 0., 0., 0.]], dtype=float32) + >>> thresholds + Array([0. , 0.25, 0.5 , 0.75, 1. ], dtype=float32) + + """ # noqa: W505 + _multiclass_precision_recall_curve_validate_args( + num_classes, + thresholds, + ignore_index, + average, + ) + xp = _multiclass_precision_recall_curve_validate_arrays( + target, + preds, + num_classes, + ignore_index, + ) + target, preds, thresholds = _multiclass_precision_recall_curve_format_arrays( + target, + preds, + num_classes, + thresholds, + ignore_index, + average, + xp=xp, + ) + state = _multiclass_precision_recall_curve_update( + target, + preds, + num_classes, + thresholds, + average, + xp=xp, + ) + return _multiclass_precision_recall_curve_compute( + state, + num_classes, + thresholds=thresholds, + average=average, + ) + + +def _multilabel_precision_recall_curve_validate_args( + num_labels: int, + thresholds: Optional[Union[int, List[float], Array]], + ignore_index: Optional[int], +) -> None: + """Validate the arguments for the `multilabel_precision_recall_curve` function.""" + if not isinstance(num_labels, int) or num_labels < 2: + raise ValueError( + "Expected argument `num_labels` to be an integer larger than 1, " + f"but got {num_labels}.", + ) + _binary_precision_recall_curve_validate_args(thresholds, ignore_index) + + +def _multilabel_precision_recall_curve_validate_arrays( + target: Array, + preds: Array, + num_labels: int, + thresholds: Optional[Union[int, List[float], Array]], + ignore_index: Optional[int], +) -> ModuleType: + xp = _binary_precision_recall_curve_validate_arrays( + target, + preds, + thresholds, + ignore_index=ignore_index, + ) + + if preds.shape[1] != num_labels: + raise ValueError( + "Expected both `target.shape[1]` and `preds.shape[1]` to be equal to the number of labels " + f"but got {preds.shape[1]} and expected {num_labels}, respectively.", + ) + + return xp + + +def _multilabel_precision_recall_curve_format_arrays( + target: Array, + preds: Array, + num_labels: int, + thresholds: Optional[Union[int, List[float], Array]], + ignore_index: Optional[int], + *, + xp: ModuleType, +) -> Tuple[Array, Array, Optional[Array]]: + """Format the arrays for the `multilabel_precision_recall_curve` function.""" + preds = xp.reshape(moveaxis(preds, 0, 1), (num_labels, -1)).T + target = xp.reshape(moveaxis(target, 0, 1), (num_labels, -1)).T + if not xp.all(to_int(preds >= 0) * to_int(preds <= 1)): + preds = sigmoid(preds) + + thresholds = _format_thresholds(thresholds, device=apc.device(preds), xp=xp) + if ignore_index is not None and thresholds is not None: + preds = clone(preds) + target = clone(target) + # make sure that when we map, it will always result in a negative number + # that we can filter away + idx = target == ignore_index + preds[idx] = -4 * num_labels * thresholds.shape[0] if thresholds.ndim > 0 else 1 + target[idx] = ( + -4 * num_labels * thresholds.shape[0] if thresholds.ndim > 0 else 1 + ) + + return target, preds, thresholds + + +def _multilabel_precision_recall_curve_update( + target: Array, + preds: Array, + num_labels: int, + thresholds: Optional[Array], + *, + xp: ModuleType, +) -> Union[Array, Tuple[Array, Array]]: + """Update the state for the `multilabel_precision_recall_curve` function.""" + if thresholds is None: + return target, preds + + len_t = thresholds.shape[0] if thresholds.ndim > 0 else 1 + # num_samples x num_labels x num_thresholds + preds_t = to_int( + xp.expand_dims(xp.astype(preds, xp.float32, copy=False), axis=-1) + >= xp.expand_dims(xp.expand_dims(thresholds, axis=0), axis=0), + ) + unique_mapping = preds_t + 2 * xp.expand_dims(to_int(target), axis=-1) + unique_mapping += 4 * xp.expand_dims( + xp.expand_dims(xp.arange(num_labels, device=apc.device(preds)), axis=0), + axis=-1, + ) + unique_mapping += 4 * num_labels * xp.arange(len_t, device=apc.device(preds)) + unique_mapping = unique_mapping[unique_mapping >= 0] + bins = bincount(unique_mapping, minlength=4 * num_labels * len_t) + return xp.reshape(xp.astype(bins, xp.int32, copy=False), (len_t, num_labels, 2, 2)) # type: ignore[no-any-return] + + +def _multilabel_precision_recall_curve_compute( + state: Union[Array, Tuple[Array, Array]], + num_labels: int, + thresholds: Optional[Array], + ignore_index: Optional[int], +) -> Union[Tuple[Array, Array, Array], Tuple[List[Array], List[Array], List[Array]]]: + """Compute the precision and recall for all unique thresholds.""" + if apc.is_array_api_obj(state) and thresholds is not None: + xp = apc.array_namespace(state) + tps = state[:, :, 1, 1] # type: ignore[call-overload] + fps = state[:, :, 0, 1] # type: ignore[call-overload] + fns = state[:, :, 1, 0] # type: ignore[call-overload] + precision = safe_divide(tps, tps + fps) + recall = safe_divide(tps, tps + fns) + precision = xp.concat( + [ + precision, + xp.ones( + (1, num_labels), + dtype=precision.dtype, + device=apc.device(precision), + ), + ], + ) + recall = xp.concat( + [ + recall, + xp.zeros( + (1, num_labels), + dtype=recall.dtype, + device=apc.device(recall), + ), + ], + ) + return precision.T, recall.T, thresholds + + precision_list, recall_list, thres_list = [], [], [] + for i in range(num_labels): + target = state[0][:, i] + preds = state[1][:, i] + if ignore_index is not None: + target, preds = remove_ignore_index( + target, + preds, + ignore_index=ignore_index, + ) + res = _binary_precision_recall_curve_compute( + (target, preds), + thresholds=None, + pos_label=1, + ) + precision_list.append(res[0]) + recall_list.append(res[1]) + thres_list.append(res[2]) + return precision_list, recall_list, thres_list + + +def multilabel_precision_recall_curve( + target: Array, + preds: Array, + num_labels: int, + thresholds: Optional[Union[int, List[float], Array]] = None, + ignore_index: Optional[int] = None, +) -> Union[Tuple[Array, Array, Array], Tuple[List[Array], List[Array], List[Array]]]: + """Compute the precision and recall for all unique thresholds. + + Parameters + ---------- + target : Array + The target array of shape `(N, L, ...)` containing the ground truth labels + in the range [0, 1], where `N` is the number of samples and `L` is the + number of labels. + preds : Array + The prediction array of shape `(N, L, ...)` containing the probability/logit + scores for each sample, where `N` is the number of samples and `L` is the + number of labels. If `preds` contains floating point values that are not + in the range [0,1], they will be converted to probabilities using the + sigmoid function. + num_labels : int + The number of labels in the multilabel classification problem. + thresholds : Union[int, List[float], Array], optional, default=None + The thresholds to use for computing the precision and recall. Can be one + of the following: + - `None`: use all unique values in `preds` as thresholds. + - `int`: use `int` (larger than 1) uniformly spaced thresholds in the range + [0, 1]. + - `List[float]`: use the values in the list as bins for the thresholds. + - `Array`: use the values in the Array as bins for the thresholds. The + array must be 1D. + ignore_index : int, optional, default=None + The value in `target` that should be ignored when computing the precision + and recall. If `None`, all values in `target` are used. + + Returns + ------- + precision : Array or List[Array] + The precision values for all unique thresholds. If `thresholds` is `None`, + a list for each label is returned with 1-D Arrays of shape + `(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape + `(num_thresholds + 1, num_labels)` is returned. + recall : Array or List[Array] + The recall values for all unique thresholds. If `thresholds` is `None`, + a list for each label is returned with 1-D Arrays of shape + `(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape + `(num_thresholds + 1, num_labels)` is returned. + thresholds : Array or List[Array] + The thresholds used for computing the precision and recall values, in + ascending order. If `thresholds` is `None`, a list for each label is + returned with 1-D Arrays of shape `(num_thresholds,)`. Otherwise, a 1-D + Array of shape `(num_thresholds,)` is returned. + + Raises + ------ + TypeError + If `thresholds` is not `None` and not an integer, a list of floats or an + Array of floats. + ValueError + If `thresholds` is an integer and smaller than 2. + ValueError + If `thresholds` is a list of floats with values outside the range [0, 1] + and not monotonically increasing. + ValueError + If `thresholds` is an Array of floats and not all values are in the range + [0, 1] or the array is not one-dimensional. + ValueError + If `ignore_index` is not `None` or an integer. + ValueError + If `num_labels` is not an integer larger than 1. + TypeError + If `target` and `preds` are not compatible with the Python array API standard. + ValueError + If `target` and `preds` are empty. + ValueError + If `target` and `preds` are not numeric arrays. + ValueError + If `target` and `preds` do not have the same shape. + ValueError + If `target` contains floating point values. + ValueError + If `preds` contains non-floating point values. + RuntimeError + If `target` contains values outside the range [0, 1] or does not contain + `ignore_index` if `ignore_index` is not `None`. + ValueError + If the array API namespace of `target` and `preds` are not the same as the + array API namespace of `thresholds`. + ValueError + If the second dimension of `preds` is not equal to `num_labels`. + + Examples + -------- + >>> from cyclops.evaluate.metrics.experimental.functional import ( + ... multilabel_precision_recall_curve + ... ) + >>> import numpy.array_api as anp + >>> target = anp.asarray([[0, 1, 0], [1, 1, 0], [0, 0, 1]]) + >>> preds = anp.asarray( + ... [[0.11, 0.22, 0.67], [0.84, 0.73, 0.12], [0.33, 0.92, 0.44]], + ... ) + >>> precision, recall, thresholds = multilabel_precision_recall_curve( + ... target, preds, num_labels=3, thresholds=None, + ... ) + >>> precision + [Array([0.33333334, 0.5 , 1. , 1. ], dtype=float32), + Array([0.6666667, 0.5 , 0. , 1. ], dtype=float32), + Array([0.33333334, 0.5 , 0. , 1. ], dtype=float32)] + >>> recall + [Array([1., 1., 1., 0.], dtype=float32), Array([1. , 0.5, 0. , 0. ], dtype=float32), Array([1., 1., 0., 0.], dtype=float32)] + >>> thresholds + [Array([0.11, 0.33, 0.84], dtype=float64), Array([0.22, 0.73, 0.92], dtype=float64), Array([0.12, 0.44, 0.67], dtype=float64)] + >>> precision, recall, thresholds = multilabel_precision_recall_curve( + ... target, preds, num_labels=3, thresholds=5, + ... ) + >>> precision + Array([[0.33333334, 0.5 , 1. , 1. , + 0. , 1. ], + [0.6666667 , 0.5 , 0.5 , 0. , + 0. , 1. ], + [0.33333334, 0.5 , 0. , 0. , + 0. , 1. ]], dtype=float32) + >>> recall + Array([[1. , 1. , 1. , 1. , 0. , 0. ], + [1. , 0.5, 0.5, 0. , 0. , 0. ], + [1. , 1. , 0. , 0. , 0. , 0. ]], dtype=float32) + >>> thresholds + Array([0. , 0.25, 0.5 , 0.75, 1. ], dtype=float32) + + """ # noqa: W505 + _multilabel_precision_recall_curve_validate_args( + num_labels, + thresholds, + ignore_index, + ) + xp = _multilabel_precision_recall_curve_validate_arrays( + target, + preds, + num_labels, + thresholds=thresholds, + ignore_index=ignore_index, + ) + target, preds, thresholds = _multilabel_precision_recall_curve_format_arrays( + target, + preds, + num_labels, + thresholds, + ignore_index, + xp=xp, + ) + state = _multilabel_precision_recall_curve_update( + target, + preds, + num_labels, + thresholds, + xp=xp, + ) + return _multilabel_precision_recall_curve_compute( + state, + num_labels, + thresholds, + ignore_index, + ) diff --git a/cyclops/evaluate/metrics/experimental/precision_recall_curve.py b/cyclops/evaluate/metrics/experimental/precision_recall_curve.py new file mode 100644 index 000000000..46bfba20e --- /dev/null +++ b/cyclops/evaluate/metrics/experimental/precision_recall_curve.py @@ -0,0 +1,486 @@ +"""Classes for computing the precision-recall curve.""" +from types import ModuleType +from typing import List, Literal, Optional, Tuple, Union + +import array_api_compat as apc + +from cyclops.evaluate.metrics.experimental.functional.precision_recall_curve import ( + _binary_precision_recall_curve_compute, + _binary_precision_recall_curve_format_arrays, + _binary_precision_recall_curve_update, + _binary_precision_recall_curve_validate_args, + _binary_precision_recall_curve_validate_arrays, + _multiclass_precision_recall_curve_compute, + _multiclass_precision_recall_curve_format_arrays, + _multiclass_precision_recall_curve_update, + _multiclass_precision_recall_curve_validate_args, + _multiclass_precision_recall_curve_validate_arrays, + _multilabel_precision_recall_curve_compute, + _multilabel_precision_recall_curve_format_arrays, + _multilabel_precision_recall_curve_update, + _multilabel_precision_recall_curve_validate_args, + _multilabel_precision_recall_curve_validate_arrays, +) +from cyclops.evaluate.metrics.experimental.metric import Metric +from cyclops.evaluate.metrics.experimental.utils.ops import dim_zero_cat +from cyclops.evaluate.metrics.experimental.utils.types import Array + + +class BinaryPrecisionRecallCurve(Metric, registry_key="binary_precision_recall_curve"): + """The precision and recall values computed at different thresholds. + + Parameters + ---------- + thresholds : Union[int, List[float], Array], optional, default=None + The thresholds to use for computing the precision and recall. Can be one + of the following: + - `None`: use all unique values in `preds` as thresholds. + - `int`: use `int` (larger than 1) uniformly spaced thresholds in the range + [0, 1]. + - `List[float]`: use the values in the list as bins for the thresholds. + - `Array`: use the values in the Array as bins for the thresholds. The + array must be 1D. + ignore_index : int, optional, default=None + The value in `target` that should be ignored when computing the precision + and recall. If `None`, all values in `target` are used. + + Examples + -------- + >>> import numpy.array_api as anp + >>> from cyclops.evaluate.metrics.experimental import BinaryPrecisionRecallCurve + >>> target = anp.asarray([0, 1, 0, 1, 0, 1]) + >>> preds = anp.asarray([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) + >>> metric = BinaryPrecisionRecallCurve(thresholds=None) + >>> metric(target, preds) + (Array([0.5 , 0.6 , 0.5 , 0.6666667, + 0.5 , 1. , 1. ], dtype=float32), Array([1. , 1. , 0.6666667 , 0.6666667 , + 0.33333334, 0.33333334, 0. ], dtype=float32), Array([0.11, 0.22, 0.33, 0.73, 0.84, 0.92], dtype=float64)) + >>> metric = BinaryPrecisionRecallCurve(thresholds=5) + >>> metric(target, preds) + (Array([0.5 , 0.5 , 0.6666667, 0.5 , + 0. , 1. ], dtype=float32), Array([1. , 0.6666667 , 0.6666667 , 0.33333334, + 0. , 0. ], dtype=float32), Array([0. , 0.25, 0.5 , 0.75, 1. ], dtype=float32)) + + """ # noqa: W505 + + name: str = "Precision-Recall Curve" + + def __init__( + self, + thresholds: Optional[Union[int, List[float], Array]] = None, + ignore_index: Optional[int] = None, + ) -> None: + """Initialize a `BinaryPrecisionRecallCurve` instance.""" + super().__init__() + _binary_precision_recall_curve_validate_args(thresholds, ignore_index) + self.ignore_index = ignore_index + self.thresholds = thresholds + + if thresholds is None: + self.add_state_default_factory( + "preds", + default_factory=list, # type: ignore + dist_reduce_fn="cat", + ) + self.add_state_default_factory( + "target", + default_factory=list, # type: ignore + dist_reduce_fn="cat", + ) + else: + len_thresholds = ( + len(thresholds) + if isinstance(thresholds, list) + else thresholds + if isinstance(thresholds, int) + else thresholds.shape[0] + ) + + def default(xp: ModuleType) -> Array: + return xp.zeros((len_thresholds, 2, 2), dtype=xp.int32, device=self.device) # type: ignore[no-any-return] + + self.add_state_default_factory( + "confmat", + default_factory=default, # type: ignore + dist_reduce_fn="sum", + ) + + def _update_state(self, target: Array, preds: Array) -> None: + """Update the state of the metric.""" + xp = _binary_precision_recall_curve_validate_arrays( + target, + preds, + self.thresholds, + self.ignore_index, + ) + target, preds, self.thresholds = _binary_precision_recall_curve_format_arrays( + target, + preds, + thresholds=self.thresholds, + ignore_index=self.ignore_index, + xp=xp, + ) + state = _binary_precision_recall_curve_update( + target, + preds, + thresholds=self.thresholds, + xp=xp, + ) + + if apc.is_array_api_obj(state): + self.confmat += state # type: ignore[attr-defined] + else: + self.target.append(state[0]) # type: ignore[attr-defined] + self.preds.append(state[1]) # type: ignore[attr-defined] + + def _compute_metric(self) -> Tuple[Array, Array, Array]: + """Compute the metric.""" + state = ( + (dim_zero_cat(self.target), dim_zero_cat(self.preds)) # type: ignore[attr-defined] + if self.thresholds is None + else self.confmat # type: ignore[attr-defined] + ) + return _binary_precision_recall_curve_compute(state, self.thresholds) # type: ignore[arg-type] + + +class MulticlassPrecisionRecallCurve( + Metric, + registry_key="multiclass_precision_recall_curve", +): + """The precision and recall values computed at different thresholds. + + Parameters + ---------- + num_classes : int + The number of classes in the classification problem. + thresholds : Union[int, List[float], Array], optional, default=None + The thresholds to use for computing the precision and recall. Can be one + of the following: + - `None`: use all unique values in `preds` as thresholds. + - `int`: use `int` (larger than 1) uniformly spaced thresholds in the range + [0, 1]. + - `List[float]`: use the values in the list as bins for the thresholds. + - `Array`: use the values in the Array as bins for the thresholds. The + array must be 1D. + average : {"macro", "micro", "none"}, optional, default=None + The type of averaging to use for computing the precision and recall. Can + be one of the following: + - `"macro"`: interpolates the curves from each class at a combined set of + thresholds and then average over the classwise interpolated curves. + - `"micro"`: one-hot encodes the targets and flattens the predictions, + considering all classes jointly as a binary problem. + - `"none"`: do not average over the classwise curves. + ignore_index : int or Tuple[int], optional, default=None + The value(s) in `target` that should be ignored when computing the + precision and recall. If `None`, all values in `target` are used. + + Examples + -------- + >>> import numpy.array_api as anp + >>> from cyclops.evaluate.metrics.experimental import MulticlassPrecisionRecallCurve + >>> target = anp.asarray([0, 1, 2, 0, 1, 2]) + >>> preds = anp.asarray( + ... [[0.11, 0.22, 0.67], + ... [0.84, 0.73, 0.12], + ... [0.33, 0.92, 0.44], + ... [0.11, 0.22, 0.67], + ... [0.84, 0.73, 0.12], + ... [0.33, 0.92, 0.44]] + ... ) + >>> metric = MulticlassPrecisionRecallCurve(num_classes=3, thresholds=None) + >>> metric(target, preds) + ([Array([0.33333334, 0. , 0. , 1. ], dtype=float32), + Array([0.33333334, 0.5 , 0. , 1. ], dtype=float32), + Array([0.33333334, 0.5 , 0. , 1. ], dtype=float32)], + [Array([1., 0., 0., 0.], dtype=float32), + Array([1., 1., 0., 0.], dtype=float32), + Array([1., 1., 0., 0.], dtype=float32)], + [Array([0.11, 0.33, 0.84], dtype=float64), + Array([0.22, 0.73, 0.92], dtype=float64), + Array([0.12, 0.44, 0.67], dtype=float64)]) + >>> metric = MulticlassPrecisionRecallCurve(num_classes=3, thresholds=5) + >>> metric(target, preds) + (Array([[0.33333334, 0. , 0. , 0. , + 0. , 1. ], + [0.33333334, 0.5 , 0.5 , 0. , + 0. , 1. ], + [0.33333334, 0.5 , 0. , 0. , + 0. , 1. ]], dtype=float32), Array([[1., 0., 0., 0., 0., 0.], + [1., 1., 1., 0., 0., 0.], + [1., 1., 0., 0., 0., 0.]], dtype=float32), Array([0. , 0.25, 0.5 , 0.75, 1. ], dtype=float32)) + + """ # noqa: W505 + + name: str = "Precision-Recall Curve" + + def __init__( + self, + num_classes: int, + thresholds: Optional[Union[int, List[float], Array]] = None, + average: Optional[Literal["macro", "micro", "none"]] = None, + ignore_index: Optional[Union[int, Tuple[int]]] = None, + ) -> None: + """Initialize a `MulticlassPrecisionRecallCurve` instance.""" + super().__init__() + _multiclass_precision_recall_curve_validate_args( + num_classes, + thresholds=thresholds, + average=average, + ignore_index=ignore_index, + ) + self.num_classes = num_classes + self.ignore_index = ignore_index + self.average = average + self.thresholds = thresholds + + if thresholds is None: + self.add_state_default_factory( + "preds", + default_factory=list, # type: ignore + dist_reduce_fn="cat", + ) + self.add_state_default_factory( + "target", + default_factory=list, # type: ignore + dist_reduce_fn="cat", + ) + else: + len_thresholds = ( + len(thresholds) + if isinstance(thresholds, list) + else thresholds + if isinstance(thresholds, int) + else thresholds.shape[0] + ) + + def default(xp: ModuleType) -> Array: + return xp.zeros( # type: ignore[no-any-return] + (len_thresholds, num_classes, 2, 2), + dtype=xp.int32, + device=self.device, + ) + + self.add_state_default_factory( + "confmat", + default_factory=default, # type: ignore + dist_reduce_fn="sum", + ) + + def _update_state(self, target: Array, preds: Array) -> None: + """Update the state of the metric.""" + xp = _multiclass_precision_recall_curve_validate_arrays( + target, + preds, + self.num_classes, + ignore_index=self.ignore_index, + ) + + ( + target, + preds, + self.thresholds, + ) = _multiclass_precision_recall_curve_format_arrays( + target, + preds, + self.num_classes, + thresholds=self.thresholds, + ignore_index=self.ignore_index, + average=self.average, + xp=xp, + ) + state = _multiclass_precision_recall_curve_update( + target, + preds, + self.num_classes, + thresholds=self.thresholds, + average=self.average, + xp=xp, + ) + + if apc.is_array_api_obj(state): + self.confmat += state # type: ignore[attr-defined] + else: + self.target.append(state[0]) # type: ignore[attr-defined] + self.preds.append(state[1]) # type: ignore[attr-defined] + + def _compute_metric( + self, + ) -> Union[ + Tuple[Array, Array, Array], + Tuple[List[Array], List[Array], List[Array]], + ]: + """Compute the metric.""" + state = ( + (dim_zero_cat(self.target), dim_zero_cat(self.preds)) # type: ignore[attr-defined] + if self.thresholds is None + else self.confmat # type: ignore[attr-defined] + ) + return _multiclass_precision_recall_curve_compute( + state, + self.num_classes, + thresholds=self.thresholds, # type: ignore[arg-type] + average=self.average, + ) + + +class MultilabelPrecisionRecallCurve( + Metric, + registry_key="multilabel_precision_recall_curve", +): + """The precision and recall values computed at different thresholds. + + Parameters + ---------- + num_labels : int + The number of labels in the multilabel classification problem. + thresholds : Union[int, List[float], Array], optional, default=None + The thresholds to use for computing the precision and recall. Can be one + of the following: + - `None`: use all unique values in `preds` as thresholds. + - `int`: use `int` (larger than 1) uniformly spaced thresholds in the range + [0, 1]. + - `List[float]`: use the values in the list as bins for the thresholds. + - `Array`: use the values in the Array as bins for the thresholds. The + array must be 1D. + ignore_index : int, optional, default=None + The value in `target` that should be ignored when computing the precision + and recall. If `None`, all values in `target` are used. + + Examples + -------- + >>> import numpy.array_api as anp + >>> from cyclops.evaluate.metrics.experimental import MultilabelPrecisionRecallCurve + >>> target = anp.asarray([[0, 1, 0], [1, 1, 0], [0, 0, 1]]) + >>> preds = anp.asarray( + ... [[0.11, 0.22, 0.67], [0.84, 0.73, 0.12], [0.33, 0.92, 0.44]], + ... ) + >>> metric = MultilabelPrecisionRecallCurve(num_labels=3, thresholds=None) + >>> metric(target, preds) + ([Array([0.33333334, 0.5 , 1. , 1. ], dtype=float32), + Array([0.6666667, 0.5 , 0. , 1. ], dtype=float32), + Array([0.33333334, 0.5 , 0. , 1. ], dtype=float32)], + [Array([1., 1., 1., 0.], dtype=float32), + Array([1. , 0.5, 0. , 0. ], dtype=float32), + Array([1., 1., 0., 0.], dtype=float32)], + [Array([0.11, 0.33, 0.84], dtype=float64), + Array([0.22, 0.73, 0.92], dtype=float64), + Array([0.12, 0.44, 0.67], dtype=float64)]) + >>> metric = MultilabelPrecisionRecallCurve(num_labels=3, thresholds=5) + >>> metric(target, preds) + (Array([[0.33333334, 0.5 , 1. , 1. , + 0. , 1. ], + [0.6666667 , 0.5 , 0.5 , 0. , + 0. , 1. ], + [0.33333334, 0.5 , 0. , 0. , + 0. , 1. ]], dtype=float32), Array([[1. , 1. , 1. , 1. , 0. , 0. ], + [1. , 0.5, 0.5, 0. , 0. , 0. ], + [1. , 1. , 0. , 0. , 0. , 0. ]], dtype=float32), Array([0. , 0.25, 0.5 , 0.75, 1. ], dtype=float32)) + + """ # noqa: W505 + + name: str = "Precision-Recall Curve" + + def __init__( + self, + num_labels: int, + thresholds: Optional[Union[int, List[float], Array]] = None, + ignore_index: Optional[int] = None, + ) -> None: + """Initialize a `MultilabelPrecisionRecallCurve` instance.""" + super().__init__() + _multilabel_precision_recall_curve_validate_args( + num_labels, + thresholds=thresholds, + ignore_index=ignore_index, + ) + self.num_labels = num_labels + self.ignore_index = ignore_index + self.thresholds = thresholds + + if thresholds is None: + self.add_state_default_factory( + "preds", + default_factory=list, # type: ignore + dist_reduce_fn="cat", + ) + self.add_state_default_factory( + "target", + default_factory=list, # type: ignore + dist_reduce_fn="cat", + ) + else: + len_thresholds = ( + len(thresholds) + if isinstance(thresholds, list) + else thresholds + if isinstance(thresholds, int) + else thresholds.shape[0] + ) + + def default(xp: ModuleType) -> Array: + return xp.zeros( # type: ignore[no-any-return] + (len_thresholds, num_labels, 2, 2), + dtype=xp.int32, + device=self.device, + ) + + self.add_state_default_factory( + "confmat", + default_factory=default, # type: ignore + dist_reduce_fn="sum", + ) + + def _update_state(self, target: Array, preds: Array) -> None: + """Update the state of the metric.""" + xp = _multilabel_precision_recall_curve_validate_arrays( + target, + preds, + self.num_labels, + thresholds=self.thresholds, + ignore_index=self.ignore_index, + ) + + ( + target, + preds, + self.thresholds, + ) = _multilabel_precision_recall_curve_format_arrays( + target, + preds, + self.num_labels, + thresholds=self.thresholds, + ignore_index=self.ignore_index, + xp=xp, + ) + state = _multilabel_precision_recall_curve_update( + target, + preds, + self.num_labels, + thresholds=self.thresholds, + xp=xp, + ) + + if apc.is_array_api_obj(state): + self.confmat += state # type: ignore[attr-defined] + else: + self.target.append(state[0]) # type: ignore[attr-defined] + self.preds.append(state[1]) # type: ignore[attr-defined] + + def _compute_metric( + self, + ) -> Union[ + Tuple[Array, Array, Array], + Tuple[List[Array], List[Array], List[Array]], + ]: + """Compute the metric.""" + state = ( + (dim_zero_cat(self.target), dim_zero_cat(self.preds)) # type: ignore[attr-defined] + if self.thresholds is None + else self.confmat # type: ignore[attr-defined] + ) + return _multilabel_precision_recall_curve_compute( + state, + self.num_labels, + thresholds=self.thresholds, # type: ignore[arg-type] + ignore_index=self.ignore_index, + ) diff --git a/cyclops/evaluate/metrics/experimental/utils/ops.py b/cyclops/evaluate/metrics/experimental/utils/ops.py index e51634787..5d9279e56 100644 --- a/cyclops/evaluate/metrics/experimental/utils/ops.py +++ b/cyclops/evaluate/metrics/experimental/utils/ops.py @@ -685,6 +685,76 @@ def _adjust_weight_apply_average( ) +def _array_indexing(arr: Array, idx: Array) -> Array: + try: + return arr[idx] + except IndexError: + xp = apc.array_namespace(arr, idx) + np_idx = np.from_dlpack(apc.to_device(idx, "cpu")) + np_arr = np.from_dlpack(apc.to_device(arr, "cpu"))[np_idx] + return xp.asarray(np_arr, dtype=arr.dtype, device=apc.device(arr)) + + +def _cumsum(x: Array, axis: Optional[int], dtype: Optional[Any] = None) -> Array: + xp = apc.array_namespace(x) + if hasattr(xp, "cumsum"): + return xp.cumsum(x, axis, dtype=dtype) + + if axis is None: + x = flatten(x) + axis = 0 + + if axis < 0 or axis >= x.ndim: + raise ValueError("Invalid axis value") + + if dtype is None: + dtype = x.dtype + + if axis < 0: + axis += x.ndim + + if int(apc.size(x) or 0) == 0: + return x + + result = xp.empty_like(x, dtype=dtype, device=apc.device(x)) + + # create slice object with `axis` at the appropriate position + curr_indices = [slice(None)] * x.ndim + prev_indices = [slice(None)] * x.ndim + + curr_indices[axis] = 0 # type: ignore[call-overload] + result[tuple(curr_indices)] = x[tuple(curr_indices)] + for i in range(1, x.shape[axis]): + prev_indices[axis] = i - 1 # type: ignore[call-overload] + curr_indices[axis] = i # type: ignore[call-overload] + result[tuple(curr_indices)] = ( + result[tuple(prev_indices)] + x[tuple(curr_indices)] + ) + + return result + + +def _interp(x: Array, xcoords: Array, ycoords: Array) -> Array: + """Perform linear interpolation for 1D arrays.""" + xp = apc.array_namespace(x, xcoords, ycoords) + if hasattr(xp, "interp"): + return xp.interp(x, xcoords, ycoords) + + m = safe_divide(ycoords[1:] - ycoords[:-1], xcoords[1:] - xcoords[:-1]) + b = ycoords[:-1] - (m * xcoords[:-1]) + + indices = xp.sum(x[:, None] >= xcoords[None, :], 1) - 1 + _min_val = xp.asarray(0, dtype=xp.float32, device=apc.device(x)) + _max_val = xp.asarray( + m.shape[0] if m.ndim > 0 else 1 - 1, + dtype=xp.float32, + device=apc.device(x), + ) + indices = xp.min(xp.max(indices, _min_val), _max_val) + + return _array_indexing(m, indices) * x + _array_indexing(b, indices) + + def _select_topk( # noqa: PLR0912 scores: Array, top_k: int = 1, diff --git a/tests/cyclops/evaluate/metrics/experimental/inputs.py b/tests/cyclops/evaluate/metrics/experimental/inputs.py index 1e7fa3ce2..2a0197600 100644 --- a/tests/cyclops/evaluate/metrics/experimental/inputs.py +++ b/tests/cyclops/evaluate/metrics/experimental/inputs.py @@ -63,6 +63,20 @@ def _binary_cases(*, xp: Any): ), id="input[single-element-labels]", ), + pytest.param( + InputSpec( + target=xp.asarray(_binary_labels_1d), + preds=xp.asarray(_binary_preds_1d), + ), + id="input[1d-labels]", + ), + pytest.param( + InputSpec( + target=xp.asarray(_binary_labels_multidim), + preds=xp.asarray(_binary_preds_multidim), + ), + id="input[multidim-labels]", + ), pytest.param( InputSpec( target=xp.asarray(_binary_labels_0d), @@ -77,13 +91,6 @@ def _binary_cases(*, xp: Any): ), id="input[single-element-logits]", ), - pytest.param( - InputSpec( - target=xp.asarray(_binary_labels_1d), - preds=xp.asarray(_binary_preds_1d), - ), - id="input[1d-labels]", - ), pytest.param( InputSpec( target=xp.asarray(_binary_labels_1d), @@ -96,14 +103,7 @@ def _binary_cases(*, xp: Any): target=xp.asarray(_binary_labels_1d), preds=xp.asarray(_inv_sigmoid(_binary_probs_1d)), ), - id="input[1d-probs]", - ), - pytest.param( - InputSpec( - target=xp.asarray(_binary_labels_multidim), - preds=xp.asarray(_binary_preds_multidim), - ), - id="input[multidim-labels]", + id="input[1d-logits]", ), pytest.param( InputSpec( @@ -117,7 +117,7 @@ def _binary_cases(*, xp: Any): target=xp.asarray(_binary_labels_multidim), preds=xp.asarray(_inv_sigmoid(_binary_probs_multidim)), ), - id="input[multidim-probs]", + id="input[multidim-logits]", ), ) @@ -182,20 +182,6 @@ def _multiclass_cases(*, xp: Any): ), id="input[single-element-labels]", ), - pytest.param( - InputSpec( - target=xp.asarray(_multiclass_labels_0d), - preds=xp.asarray(_multiclass_probs_0d), - ), - id="input[single-element-probs]", - ), - pytest.param( - InputSpec( - target=xp.asarray(_multiclass_labels_0d), - preds=xp.asarray(log_softmax(_multiclass_probs_0d, axis=-1)), - ), - id="input[single-element-logits]", - ), pytest.param( InputSpec( target=xp.asarray(_multiclass_labels_1d), @@ -203,20 +189,6 @@ def _multiclass_cases(*, xp: Any): ), id="input[1d-labels]", ), - pytest.param( - InputSpec( - target=xp.asarray(_multiclass_labels_1d), - preds=xp.asarray(_multiclass_probs_1d), - ), - id="input[1d-probs]", - ), - pytest.param( - InputSpec( - target=xp.asarray(_multiclass_labels_1d), - preds=xp.asarray(log_softmax(_multiclass_probs_1d, axis=-1)), - ), - id="input[1d-logits]", - ), pytest.param( InputSpec( preds=_multiclass_with_missing_class( @@ -241,6 +213,34 @@ def _multiclass_cases(*, xp: Any): ), id="input[multidim-labels]", ), + pytest.param( + InputSpec( + target=xp.asarray(_multiclass_labels_0d), + preds=xp.asarray(_multiclass_probs_0d), + ), + id="input[single-element-probs]", + ), + pytest.param( + InputSpec( + target=xp.asarray(_multiclass_labels_0d), + preds=xp.asarray(log_softmax(_multiclass_probs_0d, axis=-1)), + ), + id="input[single-element-logits]", + ), + pytest.param( + InputSpec( + target=xp.asarray(_multiclass_labels_1d), + preds=xp.asarray(_multiclass_probs_1d), + ), + id="input[1d-probs]", + ), + pytest.param( + InputSpec( + target=xp.asarray(_multiclass_labels_1d), + preds=xp.asarray(log_softmax(_multiclass_probs_1d, axis=-1)), + ), + id="input[1d-logits]", + ), pytest.param( InputSpec( target=xp.asarray(_multiclass_labels_multidim), @@ -293,6 +293,13 @@ def _multilabel_cases(*, xp: Any): ), id="input[2d-labels]", ), + pytest.param( + InputSpec( + target=xp.asarray(_multilabel_labels_multidim), + preds=xp.asarray(_multilabel_preds_multidim), + ), + id="input[multidim-labels]", + ), pytest.param( InputSpec( target=xp.asarray(_multilabel_labels), @@ -307,13 +314,6 @@ def _multilabel_cases(*, xp: Any): ), id="input[2d-logits]", ), - pytest.param( - InputSpec( - target=xp.asarray(_multilabel_labels_multidim), - preds=xp.asarray(_multilabel_preds_multidim), - ), - id="input[multidim-labels]", - ), pytest.param( InputSpec( target=xp.asarray(_multilabel_labels_multidim), diff --git a/tests/cyclops/evaluate/metrics/experimental/test_precision_recall_curve.py b/tests/cyclops/evaluate/metrics/experimental/test_precision_recall_curve.py new file mode 100644 index 000000000..b19e36822 --- /dev/null +++ b/tests/cyclops/evaluate/metrics/experimental/test_precision_recall_curve.py @@ -0,0 +1,492 @@ +"""Test precision-recall curve metric.""" +from functools import partial +from types import ModuleType +from typing import List, Tuple, Union + +import array_api_compat as apc +import array_api_compat.torch +import numpy.array_api as anp +import pytest +import torch.utils.dlpack +from torchmetrics.functional.classification import ( + binary_precision_recall_curve as tm_binary_precision_recall_curve, +) +from torchmetrics.functional.classification import ( + multiclass_precision_recall_curve as tm_multiclass_precision_recall_curve, +) +from torchmetrics.functional.classification import ( + multilabel_precision_recall_curve as tm_multilabel_precision_recall_curve, +) + +from cyclops.evaluate.metrics.experimental.functional.precision_recall_curve import ( + binary_precision_recall_curve, + multiclass_precision_recall_curve, + multilabel_precision_recall_curve, +) +from cyclops.evaluate.metrics.experimental.precision_recall_curve import ( + BinaryPrecisionRecallCurve, + MulticlassPrecisionRecallCurve, + MultilabelPrecisionRecallCurve, +) +from cyclops.evaluate.metrics.experimental.utils.ops import to_int +from cyclops.evaluate.metrics.experimental.utils.validation import is_floating_point + +from ..conftest import NUM_CLASSES, NUM_LABELS +from .inputs import _binary_cases, _multiclass_cases, _multilabel_cases +from .testers import MetricTester, _inject_ignore_index + + +def _thresholds_for_prc(*, xp: ModuleType) -> list: + """Return thresholds for precision-recall curve.""" + thresh_list = [0.0, 0.3, 0.5, 0.7, 0.9, 1.0] + return [None, 5, thresh_list, xp.asarray(thresh_list)] + + +def _binary_precision_recall_curve_reference( + target, + preds, + thresholds, + ignore_index, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Return the reference binary precision-recall curve.""" + return tm_binary_precision_recall_curve( + torch.utils.dlpack.from_dlpack(preds), + torch.utils.dlpack.from_dlpack(target), + thresholds=torch.utils.dlpack.from_dlpack(thresholds) + if apc.is_array_api_obj(thresholds) + else thresholds, + ignore_index=ignore_index, + ) + + +class TestBinaryPrecisionRecallCurve(MetricTester): + """Test binary precision-recall curve function and class.""" + + @pytest.mark.parametrize("inputs", _binary_cases(xp=anp)[3:]) + @pytest.mark.parametrize("thresholds", _thresholds_for_prc(xp=anp)) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_binary_precision_recall_curve_function_with_numpy_array_api_arrays( + self, + inputs, + thresholds, + ignore_index, + ) -> None: + """Test function for binary precision-recall curve using array_api arrays.""" + target, preds = inputs + + if ignore_index is not None: + if target.shape[1] == 1 and anp.any(target == ignore_index): + pytest.skip( + "When targets are single elements and 'ignore_index' in target " + "the function will raise an error because it will receive an " + "empty array after filtering out the 'ignore_index' values.", + ) + target = _inject_ignore_index(target, ignore_index) + + self.run_metric_function_implementation_test( + target, + preds, + metric_function=binary_precision_recall_curve, + metric_args={ + "thresholds": thresholds, + "ignore_index": ignore_index, + }, + reference_metric=partial( + _binary_precision_recall_curve_reference, + thresholds=thresholds, + ignore_index=ignore_index, + ), + ) + + @pytest.mark.parametrize("inputs", _binary_cases(xp=anp)[3:]) + @pytest.mark.parametrize("thresholds", _thresholds_for_prc(xp=anp)) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_binary_precision_recall_cuvrve_class_with_numpy_array_api_arrays( + self, + inputs, + thresholds, + ignore_index, + ) -> None: + """Test class for binary precision-recall curve using array_api arrays.""" + target, preds = inputs + + if ( + preds.shape[1] == 1 + and is_floating_point(preds) + and not anp.all(to_int((preds >= 0)) * to_int((preds <= 1))) + ): + pytest.skip( + "When using 0-D logits, batch result will be different from local " + "result because the `sigmoid` operation may not be applied to each " + "batch (some values may be in [0, 1] and some may not).", + ) + + if ignore_index is not None: + if target.shape[1] == 1 and anp.any(target == ignore_index): + pytest.skip( + "When targets are single elements and 'ignore_index' in target " + "the function will raise an error because it will receive an " + "empty array after filtering out the 'ignore_index' values.", + ) + target = _inject_ignore_index(target, ignore_index) + + self.run_metric_class_implementation_test( + target, + preds, + metric_class=BinaryPrecisionRecallCurve, + metric_args={ + "thresholds": thresholds, + "ignore_index": ignore_index, + }, + reference_metric=partial( + _binary_precision_recall_curve_reference, + thresholds=thresholds, + ignore_index=ignore_index, + ), + ) + + @pytest.mark.integration_test() # machine for integration tests has GPU + @pytest.mark.parametrize("inputs", _binary_cases(xp=array_api_compat.torch)[3:]) + @pytest.mark.parametrize( + "thresholds", + _thresholds_for_prc(xp=array_api_compat.torch), + ) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_binary_precision_recall_curve_with_torch_tensors( + self, + inputs, + thresholds, + ignore_index, + ) -> None: + """Test binary precision-recall curve class with torch tensors.""" + target, preds = inputs + + if ( + preds.shape[1] == 1 + and is_floating_point(preds) + and not torch.all(to_int((preds >= 0)) * to_int((preds <= 1))) + ): + pytest.skip( + "When using 0-D logits, batch result will be different from local " + "result because the `sigmoid` operation may not be applied to each " + "batch (some values may be in [0, 1] and some may not).", + ) + + if ignore_index is not None: + if target.shape[1] == 1 and torch.any(target == ignore_index): + pytest.skip( + "When targets are single elements and 'ignore_index' in target " + "the function will raise an error because it will receive an " + "empty array after filtering out the 'ignore_index' values.", + ) + target = _inject_ignore_index(target, ignore_index) + + device = "cuda" if torch.cuda.is_available() else "cpu" + if isinstance(thresholds, torch.Tensor): + thresholds = thresholds.to(device) + + self.run_metric_class_implementation_test( + target, + preds, + metric_class=BinaryPrecisionRecallCurve, + metric_args={ + "thresholds": thresholds, + "ignore_index": ignore_index, + }, + reference_metric=partial( + _binary_precision_recall_curve_reference, + thresholds=thresholds, + ignore_index=ignore_index, + ), + device=device, + use_device_for_ref=True, + ) + + +def _multiclass_precision_recall_curve_reference( + target, + preds, + num_classes=NUM_CLASSES, + thresholds=None, + ignore_index=None, +) -> Union[ + Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]], +]: + """Return the reference multiclass precision-recall curve.""" + if preds.ndim == 1 and is_floating_point(preds): + xp = apc.array_namespace(preds) + preds = xp.argmax(preds, axis=0) + + return tm_multiclass_precision_recall_curve( + torch.utils.dlpack.from_dlpack(preds), + torch.utils.dlpack.from_dlpack(target), + num_classes, + thresholds=torch.utils.dlpack.from_dlpack(thresholds) + if apc.is_array_api_obj(thresholds) + else thresholds, + ignore_index=ignore_index, + ) + + +class TestMulticlassPrecisionRecallCurve(MetricTester): + """Test multiclass precision-recall curve function and class.""" + + @pytest.mark.parametrize("inputs", _multiclass_cases(xp=anp)[4:]) + @pytest.mark.parametrize("thresholds", _thresholds_for_prc(xp=anp)) + @pytest.mark.parametrize("average", [None, "none"]) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_multiclass_precision_recall_curve_with_numpy_array_api_arrays( + self, + inputs, + thresholds, + average, + ignore_index, + ) -> None: + """Test function for multiclass precision-recall curve.""" + target, preds = inputs + + if ignore_index is not None: + if target.shape[1] == 1 and anp.any(target == ignore_index): + pytest.skip( + "When targets are single elements and 'ignore_index' in target " + "the function will raise an error because it will receive an " + "empty array after filtering out the 'ignore_index' values.", + ) + target = _inject_ignore_index(target, ignore_index) + + self.run_metric_function_implementation_test( + target, + preds, + metric_function=multiclass_precision_recall_curve, + metric_args={ + "num_classes": NUM_CLASSES, + "thresholds": thresholds, + "average": average, + "ignore_index": ignore_index, + }, + reference_metric=partial( + _multiclass_precision_recall_curve_reference, + thresholds=thresholds, + ignore_index=ignore_index, + ), + ) + + @pytest.mark.parametrize("inputs", _multiclass_cases(xp=anp)[4:]) + @pytest.mark.parametrize("thresholds", _thresholds_for_prc(xp=anp)) + @pytest.mark.parametrize("average", [None, "none"]) + @pytest.mark.parametrize("ignore_index", [None, 1, -1]) + def test_multiclass_precision_recall_curve_class_with_numpy_array_api_arrays( + self, + inputs, + thresholds, + average, + ignore_index, + ) -> None: + """Test class for multiclass precision-recall curve.""" + target, preds = inputs + + if ignore_index is not None: + if target.shape[1] == 1 and anp.any(target == ignore_index): + pytest.skip( + "When targets are single elements and 'ignore_index' in target " + "the function will raise an error because it will receive an " + "empty array after filtering out the 'ignore_index' values.", + ) + target = _inject_ignore_index(target, ignore_index) + + self.run_metric_class_implementation_test( + target, + preds, + metric_class=MulticlassPrecisionRecallCurve, + reference_metric=partial( + _multiclass_precision_recall_curve_reference, + thresholds=thresholds, + ignore_index=ignore_index, + ), + metric_args={ + "num_classes": NUM_CLASSES, + "average": average, + "thresholds": thresholds, + "ignore_index": ignore_index, + }, + ) + + @pytest.mark.integration_test() # machine for integration tests has GPU + @pytest.mark.parametrize("inputs", _multiclass_cases(xp=array_api_compat.torch)[4:]) + @pytest.mark.parametrize( + "thresholds", + _thresholds_for_prc(xp=array_api_compat.torch), + ) + @pytest.mark.parametrize("average", [None, "none"]) + @pytest.mark.parametrize("ignore_index", [None, 1, -1]) + def test_multiclass_precision_recall_curve_class_with_torch_tensors( + self, + inputs, + thresholds, + average, + ignore_index, + ) -> None: + """Test class for multiclass precision-recall curve.""" + target, preds = inputs + + if ignore_index is not None: + if target.shape[1] == 1 and torch.any(target == ignore_index): + pytest.skip( + "When targets are single elements and 'ignore_index' in target " + "the function will raise an error because it will receive an " + "empty array after filtering out the 'ignore_index' values.", + ) + target = _inject_ignore_index(target, ignore_index) + + device = "cuda" if torch.cuda.is_available() else "cpu" + if isinstance(thresholds, torch.Tensor): + thresholds = thresholds.to(device) + + self.run_metric_class_implementation_test( + target, + preds, + metric_class=MulticlassPrecisionRecallCurve, + reference_metric=partial( + _multiclass_precision_recall_curve_reference, + thresholds=thresholds, + ignore_index=ignore_index, + ), + metric_args={ + "num_classes": NUM_CLASSES, + "thresholds": thresholds, + "average": average, + "ignore_index": ignore_index, + }, + device=device, + use_device_for_ref=True, + ) + + +def _multilabel_precision_recall_curve_reference( + preds, + target, + num_labels=NUM_LABELS, + thresholds=None, + ignore_index=None, +) -> Union[ + Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]], +]: + """Return the reference multilabel precision-recall curve.""" + return tm_multilabel_precision_recall_curve( + torch.utils.dlpack.from_dlpack(preds), + torch.utils.dlpack.from_dlpack(target), + num_labels, + thresholds=torch.utils.dlpack.from_dlpack(thresholds) + if apc.is_array_api_obj(thresholds) + else thresholds, + ignore_index=ignore_index, + ) + + +class TestMultilabelPrecisionRecallCurve(MetricTester): + """Test multilabel precision-recall curve function and class.""" + + @pytest.mark.parametrize("inputs", _multilabel_cases(xp=anp)[2:]) + @pytest.mark.parametrize("thresholds", _thresholds_for_prc(xp=anp)) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_multilabel_precision_recall_curve_with_numpy_array_api_arrays( + self, + inputs, + thresholds, + ignore_index, + ) -> None: + """Test function for multilabel precision-recall curve.""" + target, preds = inputs + + if ignore_index is not None: + target = _inject_ignore_index(target, ignore_index) + + self.run_metric_function_implementation_test( + target, + preds, + metric_function=multilabel_precision_recall_curve, + reference_metric=partial( + _multilabel_precision_recall_curve_reference, + thresholds=thresholds, + ignore_index=ignore_index, + ), + metric_args={ + "thresholds": thresholds, + "num_labels": NUM_LABELS, + "ignore_index": ignore_index, + }, + ) + + @pytest.mark.parametrize("inputs", _multilabel_cases(xp=anp)[2:]) + @pytest.mark.parametrize("thresholds", _thresholds_for_prc(xp=anp)) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_multilabel_precision_recall_curve_class_with_numpy_array_api_arrays( + self, + inputs, + thresholds, + ignore_index, + ) -> None: + """Test class for multilabel precision-recall curve.""" + target, preds = inputs + + if ignore_index is not None: + target = _inject_ignore_index(target, ignore_index) + + self.run_metric_class_implementation_test( + target, + preds, + metric_class=MultilabelPrecisionRecallCurve, + reference_metric=partial( + _multilabel_precision_recall_curve_reference, + thresholds=thresholds, + ignore_index=ignore_index, + ), + metric_args={ + "thresholds": thresholds, + "num_labels": NUM_LABELS, + "ignore_index": ignore_index, + }, + ) + + @pytest.mark.integration_test() # machine for integration tests has GPU + @pytest.mark.parametrize("inputs", _multilabel_cases(xp=array_api_compat.torch)[2:]) + @pytest.mark.parametrize( + "thresholds", + _thresholds_for_prc(xp=array_api_compat.torch), + ) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_multilabel_precision_recall_curve_class_with_torch_tensors( + self, + inputs, + thresholds, + ignore_index, + ) -> None: + """Test class for multilabel precision-recall curve.""" + target, preds = inputs + + if ignore_index is not None: + target = _inject_ignore_index(target, ignore_index) + + device = "cuda" if torch.cuda.is_available() else "cpu" + if isinstance(thresholds, torch.Tensor): + thresholds = thresholds.to(device) + + self.run_metric_class_implementation_test( + target, + preds, + metric_class=MultilabelPrecisionRecallCurve, + reference_metric=partial( + _multilabel_precision_recall_curve_reference, + thresholds=thresholds, + ignore_index=ignore_index, + ), + metric_args={ + "thresholds": thresholds, + "num_labels": NUM_LABELS, + "ignore_index": ignore_index, + }, + device=device, + use_device_for_ref=True, + ) From 6a91acf9c7dca706c957f7206497158e23af096d Mon Sep 17 00:00:00 2001 From: Franklin <41602287+fcogidi@users.noreply.github.com> Date: Wed, 10 Jan 2024 20:04:44 -0500 Subject: [PATCH 3/6] Add ROC curve (#545) * Add ROC metrics to experimental module * fix typo * Fix docstrings --- .../evaluate/metrics/experimental/__init__.py | 5 + .../experimental/functional/__init__.py | 5 + .../metrics/experimental/functional/roc.py | 651 ++++++++++++++++++ cyclops/evaluate/metrics/experimental/roc.py | 214 ++++++ .../test_precision_recall_curve.py | 2 +- .../evaluate/metrics/experimental/test_roc.py | 492 +++++++++++++ 6 files changed, 1368 insertions(+), 1 deletion(-) create mode 100644 cyclops/evaluate/metrics/experimental/functional/roc.py create mode 100644 cyclops/evaluate/metrics/experimental/roc.py create mode 100644 tests/cyclops/evaluate/metrics/experimental/test_roc.py diff --git a/cyclops/evaluate/metrics/experimental/__init__.py b/cyclops/evaluate/metrics/experimental/__init__.py index 05fea662a..98afc9892 100644 --- a/cyclops/evaluate/metrics/experimental/__init__.py +++ b/cyclops/evaluate/metrics/experimental/__init__.py @@ -45,6 +45,11 @@ MulticlassPrecisionRecallCurve, MultilabelPrecisionRecallCurve, ) +from cyclops.evaluate.metrics.experimental.roc import ( + BinaryROC, + MulticlassROC, + MultilabelROC, +) from cyclops.evaluate.metrics.experimental.specificity import ( BinarySpecificity, BinaryTNR, diff --git a/cyclops/evaluate/metrics/experimental/functional/__init__.py b/cyclops/evaluate/metrics/experimental/functional/__init__.py index 23b91cc37..429b9bd78 100644 --- a/cyclops/evaluate/metrics/experimental/functional/__init__.py +++ b/cyclops/evaluate/metrics/experimental/functional/__init__.py @@ -41,6 +41,11 @@ multiclass_precision_recall_curve, multilabel_precision_recall_curve, ) +from cyclops.evaluate.metrics.experimental.functional.roc import ( + binary_roc, + multiclass_roc, + multilabel_roc, +) from cyclops.evaluate.metrics.experimental.functional.specificity import ( binary_specificity, binary_tnr, diff --git a/cyclops/evaluate/metrics/experimental/functional/roc.py b/cyclops/evaluate/metrics/experimental/functional/roc.py new file mode 100644 index 000000000..2371a5a83 --- /dev/null +++ b/cyclops/evaluate/metrics/experimental/functional/roc.py @@ -0,0 +1,651 @@ +"""Functions for computing Receiver Operating Characteristic (ROC) curves.""" +import warnings +from typing import List, Literal, Optional, Tuple, Union + +import array_api_compat as apc + +from cyclops.evaluate.metrics.experimental.functional.precision_recall_curve import ( + _binary_clf_curve, + _binary_precision_recall_curve_format_arrays, + _binary_precision_recall_curve_update, + _binary_precision_recall_curve_validate_args, + _binary_precision_recall_curve_validate_arrays, + _multiclass_precision_recall_curve_format_arrays, + _multiclass_precision_recall_curve_update, + _multiclass_precision_recall_curve_validate_args, + _multiclass_precision_recall_curve_validate_arrays, + _multilabel_precision_recall_curve_format_arrays, + _multilabel_precision_recall_curve_update, + _multilabel_precision_recall_curve_validate_args, + _multilabel_precision_recall_curve_validate_arrays, +) +from cyclops.evaluate.metrics.experimental.utils.ops import ( + _interp, + flatten, + remove_ignore_index, + safe_divide, +) +from cyclops.evaluate.metrics.experimental.utils.types import Array + + +def _binary_roc_compute( + state: Union[Array, Tuple[Array, Array]], + thresholds: Optional[Array], + pos_label: int = 1, +) -> Tuple[Array, Array, Array]: + """Compute the binary ROC curve.""" + if apc.is_array_api_obj(state) and thresholds is not None: + xp = apc.array_namespace(state, thresholds) + tps = state[:, 1, 1] # type: ignore[call-overload] + fps = state[:, 0, 1] # type: ignore[call-overload] + fns = state[:, 1, 0] # type: ignore[call-overload] + tns = state[:, 0, 0] # type: ignore[call-overload] + tpr = xp.flip(safe_divide(tps, tps + fns), axis=0) + fpr = xp.flip(safe_divide(fps, fps + tns), axis=0) + thresh = xp.flip(thresholds, axis=0) + else: + xp = apc.array_namespace(state[0], state[1]) + fps, tps, thresh = _binary_clf_curve( + state[0], + state[1], + pos_label=pos_label, + ) + + # add extra threshold position so that the curve starts at (0, 0) + tps = xp.concat([xp.zeros(1, dtype=tps.dtype, device=apc.device(tps)), tps]) + fps = xp.concat([xp.zeros(1, dtype=fps.dtype, device=apc.device(fps)), fps]) + thresh = xp.concat( + [ + xp.ones(1, dtype=thresh.dtype, device=apc.device(thresh)), + thresh, + ], + ) + + if fps[-1] <= 0: + warnings.warn( + "No negative samples in targets false positive value should be " + "meaningless. Returning an array of 0s instead.", + UserWarning, + stacklevel=1, + ) + fpr = xp.zeros_like(thresh) + else: + fpr = fps / fps[-1] + + if tps[-1] <= 0: + warnings.warn( + "No positive samples in targets true positive value should be " + "meaningless. Returning an array of 0s instead.", + UserWarning, + stacklevel=1, + ) + tpr = xp.zeros_like(fpr) + else: + tpr = tps / tps[-1] + + return fpr, tpr, thresh + + +def binary_roc( + target: Array, + preds: Array, + thresholds: Optional[Union[int, List[float], Array]] = None, + ignore_index: Optional[int] = None, +) -> Tuple[Array, Array, Array]: + """Compute the receiver operating characteristic (ROC) curve for binary tasks. + + Parameters + ---------- + target : Array + An array object that is compatible with the Python array API standard + and contains the ground truth labels in the range [0, 1]. The expected + shape of the array is `(N, ...)`, where `N` is the number of samples. + preds : Array + An array object that is compatible with the Python array API standard and + contains the probability/logit scores for the positive class. The expected + shape of the array is `(N, ...)` where `N` is the number of samples. If + `preds` contains floating point values that are not in the range `[0, 1]`, + a sigmoid function will be applied to each value before thresholding. + thresholds : Union[int, List[float], Array], optional, default=None + The thresholds to use for computing the ROC curve. Can be one of the following: + - `None`: use all unique values in `preds` as thresholds. + - `int`: use `int` (larger than 1) uniformly spaced thresholds in the range + [0, 1]. + - `List[float]`: use the values in the list as bins for the thresholds. + - `Array`: use the values in the Array as bins for the thresholds. The + array must be 1D. + ignore_index : int, optional, default=None + The value in `target` that should be ignored when computing the ROC curve. + If `None`, all values in `target` are used. + + Returns + ------- + fpr : Array + The false positive rates for all unique thresholds. The shape of the array is + `(num_thresholds + 1,)`. + tpr : Array + The true positive rates for all unique thresholds. The shape of the array is + `(num_thresholds + 1,)`. + thresholds : Array + The thresholds used for computing the ROC curve values, in descending order. + The shape of the array is `(num_thresholds,)`. + + Raises + ------ + TypeError + If `thresholds` is not `None` and not an integer, a list of floats or an + Array of floats. + ValueError + If `thresholds` is an integer and smaller than 2. + ValueError + If `thresholds` is a list of floats with values outside the range [0, 1] + and not monotonically increasing. + ValueError + If `thresholds` is an Array of floats and not all values are in the range + [0, 1] or the array is not one-dimensional. + ValueError + If `ignore_index` is not `None` or an integer. + TypeError + If `target` and `preds` are not compatible with the Python array API standard. + ValueError + If `target` and `preds` are empty. + ValueError + If `target` and `preds` are not numeric arrays. + ValueError + If `target` and `preds` do not have the same shape. + ValueError + If `target` contains floating point values. + ValueError + If `preds` contains non-floating point values. + RuntimeError + If `target` contains values outside the range [0, 1] or does not contain + `ignore_index` if `ignore_index` is not `None`. + ValueError + If the array API namespace of `target` and `preds` are not the same as the + array API namespace of `thresholds`. + + Examples + -------- + >>> from cyclops.evaluate.metrics.experimental.functional import binary_roc + >>> import numpy.array_api as anp + >>> target = anp.asarray([0, 1, 0, 1, 0, 1]) + >>> preds = anp.asarray([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) + >>> fpr, tpr, thresholds = binary_roc(target, preds, thresholds=None) + >>> fpr + Array([0. , 0. , 0.33333334, 0.33333334, + 0.6666667 , 0.6666667 , 1. ], dtype=float32) + >>> tpr + Array([0. , 0.33333334, 0.33333334, 0.6666667 , + 0.6666667 , 1. , 1. ], dtype=float32) + >>> thresholds + Array([1. , 0.92, 0.84, 0.73, 0.33, 0.22, 0.11], dtype=float64) + >>> fpr, tpr, thresholds = binary_roc( + ... target, preds, thresholds=5, + ... ) + >>> fpr + Array([0. , 0.33333334, 0.33333334, 0.6666667 , + 1. ], dtype=float32) + >>> tpr + Array([0. , 0.33333334, 0.6666667 , 0.6666667 , + 1. ], dtype=float32) + >>> thresholds + Array([1. , 0.75, 0.5 , 0.25, 0. ], dtype=float32) + + """ + _binary_precision_recall_curve_validate_args(thresholds, ignore_index) + xp = _binary_precision_recall_curve_validate_arrays( + target, + preds, + thresholds, + ignore_index, + ) + target, preds, thresholds = _binary_precision_recall_curve_format_arrays( + target, + preds, + thresholds, + ignore_index, + xp=xp, + ) + state = _binary_precision_recall_curve_update(target, preds, thresholds, xp=xp) + return _binary_roc_compute(state, thresholds) + + +def _multiclass_roc_compute( + state: Union[Array, Tuple[Array, Array]], + num_classes: int, + thresholds: Optional[Array], + average: Optional[Literal["macro", "micro", "none"]], +) -> Union[Tuple[Array, Array, Array], Tuple[List[Array], List[Array], List[Array]]]: + """Compute the multiclass ROC curve.""" + if average == "micro": + return _binary_roc_compute(state, thresholds) + + if apc.is_array_api_obj(state) and thresholds is not None: + xp = apc.array_namespace(state, thresholds) + tps = state[:, :, 1, 1] # type: ignore[call-overload] + fps = state[:, :, 0, 1] # type: ignore[call-overload] + fns = state[:, :, 1, 0] # type: ignore[call-overload] + tns = state[:, :, 0, 0] # type: ignore[call-overload] + tpr = xp.flip(safe_divide(tps, tps + fns), axis=0).T + fpr = xp.flip(safe_divide(fps, fps + tns), axis=0).T + thresh = xp.flip(thresholds, axis=0) + array_state = True + else: + xp = apc.array_namespace(state[0], state[1]) + fpr_list, tpr_list, thresh_list = [], [], [] + for i in range(num_classes): + res = _binary_roc_compute( + (state[0], state[1][:, i]), + thresholds=None, + pos_label=i, + ) + fpr_list.append(res[0]) + tpr_list.append(res[1]) + thresh_list.append(res[2]) + array_state = False + + if average == "macro": + thresh = ( + xp.concat([xp.expand_dims(thresh, axis=0)] * num_classes, axis=0) # repeat + if array_state + else xp.concat(xp.asarray(thresh_list), 0) + ) + thresh = xp.sort(thresh, descending=True) + mean_fpr = flatten(fpr) if array_state else xp.concat(xp.asarray(fpr_list), 0) + mean_fpr = xp.sort(mean_fpr) + mean_tpr = xp.zeros_like(mean_fpr) + for i in range(num_classes): + mean_tpr += _interp( + mean_fpr, + fpr[i] if array_state else fpr_list[i], + tpr[i] if array_state else tpr_list[i], + ) + mean_tpr /= num_classes + return mean_fpr, mean_tpr, thresh + + if array_state: + return fpr, tpr, thresh + return fpr_list, tpr_list, thresh_list + + +def multiclass_roc( + target: Array, + preds: Array, + num_classes: int, + thresholds: Optional[Union[int, List[float], Array]] = None, + average: Optional[Literal["macro", "micro", "none"]] = None, + ignore_index: Optional[Union[int, Tuple[int]]] = None, +) -> Union[Tuple[Array, Array, Array], Tuple[List[Array], List[Array], List[Array]]]: + """Compute the receiver operating characteristic (ROC) curve for multiclass tasks. + + Parameters + ---------- + target : Array + An array object that is compatible with the Python array API standard + and contains the ground truth labels in the range [0, `num_classes`] + (except if `ignore_index` is specified). The expected shape of the array + is `(N, ...)`, where `N` is the number of samples. + preds : Array + An array object that is compatible with the Python array API standard and + contains the probability/logit scores for each sample. The expected shape + of the array is `(N, C, ...)` where `N` is the number of samples and `C` + is the number of classes. If `preds` contains floating point values that + are not in the range `[0, 1]`, a softmax function will be applied to each + value before thresholding. + num_classes : int + The number of classes in the classification problem. + thresholds : Union[int, List[float], Array], optional, default=None + The thresholds to use for computing the ROC curve. Can be one of the following: + - `None`: use all unique values in `preds` as thresholds. + - `int`: use `int` (larger than 1) uniformly spaced thresholds in the range + [0, 1]. + - `List[float]`: use the values in the list as bins for the thresholds. + - `Array`: use the values in the Array as bins for the thresholds. The + array must be 1D. + average : {"macro", "micro", "none"}, optional, default=None + The type of averaging to use for computing the ROC curve. Can be one of + the following: + - `"macro"`: interpolates the curves from each class at a combined set of + thresholds and then average over the classwise interpolated curves. + - `"micro"`: one-hot encodes the targets and flattens the predictions, + considering all classes jointly as a binary problem. + - `"none"`: do not average over the classwise curves. + ignore_index : int or Tuple[int], optional, default=None + The value(s) in `target` that should be ignored when computing the ROC curve. + If `None`, all values in `target` are used. + + Returns + ------- + fpr : Array or List[Array] + The false positive rates for all unique thresholds. If `thresholds` is `"none"` + or `None`, a list for each class is returned with 1-D Arrays of shape + `(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape + `(num_thresholds + 1, num_classes)` is returned. + tpr : Array or List[Array] + The true positive rates for all unique thresholds. If `thresholds` is `"none"` + or `None`, a list for each class is returned with 1-D Arrays of shape + `(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape + `(num_thresholds + 1, num_classes)` is returned. + thresholds : Array or List[Array] + The thresholds used for computing the ROC curve values, in descending order. + If `thresholds` is `"none"` or `None`, a list for each class is returned + with 1-D Arrays of shape `(num_thresholds,)`. Otherwise, a 1-D Array of + shape `(num_thresholds,)` is returned. + + Raises + ------ + TypeError + If `thresholds` is not `None` and not an integer, a list of floats or an + Array of floats. + ValueError + If `thresholds` is an integer and smaller than 2. + ValueError + If `thresholds` is a list of floats with values outside the range [0, 1] + and not monotonically increasing. + ValueError + If `thresholds` is an Array of floats and not all values are in the range + [0, 1] or the array is not one-dimensional. + ValueError + If `num_classes` is not an integer larger than 1. + ValueError + If `ignore_index` is not `None`, an integer or a tuple of integers. + ValueError + If `average` is not `"macro"`, `"micro"`, `"none"` or `None`. + TypeError + If `target` and `preds` are not compatible with the Python array API standard. + ValueError + If `target` and `preds` are empty. + ValueError + If `target` and `preds` are not numeric arrays. + ValueError + If `preds` does not have one more dimension than `target`. + ValueError + If `target` contains floating point values. + ValueError + If `preds` contains non-floating point values. + ValueError + If the second dimension of `preds` is not equal to `num_classes`. + ValueError + If the first dimension of `preds` is not equal to the first dimension of + `target` or the third dimension of `preds` is not equal to the second + dimension of `target`. + RuntimeError + If `target` contains more unique values than `num_classes` or `num_classes` + plus the number of values in `ignore_index` if `ignore_index` is not `None`. + + Examples + -------- + >>> from cyclops.evaluate.metrics.experimental.functional import multiclass_roc + >>> import numpy.array_api as anp + >>> target = anp.asarray([0, 1, 2, 0, 1, 2]) + >>> preds = anp.asarray( + ... [[0.11, 0.22, 0.67], + ... [0.84, 0.73, 0.12], + ... [0.33, 0.92, 0.44], + ... [0.11, 0.22, 0.67], + ... [0.84, 0.73, 0.12], + ... [0.33, 0.92, 0.44]]) + >>> fpr, tpr, thresholds = multiclass_roc( + ... target, preds, num_classes=3, thresholds=None, + ... ) + >>> fpr + [Array([0. , 0.5, 1. , 1. ], dtype=float32), + Array([0. , 0.5, 0.5, 1. ], dtype=float32), + Array([0. , 0.5, 0.5, 1. ], dtype=float32)] + >>> tpr + [Array([0., 0., 0., 1.], dtype=float32), + Array([0., 0., 1., 1.], dtype=float32), + Array([0., 0., 1., 1.], dtype=float32)] + >>> thresholds + [Array([1. , 0.84, 0.33, 0.11], dtype=float64), + Array([1. , 0.92, 0.73, 0.22], dtype=float64), + Array([1. , 0.67, 0.44, 0.12], dtype=float64)] + >>> fpr, tpr, thresholds = multiclass_roc( + ... target, preds, num_classes=3, thresholds=5, + ... ) + >>> fpr + Array([[0. , 0.5, 0.5, 1. , 1. ], + [0. , 0.5, 0.5, 0.5, 1. ], + [0. , 0. , 0.5, 0.5, 1. ]], dtype=float32) + >>> tpr + Array([[0., 0., 0., 0., 1.], + [0., 0., 1., 1., 1.], + [0., 0., 0., 1., 1.]], dtype=float32) + >>> thresholds + Array([1. , 0.75, 0.5 , 0.25, 0. ], dtype=float32) + + """ # noqa: W505 + _multiclass_precision_recall_curve_validate_args( + num_classes, + thresholds, + ignore_index, + average, + ) + xp = _multiclass_precision_recall_curve_validate_arrays( + target, + preds, + num_classes, + ignore_index, + ) + target, preds, thresholds = _multiclass_precision_recall_curve_format_arrays( + target, + preds, + num_classes, + thresholds, + ignore_index, + average, + xp=xp, + ) + state = _multiclass_precision_recall_curve_update( + target, + preds, + num_classes, + thresholds, + average, + xp=xp, + ) + return _multiclass_roc_compute( + state, + num_classes, + thresholds=thresholds, + average=average, + ) + + +def _multilabel_roc_compute( + state: Union[Array, Tuple[Array, Array]], + num_labels: int, + thresholds: Optional[Array], + ignore_index: Optional[int], +) -> Union[Tuple[Array, Array, Array], Tuple[List[Array], List[Array], List[Array]]]: + """Compute the multilabel ROC curve.""" + if apc.is_array_api_obj(state) and thresholds is not None: + xp = apc.array_namespace(state) + tps = state[:, :, 1, 1] # type: ignore[call-overload] + fps = state[:, :, 0, 1] # type: ignore[call-overload] + fns = state[:, :, 1, 0] # type: ignore[call-overload] + tns = state[:, :, 0, 0] # type: ignore[call-overload] + tpr = xp.flip(safe_divide(tps, tps + fns), axis=0).T + fpr = xp.flip(safe_divide(fps, fps + tns), axis=0).T + thresh = xp.flip(thresholds, axis=0) + return fpr, tpr, thresh + + fpr_list, tpr_list, thresh_list = [], [], [] + for i in range(num_labels): + target = state[0][:, i] + preds = state[1][:, i] + if ignore_index is not None: + target, preds = remove_ignore_index( + target, + preds, + ignore_index=ignore_index, + ) + res = _binary_roc_compute((target, preds), thresholds=None, pos_label=1) + fpr_list.append(res[0]) + tpr_list.append(res[1]) + thresh_list.append(res[2]) + return fpr_list, tpr_list, thresh_list + + +def multilabel_roc( + target: Array, + preds: Array, + num_labels: int, + thresholds: Optional[Union[int, List[float], Array]] = None, + ignore_index: Optional[int] = None, +) -> Union[Tuple[Array, Array, Array], Tuple[List[Array], List[Array], List[Array]]]: + """Compute the receiver operating characteristic (ROC) curve for multilabel tasks. + + Parameters + ---------- + target : Array + The target array of shape `(N, L, ...)` containing the ground truth labels + in the range [0, 1], where `N` is the number of samples and `L` is the + number of labels. + preds : Array + The prediction array of shape `(N, L, ...)` containing the probability/logit + scores for each sample, where `N` is the number of samples and `L` is the + number of labels. If `preds` contains floating point values that are not + in the range [0,1], they will be converted to probabilities using the + sigmoid function. + num_labels : int + The number of labels in the multilabel classification problem. + thresholds : Union[int, List[float], Array], optional, default=None + The thresholds to use for computing the ROC curve. Can be one of the following: + - `None`: use all unique values in `preds` as thresholds. + - `int`: use `int` (larger than 1) uniformly spaced thresholds in the range + [0, 1]. + - `List[float]`: use the values in the list as bins for the thresholds. + - `Array`: use the values in the Array as bins for the thresholds. The + array must be 1D. + ignore_index : int, optional, default=None + The value in `target` that should be ignored when computing the ROC curve. + If `None`, all values in `target` are used. + + Returns + ------- + fpr : Array or List[Array] + The false positive rates for all unique thresholds. If `thresholds` is `None`, + a list for each label is returned with 1-D Arrays of shape + `(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape + `(num_thresholds + 1, num_labels)` is returned. + tpr : Array or List[Array] + The true positive rates for all unique thresholds. If `thresholds` is `None`, + a list for each label is returned with 1-D Arrays of shape + `(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape + `(num_thresholds + 1, num_labels)` is returned. + thresholds : Array or List[Array] + The thresholds used for computing the ROC curve values, in + descending order. If `thresholds` is `None`, a list for each label is + returned with 1-D Arrays of shape `(num_thresholds,)`. Otherwise, a 1-D + Array of shape `(num_thresholds,)` is returned. + + Raises + ------ + TypeError + If `thresholds` is not `None` and not an integer, a list of floats or an + Array of floats. + ValueError + If `thresholds` is an integer and smaller than 2. + ValueError + If `thresholds` is a list of floats with values outside the range [0, 1] + and not monotonically increasing. + ValueError + If `thresholds` is an Array of floats and not all values are in the range + [0, 1] or the array is not one-dimensional. + ValueError + If `ignore_index` is not `None` or an integer. + ValueError + If `num_labels` is not an integer larger than 1. + TypeError + If `target` and `preds` are not compatible with the Python array API standard. + ValueError + If `target` and `preds` are empty. + ValueError + If `target` and `preds` are not numeric arrays. + ValueError + If `target` and `preds` do not have the same shape. + ValueError + If `target` contains floating point values. + ValueError + If `preds` contains non-floating point values. + RuntimeError + If `target` contains values outside the range [0, 1] or does not contain + `ignore_index` if `ignore_index` is not `None`. + ValueError + If the array API namespace of `target` and `preds` are not the same as the + array API namespace of `thresholds`. + ValueError + If the second dimension of `preds` is not equal to `num_labels`. + + Examples + -------- + >>> from cyclops.evaluate.metrics.experimental.functional import multilabel_roc + >>> import numpy.array_api as anp + >>> target = anp.asarray([[0, 1, 0], [1, 1, 0], [0, 0, 1]]) + >>> preds = anp.asarray( + ... [[0.11, 0.22, 0.67], [0.84, 0.73, 0.12], [0.33, 0.92, 0.44]], + ... ) + >>> fpr, tpr, thresholds = multilabel_roc( + ... target, preds, num_labels=3, thresholds=None, + ... ) + >>> fpr + [Array([0. , 0. , 0.5, 1. ], dtype=float32), + Array([0., 1., 1., 1.], dtype=float32), + Array([0. , 0.5, 0.5, 1. ], dtype=float32)] + >>> tpr + [Array([0., 1., 1., 1.], dtype=float32), + Array([0. , 0. , 0.5, 1. ], dtype=float32), + Array([0., 0., 1., 1.], dtype=float32)] + >>> thresholds + [Array([1. , 0.84, 0.33, 0.11], dtype=float64), + Array([1. , 0.92, 0.73, 0.22], dtype=float64), + Array([1. , 0.67, 0.44, 0.12], dtype=float64)] + >>> fpr, tpr, thresholds = multilabel_roc( + ... target, preds, num_labels=3, thresholds=5, + ... ) + >>> fpr + Array([[0. , 0. , 0. , 0.5, 1. ], + [0. , 1. , 1. , 1. , 1. ], + [0. , 0. , 0.5, 0.5, 1. ]], dtype=float32) + >>> tpr + Array([[0. , 1. , 1. , 1. , 1. ], + [0. , 0. , 0.5, 0.5, 1. ], + [0. , 0. , 0. , 1. , 1. ]], dtype=float32) + >>> thresholds + Array([1. , 0.75, 0.5 , 0.25, 0. ], dtype=float32) + + """ # noqa: W505 + _multilabel_precision_recall_curve_validate_args( + num_labels, + thresholds, + ignore_index, + ) + xp = _multilabel_precision_recall_curve_validate_arrays( + target, + preds, + num_labels, + thresholds=thresholds, + ignore_index=ignore_index, + ) + target, preds, thresholds = _multilabel_precision_recall_curve_format_arrays( + target, + preds, + num_labels, + thresholds, + ignore_index, + xp=xp, + ) + state = _multilabel_precision_recall_curve_update( + target, + preds, + num_labels, + thresholds, + xp=xp, + ) + return _multilabel_roc_compute( + state, + num_labels, + thresholds, + ignore_index, + ) diff --git a/cyclops/evaluate/metrics/experimental/roc.py b/cyclops/evaluate/metrics/experimental/roc.py new file mode 100644 index 000000000..942cc4e89 --- /dev/null +++ b/cyclops/evaluate/metrics/experimental/roc.py @@ -0,0 +1,214 @@ +"""Classes for computing the Receiver Operating Characteristic (ROC) curve.""" +from typing import List, Tuple, Union + +from cyclops.evaluate.metrics.experimental.functional.roc import ( + _binary_roc_compute, + _multiclass_roc_compute, + _multilabel_roc_compute, +) +from cyclops.evaluate.metrics.experimental.precision_recall_curve import ( + BinaryPrecisionRecallCurve, + MulticlassPrecisionRecallCurve, + MultilabelPrecisionRecallCurve, +) +from cyclops.evaluate.metrics.experimental.utils.ops import dim_zero_cat +from cyclops.evaluate.metrics.experimental.utils.types import Array + + +class BinaryROC(BinaryPrecisionRecallCurve): + """The receiver operating characteristic (ROC) curve. + + Parameters + ---------- + thresholds : Union[int, List[float], Array], optional, default=None + The thresholds to use for computing the ROC curve. Can be one of the following: + - `None`: use all unique values in `preds` as thresholds. + - `int`: use `int` (larger than 1) uniformly spaced thresholds in the range + [0, 1]. + - `List[float]`: use the values in the list as bins for the thresholds. + - `Array`: use the values in the Array as bins for the thresholds. The + array must be 1D. + ignore_index : int, optional, default=None + The value in `target` that should be ignored when computing the ROC curve. + If `None`, all values in `target` are used. + + Examples + -------- + >>> import numpy.array_api as anp + >>> from cyclops.evaluate.metrics.experimental import BinaryROC + >>> target = anp.asarray([0, 1, 0, 1, 0, 1]) + >>> preds = anp.asarray([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) + >>> metric = BinaryROC(thresholds=None) + >>> metric(target, preds) + (Array([0. , 0. , 0.33333334, 0.33333334, + 0.6666667 , 0.6666667 , 1. ], dtype=float32), Array([0. , 0.33333334, 0.33333334, 0.6666667 , + 0.6666667 , 1. , 1. ], dtype=float32), Array([1. , 0.92, 0.84, 0.73, 0.33, 0.22, 0.11], dtype=float64)) + >>> metric = BinaryROC(thresholds=5) + >>> metric(target, preds) + (Array([0. , 0.33333334, 0.33333334, 0.6666667 , + 1. ], dtype=float32), Array([0. , 0.33333334, 0.6666667 , 0.6666667 , + 1. ], dtype=float32), Array([1. , 0.75, 0.5 , 0.25, 0. ], dtype=float32)) + + """ # noqa: W505 + + name: str = "ROC Curve" + + def _compute_metric(self) -> Tuple[Array, Array, Array]: + state = ( + (dim_zero_cat(self.target), dim_zero_cat(self.preds)) # type: ignore[attr-defined] + if self.thresholds is None + else self.confmat # type: ignore[attr-defined] + ) + return _binary_roc_compute(state, self.thresholds) # type: ignore[arg-type] + + +class MulticlassROC(MulticlassPrecisionRecallCurve): + """The reciever operator characteristics (ROC) curve. + + Parameters + ---------- + num_classes : int + The number of classes in the classification problem. + thresholds : Union[int, List[float], Array], optional, default=None + The thresholds to use for computing the ROC curve. Can be one + of the following: + - `None`: use all unique values in `preds` as thresholds. + - `int`: use `int` (larger than 1) uniformly spaced thresholds in the range + [0, 1]. + - `List[float]`: use the values in the list as bins for the thresholds. + - `Array`: use the values in the Array as bins for the thresholds. The + array must be 1D. + average : {"macro", "micro", "none"}, optional, default=None + The type of averaging to use for computing the ROC curve. Can be one of + the following: + - `"macro"`: interpolates the curves from each class at a combined set of + thresholds and then average over the classwise interpolated curves. + - `"micro"`: one-hot encodes the targets and flattens the predictions, + considering all classes jointly as a binary problem. + - `"none"`: do not average over the classwise curves. + ignore_index : int or Tuple[int], optional, default=None + The value(s) in `target` that should be ignored when computing the ROC curve. + If `None`, all values in `target` are used. + + Examples + -------- + >>> import numpy.array_api as anp + >>> from cyclops.evaluate.metrics.experimental import MulticlassROC + >>> target = anp.asarray([0, 1, 2, 0, 1, 2]) + >>> preds = anp.asarray( + ... [[0.11, 0.22, 0.67], + ... [0.84, 0.73, 0.12], + ... [0.33, 0.92, 0.44], + ... [0.11, 0.22, 0.67], + ... [0.84, 0.73, 0.12], + ... [0.33, 0.92, 0.44]] + ... ) + >>> metric = MulticlassROC(num_classes=3, thresholds=None) + >>> metric(target, preds) + ([Array([0. , 0.5, 1. , 1. ], dtype=float32), + Array([0. , 0.5, 0.5, 1. ], dtype=float32), + Array([0. , 0.5, 0.5, 1. ], dtype=float32)], + [Array([0., 0., 0., 1.], dtype=float32), + Array([0., 0., 1., 1.], dtype=float32), + Array([0., 0., 1., 1.], dtype=float32)], + [Array([1. , 0.84, 0.33, 0.11], dtype=float64), + Array([1. , 0.92, 0.73, 0.22], dtype=float64), + Array([1. , 0.67, 0.44, 0.12], dtype=float64)]) + >>> metric = MulticlassROC(num_classes=3, thresholds=5) + >>> metric(target, preds) + (Array([[0. , 0.5, 0.5, 1. , 1. ], + [0. , 0.5, 0.5, 0.5, 1. ], + [0. , 0. , 0.5, 0.5, 1. ]], dtype=float32), Array([[0., 0., 0., 0., 1.], + [0., 0., 1., 1., 1.], + [0., 0., 0., 1., 1.]], dtype=float32), Array([1. , 0.75, 0.5 , 0.25, 0. ], dtype=float32)) + + """ # noqa: W505 + + name: str = "ROC Curve" + + def _compute_metric( + self, + ) -> Union[ + Tuple[Array, Array, Array], + Tuple[List[Array], List[Array], List[Array]], + ]: + state = ( + (dim_zero_cat(self.target), dim_zero_cat(self.preds)) # type: ignore[attr-defined] + if self.thresholds is None + else self.confmat # type: ignore[attr-defined] + ) + return _multiclass_roc_compute( + state, + self.num_classes, + self.thresholds, # type: ignore[arg-type] + self.average, + ) + + +class MultilabelROC(MultilabelPrecisionRecallCurve): + """The reciever operator characteristics (ROC) curve. + + Parameters + ---------- + num_labels : int + The number of labels in the multilabel classification problem. + thresholds : Union[int, List[float], Array], optional, default=None + The thresholds to use for computing the ROC curve. Can be one of the following: + - `None`: use all unique values in `preds` as thresholds. + - `int`: use `int` (larger than 1) uniformly spaced thresholds in the range + [0, 1]. + - `List[float]`: use the values in the list as bins for the thresholds. + - `Array`: use the values in the Array as bins for the thresholds. The + array must be 1D. + ignore_index : int, optional, default=None + The value in `target` that should be ignored when computing the ROC Curve. + If `None`, all values in `target` are used. + + Examples + -------- + >>> import numpy.array_api as anp + >>> from cyclops.evaluate.metrics.experimental import MultilabelROC + >>> target = anp.asarray([[0, 1, 0], [1, 1, 0], [0, 0, 1]]) + >>> preds = anp.asarray( + ... [[0.11, 0.22, 0.67], [0.84, 0.73, 0.12], [0.33, 0.92, 0.44]], + ... ) + >>> metric = MultilabelROC(num_labels=3, thresholds=None) + >>> metric(target, preds) + ([Array([0. , 0. , 0.5, 1. ], dtype=float32), + Array([0., 1., 1., 1.], dtype=float32), + Array([0. , 0.5, 0.5, 1. ], dtype=float32)], + [Array([0., 1., 1., 1.], dtype=float32), + Array([0. , 0. , 0.5, 1. ], dtype=float32), + Array([0., 0., 1., 1.], dtype=float32)], + [Array([1. , 0.84, 0.33, 0.11], dtype=float64), + Array([1. , 0.92, 0.73, 0.22], dtype=float64), + Array([1. , 0.67, 0.44, 0.12], dtype=float64)]) + >>> metric = MultilabelROC(num_labels=3, thresholds=5) + >>> metric(target, preds) + (Array([[0. , 0. , 0. , 0.5, 1. ], + [0. , 1. , 1. , 1. , 1. ], + [0. , 0. , 0.5, 0.5, 1. ]], dtype=float32), Array([[0. , 1. , 1. , 1. , 1. ], + [0. , 0. , 0.5, 0.5, 1. ], + [0. , 0. , 0. , 1. , 1. ]], dtype=float32), Array([1. , 0.75, 0.5 , 0.25, 0. ], dtype=float32)) + + """ # noqa: W505 + + name: str = "ROC Curve" + + def _compute_metric( + self, + ) -> Union[ + Tuple[Array, Array, Array], + Tuple[List[Array], List[Array], List[Array]], + ]: + state = ( + (dim_zero_cat(self.target), dim_zero_cat(self.preds)) # type: ignore[attr-defined] + if self.thresholds is None + else self.confmat # type: ignore[attr-defined] + ) + return _multilabel_roc_compute( + state, + self.num_labels, + self.thresholds, # type: ignore[arg-type] + self.ignore_index, + ) diff --git a/tests/cyclops/evaluate/metrics/experimental/test_precision_recall_curve.py b/tests/cyclops/evaluate/metrics/experimental/test_precision_recall_curve.py index b19e36822..f0608d059 100644 --- a/tests/cyclops/evaluate/metrics/experimental/test_precision_recall_curve.py +++ b/tests/cyclops/evaluate/metrics/experimental/test_precision_recall_curve.py @@ -101,7 +101,7 @@ def test_binary_precision_recall_curve_function_with_numpy_array_api_arrays( @pytest.mark.parametrize("inputs", _binary_cases(xp=anp)[3:]) @pytest.mark.parametrize("thresholds", _thresholds_for_prc(xp=anp)) @pytest.mark.parametrize("ignore_index", [None, 0, -1]) - def test_binary_precision_recall_cuvrve_class_with_numpy_array_api_arrays( + def test_binary_precision_recall_curve_class_with_numpy_array_api_arrays( self, inputs, thresholds, diff --git a/tests/cyclops/evaluate/metrics/experimental/test_roc.py b/tests/cyclops/evaluate/metrics/experimental/test_roc.py new file mode 100644 index 000000000..c1a977268 --- /dev/null +++ b/tests/cyclops/evaluate/metrics/experimental/test_roc.py @@ -0,0 +1,492 @@ +"""Test roc curve metric.""" +from functools import partial +from types import ModuleType +from typing import List, Tuple, Union + +import array_api_compat as apc +import array_api_compat.torch +import numpy.array_api as anp +import pytest +import torch.utils.dlpack +from torchmetrics.functional.classification import ( + binary_roc as tm_binary_roc, +) +from torchmetrics.functional.classification import ( + multiclass_roc as tm_multiclass_roc, +) +from torchmetrics.functional.classification import ( + multilabel_roc as tm_multilabel_roc, +) + +from cyclops.evaluate.metrics.experimental.functional.roc import ( + binary_roc, + multiclass_roc, + multilabel_roc, +) +from cyclops.evaluate.metrics.experimental.roc import ( + BinaryROC, + MulticlassROC, + MultilabelROC, +) +from cyclops.evaluate.metrics.experimental.utils.ops import to_int +from cyclops.evaluate.metrics.experimental.utils.validation import is_floating_point + +from ..conftest import NUM_CLASSES, NUM_LABELS +from .inputs import _binary_cases, _multiclass_cases, _multilabel_cases +from .testers import MetricTester, _inject_ignore_index + + +def _thresholds_for_roc(*, xp: ModuleType) -> list: + """Return thresholds for roc curve.""" + thresh_list = [0.0, 0.3, 0.5, 0.7, 0.9, 1.0] + return [None, 5, thresh_list, xp.asarray(thresh_list)] + + +def _binary_roc_reference( + target, + preds, + thresholds, + ignore_index, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Return the reference binary roc curve.""" + return tm_binary_roc( + torch.utils.dlpack.from_dlpack(preds), + torch.utils.dlpack.from_dlpack(target), + thresholds=torch.utils.dlpack.from_dlpack(thresholds) + if apc.is_array_api_obj(thresholds) + else thresholds, + ignore_index=ignore_index, + ) + + +class TestBinaryROC(MetricTester): + """Test binary roc curve function and class.""" + + @pytest.mark.parametrize("inputs", _binary_cases(xp=anp)[3:]) + @pytest.mark.parametrize("thresholds", _thresholds_for_roc(xp=anp)) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_binary_roc_function_with_numpy_array_api_arrays( + self, + inputs, + thresholds, + ignore_index, + ) -> None: + """Test function for binary roc curve using array_api arrays.""" + target, preds = inputs + + if ignore_index is not None: + if target.shape[1] == 1 and anp.any(target == ignore_index): + pytest.skip( + "When targets are single elements and 'ignore_index' in target " + "the function will raise an error because it will receive an " + "empty array after filtering out the 'ignore_index' values.", + ) + target = _inject_ignore_index(target, ignore_index) + + self.run_metric_function_implementation_test( + target, + preds, + metric_function=binary_roc, + metric_args={ + "thresholds": thresholds, + "ignore_index": ignore_index, + }, + reference_metric=partial( + _binary_roc_reference, + thresholds=thresholds, + ignore_index=ignore_index, + ), + ) + + @pytest.mark.parametrize("inputs", _binary_cases(xp=anp)[3:]) + @pytest.mark.parametrize("thresholds", _thresholds_for_roc(xp=anp)) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_binary_roc_class_with_numpy_array_api_arrays( + self, + inputs, + thresholds, + ignore_index, + ) -> None: + """Test class for binary roc curve using array_api arrays.""" + target, preds = inputs + + if ( + preds.shape[1] == 1 + and is_floating_point(preds) + and not anp.all(to_int((preds >= 0)) * to_int((preds <= 1))) + ): + pytest.skip( + "When using 0-D logits, batch result will be different from local " + "result because the `sigmoid` operation may not be applied to each " + "batch (some values may be in [0, 1] and some may not).", + ) + + if ignore_index is not None: + if target.shape[1] == 1 and anp.any(target == ignore_index): + pytest.skip( + "When targets are single elements and 'ignore_index' in target " + "the function will raise an error because it will receive an " + "empty array after filtering out the 'ignore_index' values.", + ) + target = _inject_ignore_index(target, ignore_index) + + self.run_metric_class_implementation_test( + target, + preds, + metric_class=BinaryROC, + metric_args={ + "thresholds": thresholds, + "ignore_index": ignore_index, + }, + reference_metric=partial( + _binary_roc_reference, + thresholds=thresholds, + ignore_index=ignore_index, + ), + ) + + @pytest.mark.integration_test() # machine for integration tests has GPU + @pytest.mark.parametrize("inputs", _binary_cases(xp=array_api_compat.torch)[3:]) + @pytest.mark.parametrize( + "thresholds", + _thresholds_for_roc(xp=array_api_compat.torch), + ) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_binary_roc_with_torch_tensors( + self, + inputs, + thresholds, + ignore_index, + ) -> None: + """Test binary roc curve class with torch tensors.""" + target, preds = inputs + + if ( + preds.shape[1] == 1 + and is_floating_point(preds) + and not torch.all(to_int((preds >= 0)) * to_int((preds <= 1))) + ): + pytest.skip( + "When using 0-D logits, batch result will be different from local " + "result because the `sigmoid` operation may not be applied to each " + "batch (some values may be in [0, 1] and some may not).", + ) + + if ignore_index is not None: + if target.shape[1] == 1 and torch.any(target == ignore_index): + pytest.skip( + "When targets are single elements and 'ignore_index' in target " + "the function will raise an error because it will receive an " + "empty array after filtering out the 'ignore_index' values.", + ) + target = _inject_ignore_index(target, ignore_index) + + device = "cuda" if torch.cuda.is_available() else "cpu" + if isinstance(thresholds, torch.Tensor): + thresholds = thresholds.to(device) + + self.run_metric_class_implementation_test( + target, + preds, + metric_class=BinaryROC, + metric_args={ + "thresholds": thresholds, + "ignore_index": ignore_index, + }, + reference_metric=partial( + _binary_roc_reference, + thresholds=thresholds, + ignore_index=ignore_index, + ), + device=device, + use_device_for_ref=True, + ) + + +def _multiclass_roc_reference( + target, + preds, + num_classes=NUM_CLASSES, + thresholds=None, + ignore_index=None, +) -> Union[ + Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]], +]: + """Return the reference multiclass roc curve.""" + if preds.ndim == 1 and is_floating_point(preds): + xp = apc.array_namespace(preds) + preds = xp.argmax(preds, axis=0) + + return tm_multiclass_roc( + torch.utils.dlpack.from_dlpack(preds), + torch.utils.dlpack.from_dlpack(target), + num_classes, + thresholds=torch.utils.dlpack.from_dlpack(thresholds) + if apc.is_array_api_obj(thresholds) + else thresholds, + ignore_index=ignore_index, + ) + + +class TestMulticlassROC(MetricTester): + """Test multiclass roc curve function and class.""" + + @pytest.mark.parametrize("inputs", _multiclass_cases(xp=anp)[4:]) + @pytest.mark.parametrize("thresholds", _thresholds_for_roc(xp=anp)) + @pytest.mark.parametrize("average", [None, "none"]) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_multiclass_roc_with_numpy_array_api_arrays( + self, + inputs, + thresholds, + average, + ignore_index, + ) -> None: + """Test function for multiclass roc curve.""" + target, preds = inputs + + if ignore_index is not None: + if target.shape[1] == 1 and anp.any(target == ignore_index): + pytest.skip( + "When targets are single elements and 'ignore_index' in target " + "the function will raise an error because it will receive an " + "empty array after filtering out the 'ignore_index' values.", + ) + target = _inject_ignore_index(target, ignore_index) + + self.run_metric_function_implementation_test( + target, + preds, + metric_function=multiclass_roc, + metric_args={ + "num_classes": NUM_CLASSES, + "thresholds": thresholds, + "average": average, + "ignore_index": ignore_index, + }, + reference_metric=partial( + _multiclass_roc_reference, + thresholds=thresholds, + ignore_index=ignore_index, + ), + ) + + @pytest.mark.parametrize("inputs", _multiclass_cases(xp=anp)[4:]) + @pytest.mark.parametrize("thresholds", _thresholds_for_roc(xp=anp)) + @pytest.mark.parametrize("average", [None, "none"]) + @pytest.mark.parametrize("ignore_index", [None, 1, -1]) + def test_multiclass_roc_class_with_numpy_array_api_arrays( + self, + inputs, + thresholds, + average, + ignore_index, + ) -> None: + """Test class for multiclass roc curve.""" + target, preds = inputs + + if ignore_index is not None: + if target.shape[1] == 1 and anp.any(target == ignore_index): + pytest.skip( + "When targets are single elements and 'ignore_index' in target " + "the function will raise an error because it will receive an " + "empty array after filtering out the 'ignore_index' values.", + ) + target = _inject_ignore_index(target, ignore_index) + + self.run_metric_class_implementation_test( + target, + preds, + metric_class=MulticlassROC, + reference_metric=partial( + _multiclass_roc_reference, + thresholds=thresholds, + ignore_index=ignore_index, + ), + metric_args={ + "num_classes": NUM_CLASSES, + "average": average, + "thresholds": thresholds, + "ignore_index": ignore_index, + }, + ) + + @pytest.mark.integration_test() # machine for integration tests has GPU + @pytest.mark.parametrize("inputs", _multiclass_cases(xp=array_api_compat.torch)[4:]) + @pytest.mark.parametrize( + "thresholds", + _thresholds_for_roc(xp=array_api_compat.torch), + ) + @pytest.mark.parametrize("average", [None, "none"]) + @pytest.mark.parametrize("ignore_index", [None, 1, -1]) + def test_multiclass_roc_class_with_torch_tensors( + self, + inputs, + thresholds, + average, + ignore_index, + ) -> None: + """Test class for multiclass roc curve.""" + target, preds = inputs + + if ignore_index is not None: + if target.shape[1] == 1 and torch.any(target == ignore_index): + pytest.skip( + "When targets are single elements and 'ignore_index' in target " + "the function will raise an error because it will receive an " + "empty array after filtering out the 'ignore_index' values.", + ) + target = _inject_ignore_index(target, ignore_index) + + device = "cuda" if torch.cuda.is_available() else "cpu" + if isinstance(thresholds, torch.Tensor): + thresholds = thresholds.to(device) + + self.run_metric_class_implementation_test( + target, + preds, + metric_class=MulticlassROC, + reference_metric=partial( + _multiclass_roc_reference, + thresholds=thresholds, + ignore_index=ignore_index, + ), + metric_args={ + "num_classes": NUM_CLASSES, + "thresholds": thresholds, + "average": average, + "ignore_index": ignore_index, + }, + device=device, + use_device_for_ref=True, + ) + + +def _multilabel_roc_reference( + preds, + target, + num_labels=NUM_LABELS, + thresholds=None, + ignore_index=None, +) -> Union[ + Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]], +]: + """Return the reference multilabel roc curve.""" + return tm_multilabel_roc( + torch.utils.dlpack.from_dlpack(preds), + torch.utils.dlpack.from_dlpack(target), + num_labels, + thresholds=torch.utils.dlpack.from_dlpack(thresholds) + if apc.is_array_api_obj(thresholds) + else thresholds, + ignore_index=ignore_index, + ) + + +class TestMultilabelROC(MetricTester): + """Test multilabel roc curve function and class.""" + + @pytest.mark.parametrize("inputs", _multilabel_cases(xp=anp)[2:]) + @pytest.mark.parametrize("thresholds", _thresholds_for_roc(xp=anp)) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_multilabel_roc_with_numpy_array_api_arrays( + self, + inputs, + thresholds, + ignore_index, + ) -> None: + """Test function for multilabel roc curve.""" + target, preds = inputs + + if ignore_index is not None: + target = _inject_ignore_index(target, ignore_index) + + self.run_metric_function_implementation_test( + target, + preds, + metric_function=multilabel_roc, + reference_metric=partial( + _multilabel_roc_reference, + thresholds=thresholds, + ignore_index=ignore_index, + ), + metric_args={ + "thresholds": thresholds, + "num_labels": NUM_LABELS, + "ignore_index": ignore_index, + }, + ) + + @pytest.mark.parametrize("inputs", _multilabel_cases(xp=anp)[2:]) + @pytest.mark.parametrize("thresholds", _thresholds_for_roc(xp=anp)) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_multilabel_roc_class_with_numpy_array_api_arrays( + self, + inputs, + thresholds, + ignore_index, + ) -> None: + """Test class for multilabel roc curve.""" + target, preds = inputs + + if ignore_index is not None: + target = _inject_ignore_index(target, ignore_index) + + self.run_metric_class_implementation_test( + target, + preds, + metric_class=MultilabelROC, + reference_metric=partial( + _multilabel_roc_reference, + thresholds=thresholds, + ignore_index=ignore_index, + ), + metric_args={ + "thresholds": thresholds, + "num_labels": NUM_LABELS, + "ignore_index": ignore_index, + }, + ) + + @pytest.mark.integration_test() # machine for integration tests has GPU + @pytest.mark.parametrize("inputs", _multilabel_cases(xp=array_api_compat.torch)[2:]) + @pytest.mark.parametrize( + "thresholds", + _thresholds_for_roc(xp=array_api_compat.torch), + ) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_multilabel_roc_class_with_torch_tensors( + self, + inputs, + thresholds, + ignore_index, + ) -> None: + """Test class for multilabel roc curve.""" + target, preds = inputs + + if ignore_index is not None: + target = _inject_ignore_index(target, ignore_index) + + device = "cuda" if torch.cuda.is_available() else "cpu" + if isinstance(thresholds, torch.Tensor): + thresholds = thresholds.to(device) + + self.run_metric_class_implementation_test( + target, + preds, + metric_class=MultilabelROC, + reference_metric=partial( + _multilabel_roc_reference, + thresholds=thresholds, + ignore_index=ignore_index, + ), + metric_args={ + "thresholds": thresholds, + "num_labels": NUM_LABELS, + "ignore_index": ignore_index, + }, + device=device, + use_device_for_ref=True, + ) From df4739d69494ab24f1604070dbd9cdfe3e8604d3 Mon Sep 17 00:00:00 2001 From: Franklin <41602287+fcogidi@users.noreply.github.com> Date: Fri, 12 Jan 2024 17:45:23 -0500 Subject: [PATCH 4/6] Add AUROC (#546) * Add AUROC metric to experimental module * Refactor binary and multiclass ROC functions * Refactor tests to use a common thresholds list * Fix mypy error --- .../evaluate/metrics/experimental/__init__.py | 5 + .../evaluate/metrics/experimental/auroc.py | 258 +++++++ .../experimental/functional/__init__.py | 5 + .../metrics/experimental/functional/auroc.py | 638 ++++++++++++++++++ .../functional/precision_recall_curve.py | 20 +- .../metrics/experimental/functional/roc.py | 16 +- .../metrics/experimental/utils/ops.py | 307 ++++++++- .../evaluate/metrics/experimental/inputs.py | 8 + .../metrics/experimental/test_auroc.py | 517 ++++++++++++++ .../test_precision_recall_curve.py | 27 +- .../evaluate/metrics/experimental/test_roc.py | 27 +- 11 files changed, 1770 insertions(+), 58 deletions(-) create mode 100644 cyclops/evaluate/metrics/experimental/auroc.py create mode 100644 cyclops/evaluate/metrics/experimental/functional/auroc.py create mode 100644 tests/cyclops/evaluate/metrics/experimental/test_auroc.py diff --git a/cyclops/evaluate/metrics/experimental/__init__.py b/cyclops/evaluate/metrics/experimental/__init__.py index 98afc9892..e052ee224 100644 --- a/cyclops/evaluate/metrics/experimental/__init__.py +++ b/cyclops/evaluate/metrics/experimental/__init__.py @@ -4,6 +4,11 @@ MulticlassAccuracy, MultilabelAccuracy, ) +from cyclops.evaluate.metrics.experimental.auroc import ( + BinaryAUROC, + MulticlassAUROC, + MultilabelAUROC, +) from cyclops.evaluate.metrics.experimental.confusion_matrix import ( BinaryConfusionMatrix, MulticlassConfusionMatrix, diff --git a/cyclops/evaluate/metrics/experimental/auroc.py b/cyclops/evaluate/metrics/experimental/auroc.py new file mode 100644 index 000000000..17c6af31f --- /dev/null +++ b/cyclops/evaluate/metrics/experimental/auroc.py @@ -0,0 +1,258 @@ +"""Classes for computing the area under the ROC curve.""" +from typing import List, Literal, Optional, Tuple, Union + +from cyclops.evaluate.metrics.experimental.functional.auroc import ( + _binary_auroc_compute, + _binary_auroc_validate_args, + _multiclass_auroc_compute, + _multiclass_auroc_validate_args, + _multilabel_auroc_compute, + _multilabel_auroc_validate_args, +) +from cyclops.evaluate.metrics.experimental.precision_recall_curve import ( + BinaryPrecisionRecallCurve, + MulticlassPrecisionRecallCurve, + MultilabelPrecisionRecallCurve, +) +from cyclops.evaluate.metrics.experimental.utils.ops import dim_zero_cat +from cyclops.evaluate.metrics.experimental.utils.types import Array + + +class BinaryAUROC(BinaryPrecisionRecallCurve): + """Area under the Receiver Operating Characteristic (ROC) curve. + + Parameters + ---------- + max_fpr : float, optional, default=None + If not `None`, computes the maximum area under the curve up to the given + false positive rate value. Must be a float in the range (0, 1]. + thresholds : Union[int, List[float], Array], optional, default=None + The thresholds to use for computing the ROC curve. Can be one of the following: + - `None`: use all unique values in `preds` as thresholds. + - `int`: use `int` (larger than 1) uniformly spaced thresholds in the range + [0, 1]. + - `List[float]`: use the values in the list as bins for the thresholds. + - `Array`: use the values in the Array as bins for the thresholds. The + array must be 1D. + ignore_index : int, optional, default=None + The value in `target` that should be ignored when computing the AUROC. + If `None`, all values in `target` are used. + + Examples + -------- + >>> import numpy.array_api as anp + >>> from cyclops.evaluate.metrics.experimental import BinaryAUROC + >>> target = anp.asarray([0, 1, 1, 0, 1, 0, 0, 1]) + >>> preds = anp.asarray([0.1, 0.4, 0.35, 0.8, 0.2, 0.6, 0.7, 0.3]) + >>> auroc = BinaryAUROC(thresholds=None) + >>> auroc(target, preds) + Array(0.25, dtype=float32) + >>> auroc = BinaryAUROC(thresholds=5) + >>> auroc(target, preds) + Array(0.21875, dtype=float32) + """ + + name: str = "AUC ROC Curve" + + def __init__( + self, + max_fpr: Optional[float] = None, + thresholds: Optional[Union[int, List[float], Array]] = None, + ignore_index: Optional[int] = None, + ) -> None: + """Initialize the BinaryAUROC metric.""" + super().__init__(thresholds=thresholds, ignore_index=ignore_index) + _binary_auroc_validate_args( + max_fpr=max_fpr, + thresholds=thresholds, + ignore_index=ignore_index, + ) + self.max_fpr = max_fpr + + def _compute_metric(self) -> Array: # type: ignore[override] + """Compute the AUROC.""" "" + state = ( + (dim_zero_cat(self.target), dim_zero_cat(self.preds)) # type: ignore[attr-defined] + if self.thresholds is None + else self.confmat # type: ignore[attr-defined] + ) + return _binary_auroc_compute(state, thresholds=self.thresholds, max_fpr=self.max_fpr) # type: ignore + + +class MulticlassAUROC(MulticlassPrecisionRecallCurve): + """Area under the Receiver Operating Characteristic (ROC) curve. + + Parameters + ---------- + num_classes : int + The number of classes in the classification problem. + thresholds : Union[int, List[float], Array], optional, default=None + The thresholds to use for computing the ROC curve. Can be one of the following: + - `None`: use all unique values in `preds` as thresholds. + - `int`: use `int` (larger than 1) uniformly spaced thresholds in the range + [0, 1]. + - `List[float]`: use the values in the list as bins for the thresholds. + - `Array`: use the values in the Array as bins for the thresholds. The + array must be 1D. + average : {"macro", "weighted", "none"}, optional, default="macro" + The type of averaging to use for computing the AUROC. Can be one of + the following: + - `"macro"`: interpolates the curves from each class at a combined set of + thresholds and then average over the classwise interpolated curves. + - `"weighted"`: average over the classwise curves weighted by the support + (the number of true instances for each class). + - `"none"`: do not average over the classwise curves. + ignore_index : int or Tuple[int], optional, default=None + The value(s) in `target` that should be ignored when computing the AUROC. + If `None`, all values in `target` are used. + + Examples + -------- + >>> import numpy.array_api as anp + >>> from cyclops.evaluate.metrics.experimental import MulticlassAUROC + >>> target = anp.asarray([0, 1, 2, 0, 1, 2]) + >>> preds = anp.asarray( + ... [[0.11, 0.22, 0.67], + ... [0.84, 0.73, 0.12], + ... [0.33, 0.92, 0.44], + ... [0.11, 0.22, 0.67], + ... [0.84, 0.73, 0.12], + ... [0.33, 0.92, 0.44]]) + >>> auroc = MulticlassAUROC(num_classes=3, average="macro", thresholds=None) + >>> auroc(target, preds) + Array(0.33333334, dtype=float32) + >>> auroc = MulticlassAUROC(num_classes=3, average=None, thresholds=None) + >>> auroc(target, preds) + Array([0. , 0.5, 0.5], dtype=float32) + >>> auroc = MulticlassAUROC(num_classes=3, average="macro", thresholds=5) + >>> auroc(target, preds) + Array(0.33333334, dtype=float32) + >>> auroc = MulticlassAUROC(num_classes=3, average=None, thresholds=5) + >>> auroc(target, preds) + Array([0. , 0.5, 0.5], dtype=float32) + """ + + name: str = "AUC ROC Curve" + + def __init__( + self, + num_classes: int, + thresholds: Optional[Union[int, List[float], Array]] = None, + average: Optional[Literal["macro", "weighted", "none"]] = "macro", + ignore_index: Optional[Union[int, Tuple[int]]] = None, + ) -> None: + """Initialize the MulticlassAUROC metric.""" + super().__init__( + num_classes, + thresholds=thresholds, + ignore_index=ignore_index, + ) + _multiclass_auroc_validate_args( + num_classes=num_classes, + thresholds=thresholds, + average=average, + ignore_index=ignore_index, + ) + self.average = average # type: ignore[assignment] + + def _compute_metric(self) -> Array: # type: ignore[override] + """Compute the AUROC.""" + state = ( + (dim_zero_cat(self.target), dim_zero_cat(self.preds)) # type: ignore[attr-defined] + if self.thresholds is None + else self.confmat # type: ignore[attr-defined] + ) + return _multiclass_auroc_compute( + state, + self.num_classes, + thresholds=self.thresholds, # type: ignore[arg-type] + average=self.average, # type: ignore[arg-type] + ) + + +class MultilabelAUROC(MultilabelPrecisionRecallCurve): + """Area under the Receiver Operating Characteristic (ROC) curve. + + num_labels : int + The number of labels in the multilabel classification problem. + thresholds : Union[int, List[float], Array], optional, default=None + The thresholds to use for computing the ROC curve. Can be one of the following: + - `None`: use all unique values in `preds` as thresholds. + - `int`: use `int` (larger than 1) uniformly spaced thresholds in the range + [0, 1]. + - `List[float]`: use the values in the list as bins for the thresholds. + - `Array`: use the values in the Array as bins for the thresholds. The + array must be 1D. + average : {"micro", "macro", "weighted", "none"}, optional, default="macro" + The type of averaging to use for computing the AUROC. Can be one of + the following: + - `"micro"`: compute the AUROC globally by considering each element of the + label indicator matrix as a label. + - `"macro"`: compute the AUROC for each label and average them. + - `"weighted"`: compute the AUROC for each label and average them weighted + by the support (the number of true instances for each label). + - `"none"`: do not average over the labelwise AUROC. + ignore_index : int, optional, default=None + The value in `target` that should be ignored when computing the AUROC. + If `None`, all values in `target` are used. + + Examples + -------- + >>> import numpy.array_api as anp + >>> from cyclops.evaluate.metrics.experimental import MultilabelAUROC + >>> target = anp.asarray([[0, 1, 0], [1, 1, 0], [0, 0, 1]]) + >>> preds = anp.asarray( + ... [[0.11, 0.22, 0.67], [0.84, 0.73, 0.12], [0.33, 0.92, 0.44]], + ... ) + >>> auroc = MultilabelAUROC(num_labels=3, average="macro", thresholds=None) + >>> auroc(target, preds) + Array(0.5, dtype=float32) + >>> auroc = MultilabelAUROC(num_labels=3, average=None, thresholds=None) + >>> auroc(target, preds) + Array([1. , 0. , 0.5], dtype=float32) + >>> auroc = MultilabelAUROC(num_labels=3, average="macro", thresholds=5) + >>> auroc(target, preds) + Array(0.5, dtype=float32) + >>> auroc = MultilabelAUROC(num_labels=3, average=None, thresholds=5) + >>> auroc(target, preds) + Array([1. , 0. , 0.5], dtype=float32) + + """ + + name: str = "AUC ROC Curve" + + def __init__( + self, + num_labels: int, + thresholds: Optional[Union[int, List[float], Array]] = None, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + ignore_index: Optional[int] = None, + ) -> None: + """Initialize the MultilabelAUROC metric.""" + super().__init__( + num_labels, + thresholds=thresholds, + ignore_index=ignore_index, + ) + _multilabel_auroc_validate_args( + num_labels=num_labels, + thresholds=thresholds, + average=average, + ignore_index=ignore_index, + ) + self.average = average + + def _compute_metric(self) -> Array: # type: ignore[override] + """Compute the AUROC.""" + state = ( + (dim_zero_cat(self.target), dim_zero_cat(self.preds)) # type: ignore[attr-defined] + if self.thresholds is None + else self.confmat # type: ignore[attr-defined] + ) + return _multilabel_auroc_compute( + state, + self.num_labels, + thresholds=self.thresholds, # type: ignore[arg-type] + average=self.average, + ignore_index=self.ignore_index, + ) diff --git a/cyclops/evaluate/metrics/experimental/functional/__init__.py b/cyclops/evaluate/metrics/experimental/functional/__init__.py index 429b9bd78..14a887191 100644 --- a/cyclops/evaluate/metrics/experimental/functional/__init__.py +++ b/cyclops/evaluate/metrics/experimental/functional/__init__.py @@ -4,6 +4,11 @@ multiclass_accuracy, multilabel_accuracy, ) +from cyclops.evaluate.metrics.experimental.functional.auroc import ( + binary_auroc, + multiclass_auroc, + multilabel_auroc, +) from cyclops.evaluate.metrics.experimental.functional.confusion_matrix import ( binary_confusion_matrix, multiclass_confusion_matrix, diff --git a/cyclops/evaluate/metrics/experimental/functional/auroc.py b/cyclops/evaluate/metrics/experimental/functional/auroc.py new file mode 100644 index 000000000..c6e7c83c5 --- /dev/null +++ b/cyclops/evaluate/metrics/experimental/functional/auroc.py @@ -0,0 +1,638 @@ +"""Functions for computing the area under the ROC curve (AUROC).""" +import warnings +from typing import List, Literal, Optional, Tuple, Union + +import array_api_compat as apc + +from cyclops.evaluate.metrics.experimental.functional.precision_recall_curve import ( + _binary_precision_recall_curve_format_arrays, + _binary_precision_recall_curve_update, + _binary_precision_recall_curve_validate_args, + _binary_precision_recall_curve_validate_arrays, + _multiclass_precision_recall_curve_format_arrays, + _multiclass_precision_recall_curve_update, + _multiclass_precision_recall_curve_validate_args, + _multiclass_precision_recall_curve_validate_arrays, + _multilabel_precision_recall_curve_format_arrays, + _multilabel_precision_recall_curve_update, + _multilabel_precision_recall_curve_validate_args, + _multilabel_precision_recall_curve_validate_arrays, +) +from cyclops.evaluate.metrics.experimental.functional.roc import ( + _binary_roc_compute, + _multiclass_roc_compute, + _multilabel_roc_compute, +) +from cyclops.evaluate.metrics.experimental.utils.ops import ( + _auc_compute, + _interp, + _searchsorted, + bincount, + flatten, + remove_ignore_index, + safe_divide, +) +from cyclops.evaluate.metrics.experimental.utils.types import Array + + +def _binary_auroc_validate_args( + max_fpr: Optional[float] = None, + thresholds: Optional[Union[int, List[float], Array]] = None, + ignore_index: Optional[int] = None, +) -> None: + """Validate arguments for binary AUROC computation.""" + _binary_precision_recall_curve_validate_args(thresholds, ignore_index) + if max_fpr is not None and not isinstance(max_fpr, float) and 0 < max_fpr <= 1: + raise ValueError( + f"Argument `max_fpr` should be a float in range (0, 1], but got: {max_fpr}", + ) + + +def _binary_auroc_compute( + state: Union[Array, Tuple[Array, Array]], + thresholds: Optional[Array], + max_fpr: Optional[float] = None, + pos_label: int = 1, +) -> Array: + """Compute the area under the ROC curve for binary classification tasks.""" + fpr, tpr, _ = _binary_roc_compute(state, thresholds, pos_label) + xp = apc.array_namespace(state) + if max_fpr is None or max_fpr == 1 or xp.sum(fpr) == 0 or xp.sum(tpr) == 0: + return _auc_compute(fpr, tpr, 1.0) + + _device = apc.device(fpr) if apc.is_array_api_obj(fpr) else apc.device(fpr[0]) + max_area = xp.asarray(max_fpr, dtype=xp.float32, device=_device) + + # Add a single point at max_fpr and interpolate its tpr value + stop = _searchsorted(fpr, max_area, side="right") + x_interp = xp.asarray([fpr[stop - 1], fpr[stop]], dtype=xp.float32, device=_device) + y_interp = xp.asarray([tpr[stop - 1], tpr[stop]], dtype=xp.float32, device=_device) + interp_tpr = _interp(max_area, x_interp, y_interp) + tpr = xp.concat([tpr[:stop], xp.reshape(interp_tpr, (1,))]) + fpr = xp.concat([fpr[:stop], xp.reshape(max_area, (1,))]) + + # Compute partial AUC + partial_auc = _auc_compute(fpr, tpr, 1.0) + + # McClish correction: standardize result to be 0.5 if non-discriminant and 1 + # if maximal + min_area = 0.5 * max_area**2 + return 0.5 * (1 + (partial_auc - min_area) / (max_area - min_area)) # type: ignore[no-any-return] + + +def binary_auroc( + target: Array, + preds: Array, + max_fpr: Optional[float] = None, + thresholds: Optional[Union[int, List[float], Array]] = None, + ignore_index: Optional[int] = None, +) -> Array: + """Compute the area under the ROC curve for binary classification tasks. + + Parameters + ---------- + target : Array + An array object that is compatible with the Python array API standard + and contains the ground truth labels in the range [0, 1]. The expected + shape of the array is `(N, ...)`, where `N` is the number of samples. + preds : Array + An array object that is compatible with the Python array API standard and + contains the probability/logit scores for the positive class. The expected + shape of the array is `(N, ...)` where `N` is the number of samples. If + `preds` contains floating point values that are not in the range `[0, 1]`, + a sigmoid function will be applied to each value before thresholding. + max_fpr : float, optional, default=None + If not `None`, computes the maximum area under the curve up to the given + false positive rate value. Must be a float in the range (0, 1]. + thresholds : Union[int, List[float], Array], optional, default=None + The thresholds to use for computing the ROC curve. Can be one of the following: + - `None`: use all unique values in `preds` as thresholds. + - `int`: use `int` (larger than 1) uniformly spaced thresholds in the range + [0, 1]. + - `List[float]`: use the values in the list as bins for the thresholds. + - `Array`: use the values in the Array as bins for the thresholds. The + array must be 1D. + ignore_index : int, optional, default=None + The value in `target` that should be ignored when computing the AUROC. + If `None`, all values in `target` are used. + + Returns + ------- + Array + The area under the ROC curve. + + Raises + ------ + TypeError + If `thresholds` is not `None` and not an integer, a list of floats or an + Array of floats. + ValueError + If `thresholds` is an integer and smaller than 2. + ValueError + If `thresholds` is a list of floats with values outside the range [0, 1] + and not monotonically increasing. + ValueError + If `thresholds` is an Array of floats and not all values are in the range + [0, 1] or the array is not one-dimensional. + ValueError + If `ignore_index` is not `None` or an integer. + ValueError + If `max_fpr` is not `None` and not a float in the range (0, 1]. + TypeError + If `target` and `preds` are not compatible with the Python array API standard. + ValueError + If `target` and `preds` are empty. + ValueError + If `target` and `preds` are not numeric arrays. + ValueError + If `target` and `preds` do not have the same shape. + ValueError + If `target` contains floating point values. + ValueError + If `preds` contains non-floating point values. + RuntimeError + If `target` contains values outside the range [0, 1] or does not contain + `ignore_index` if `ignore_index` is not `None`. + ValueError + If the array API namespace of `target` and `preds` are not the same as the + array API namespace of `thresholds`. + + Examples + -------- + >>> from cyclops.evaluate.metrics.experimental.functional import binary_auroc + >>> import numpy.array_api as anp + >>> target = anp.asarray([0, 1, 0, 1, 0, 1]) + >>> preds = anp.asarray([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) + >>> binary_auroc(target, preds, thresholds=None) + Array(0.6666667, dtype=float32) + >>> binary_auroc(target, preds, thresholds=5) + Array(0.5555556, dtype=float32) + >>> binary_auroc(target, preds, max_fpr=0.2) + Array(0.6296296, dtype=float32) + + """ + _binary_auroc_validate_args(max_fpr, thresholds, ignore_index) + xp = _binary_precision_recall_curve_validate_arrays( + target, + preds, + thresholds=thresholds, + ignore_index=ignore_index, + ) + target, preds, thresholds = _binary_precision_recall_curve_format_arrays( + target, + preds, + thresholds=thresholds, + ignore_index=ignore_index, + xp=xp, + ) + state = _binary_precision_recall_curve_update(target, preds, thresholds, xp=xp) + return _binary_auroc_compute(state, thresholds=thresholds, max_fpr=max_fpr) + + +def _reduce_auroc( + fpr: Union[Array, List[Array]], + tpr: Union[Array, List[Array]], + average: Optional[Literal["macro", "weighted", "none"]] = None, + weights: Optional[Array] = None, +) -> Array: + """Compute the area under the ROC curve and apply `average` method. + + Parameters + ---------- + fpr : Array or list of Array + False positive rate. + tpr : Array or list of Array + True positive rate. + average : {"macro", "weighted", "none"}, default=None + If not None, apply the method to compute the average area under the ROC curve. + weights : Array, optional, default=None + Sample weights. + + Returns + ------- + Array + Area under the ROC curve. + + Raises + ------ + ValueError + If ``average`` is not one of ``macro`` or ``weighted`` or if + ``average`` is ``weighted`` and ``weights`` is None. + + Warns + ----- + UserWarning + If the AUROC for one or more classes is `nan` and ``average`` is not ``none``. + + """ + xp = apc.array_namespace((fpr[0], tpr[0]) if isinstance(fpr, list) else (fpr, tpr)) + if apc.is_array_api_obj(fpr) and apc.is_array_api_obj(tpr): + res = _auc_compute(fpr, tpr, 1.0, axis=1) # type: ignore + else: + res = xp.stack( + [_auc_compute(x, y, 1.0) for x, y in zip(fpr, tpr)], # type: ignore + ) + if average is None or average == "none": + return res + + if xp.any(xp.isnan(res)): + warnings.warn( + "The AUROC for one or more classes was `nan`. Ignoring these classes " + f"in {average}-average", + UserWarning, + stacklevel=1, + ) + idx = ~xp.isnan(res) + if average == "macro": + return xp.mean(res[idx]) # type: ignore[no-any-return] + if average == "weighted" and weights is not None: + weights = safe_divide(weights[idx], xp.sum(weights[idx])) + return xp.sum((res[idx] * weights)) # type: ignore[no-any-return] + raise ValueError( + "Received an incompatible combinations of inputs to make reduction.", + ) + + +def _multiclass_auroc_validate_args( + num_classes: int, + thresholds: Optional[Union[int, List[float], Array]] = None, + average: Optional[Literal["macro", "weighted", "none"]] = "macro", + ignore_index: Optional[Union[int, Tuple[int]]] = None, +) -> None: + """Validate arguments for multiclass AUROC computation.""" + _multiclass_precision_recall_curve_validate_args( + num_classes, + thresholds=thresholds, + ignore_index=ignore_index, + ) + allowed_average = ("macro", "weighted", "none", None) + if average not in allowed_average: + raise ValueError( + f"Expected argument `average` to be one of {allowed_average} but got {average}", + ) + + +def _multiclass_auroc_compute( + state: Union[Array, Tuple[Array, Array]], + num_classes: int, + thresholds: Optional[Array] = None, + average: Optional[Literal["macro", "weighted", "none"]] = "macro", +) -> Array: + """Compute the area under the ROC curve for multiclass classification tasks.""" + fpr, tpr, _ = _multiclass_roc_compute(state, num_classes, thresholds=thresholds) + xp = apc.array_namespace(state) + return _reduce_auroc( + fpr, + tpr, + average=average, + weights=xp.astype(bincount(state[0], minlength=num_classes), xp.float32) + if thresholds is None + else xp.sum(state[0, ...][:, 1, :], axis=-1), # type: ignore[call-overload] + ) + + +def multiclass_auroc( + target: Array, + preds: Array, + num_classes: int, + thresholds: Optional[Union[int, List[float], Array]] = None, + average: Optional[Literal["macro", "weighted", "none"]] = "macro", + ignore_index: Optional[Union[int, Tuple[int]]] = None, +) -> Array: + """Compute the area under the ROC curve for multiclass classification tasks. + + Parameters + ---------- + target : Array + An array object that is compatible with the Python array API standard + and contains the ground truth labels in the range [0, `num_classes`] + (except if `ignore_index` is specified). The expected shape of the array + is `(N, ...)`, where `N` is the number of samples. + preds : Array + An array object that is compatible with the Python array API standard and + contains the probability/logit scores for each sample. The expected shape + of the array is `(N, C, ...)` where `N` is the number of samples and `C` + is the number of classes. If `preds` contains floating point values that + are not in the range `[0, 1]`, a softmax function will be applied to each + value before thresholding. + num_classes : int + The number of classes in the classification problem. + thresholds : Union[int, List[float], Array], optional, default=None + The thresholds to use for computing the ROC curve. Can be one of the following: + - `None`: use all unique values in `preds` as thresholds. + - `int`: use `int` (larger than 1) uniformly spaced thresholds in the range + [0, 1]. + - `List[float]`: use the values in the list as bins for the thresholds. + - `Array`: use the values in the Array as bins for the thresholds. The + array must be 1D. + average : {"macro", "weighted", "none"}, optional, default="macro" + The type of averaging to use for computing the AUROC. Can be one of + the following: + - `"macro"`: interpolates the curves from each class at a combined set of + thresholds and then average over the classwise interpolated curves. + - `"weighted"`: average over the classwise curves weighted by the support + (the number of true instances for each class). + - `"none"`: do not average over the classwise curves. + ignore_index : int or Tuple[int], optional, default=None + The value(s) in `target` that should be ignored when computing the AUROC. + If `None`, all values in `target` are used. + + Returns + ------- + Array + The area under the ROC curve. + + Raises + ------ + TypeError + If `thresholds` is not `None` and not an integer, a list of floats or an + Array of floats. + ValueError + If `thresholds` is an integer and smaller than 2. + ValueError + If `thresholds` is a list of floats with values outside the range [0, 1] + and not monotonically increasing. + ValueError + If `thresholds` is an Array of floats and not all values are in the range + [0, 1] or the array is not one-dimensional. + ValueError + If `num_classes` is not an integer larger than 1. + ValueError + If `ignore_index` is not `None`, an integer or a tuple of integers. + ValueError + If `average` is not `"macro"`, `"weighted"`, `"none"` or `None`. + TypeError + If `target` and `preds` are not compatible with the Python array API standard. + ValueError + If `target` and `preds` are empty. + ValueError + If `target` and `preds` are not numeric arrays. + ValueError + If `preds` does not have one more dimension than `target`. + ValueError + If `target` contains floating point values. + ValueError + If `preds` contains non-floating point values. + ValueError + If the second dimension of `preds` is not equal to `num_classes`. + ValueError + If the first dimension of `preds` is not equal to the first dimension of + `target` or the third dimension of `preds` is not equal to the second + dimension of `target`. + RuntimeError + If `target` contains more unique values than `num_classes` or `num_classes` + plus the number of values in `ignore_index` if `ignore_index` is not `None`. + + Examples + -------- + >>> from cyclops.evaluate.metrics.experimental.functional import multiclass_auroc + >>> import numpy.array_api as anp + >>> target = anp.asarray([0, 1, 2, 0, 1, 2]) + >>> preds = anp.asarray( + ... [[0.11, 0.22, 0.67], + ... [0.84, 0.73, 0.12], + ... [0.33, 0.92, 0.44], + ... [0.11, 0.22, 0.67], + ... [0.84, 0.73, 0.12], + ... [0.33, 0.92, 0.44]]) + >>> multiclass_auroc(target, preds, num_classes=3, thresholds=None) + Array(0.33333334, dtype=float32) + >>> multiclass_auroc(target, preds, num_classes=3, thresholds=5) + Array(0.33333334, dtype=float32) + >>> multiclass_auroc(target, preds, num_classes=3, average=None) + Array([0. , 0.5, 0.5], dtype=float32) + + """ + _multiclass_auroc_validate_args( + num_classes, + thresholds=thresholds, + average=average, + ignore_index=ignore_index, + ) + xp = _multiclass_precision_recall_curve_validate_arrays( + target, + preds, + num_classes, + ignore_index=ignore_index, + ) + target, preds, thresholds = _multiclass_precision_recall_curve_format_arrays( + target, + preds, + num_classes, + thresholds=thresholds, + ignore_index=ignore_index, + xp=xp, + ) + state = _multiclass_precision_recall_curve_update( + target, + preds, + num_classes, + thresholds=thresholds, + xp=xp, + ) + return _multiclass_auroc_compute( + state, + num_classes, + thresholds=thresholds, + average=average, + ) + + +def _multilabel_auroc_validate_args( + num_labels: int, + thresholds: Optional[Union[int, List[float], Array]] = None, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + ignore_index: Optional[int] = None, +) -> None: + """Validate arguments for multilabel AUROC computation.""" + _multilabel_precision_recall_curve_validate_args( + num_labels, + thresholds=thresholds, + ignore_index=ignore_index, + ) + allowed_average = ("micro", "macro", "weighted", "none", None) + if average not in allowed_average: + raise ValueError( + f"Expected argument `average` to be one of {allowed_average} but got {average}", + ) + + +def _multilabel_auroc_compute( + state: Union[Array, Tuple[Array, Array]], + num_labels: int, + thresholds: Optional[Array], + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + ignore_index: Optional[int] = None, +) -> Array: + """Compute the area under the ROC curve for multilabel classification tasks.""" + if average == "micro": + if apc.is_array_api_obj(state) and thresholds is not None: + xp = apc.array_namespace(state) + return _binary_auroc_compute( + xp.sum(state, axis=1), + thresholds, + max_fpr=None, + ) + + target = flatten(state[0]) + preds = flatten(state[1]) + if ignore_index is not None: + target, preds = remove_ignore_index(target, preds, ignore_index) + return _binary_auroc_compute((target, preds), thresholds, max_fpr=None) + + fpr, tpr, _ = _multilabel_roc_compute(state, num_labels, thresholds, ignore_index) + xp = apc.array_namespace(state) + return _reduce_auroc( + fpr, + tpr, + average, + weights=xp.astype( + xp.sum(xp.astype(state[0] == 1, xp.int32), axis=0), + xp.float32, + ) + if thresholds is None + else xp.sum(state[0, ...][:, 1, :], axis=-1), # type: ignore[call-overload] + ) + + +def multilabel_auroc( + target: Array, + preds: Array, + num_labels: int, + thresholds: Optional[Union[int, List[float], Array]] = None, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + ignore_index: Optional[int] = None, +) -> Array: + """Compute the area under the ROC curve for multilabel classification tasks. + + Parameters + ---------- + target : Array + The target array of shape `(N, L, ...)` containing the ground truth labels + in the range [0, 1], where `N` is the number of samples and `L` is the + number of labels. + preds : Array + The prediction array of shape `(N, L, ...)` containing the probability/logit + scores for each sample, where `N` is the number of samples and `L` is the + number of labels. If `preds` contains floating point values that are not + in the range [0,1], they will be converted to probabilities using the + sigmoid function. + num_labels : int + The number of labels in the multilabel classification problem. + thresholds : Union[int, List[float], Array], optional, default=None + The thresholds to use for computing the ROC curve. Can be one of the following: + - `None`: use all unique values in `preds` as thresholds. + - `int`: use `int` (larger than 1) uniformly spaced thresholds in the range + [0, 1]. + - `List[float]`: use the values in the list as bins for the thresholds. + - `Array`: use the values in the Array as bins for the thresholds. The + array must be 1D. + average : {"micro", "macro", "weighted", "none"}, optional, default="macro" + The type of averaging to use for computing the AUROC. Can be one of + the following: + - `"micro"`: compute the AUROC globally by considering each element of the + label indicator matrix as a label. + - `"macro"`: compute the AUROC for each label and average them. + - `"weighted"`: compute the AUROC for each label and average them weighted + by the support (the number of true instances for each label). + - `"none"`: do not average over the labelwise AUROC. + ignore_index : int, optional, default=None + The value in `target` that should be ignored when computing the AUROC. + If `None`, all values in `target` are used. + + Returns + ------- + Array + The area under the ROC curve. + + Raises + ------ + TypeError + If `thresholds` is not `None` and not an integer, a list of floats or an + Array of floats. + ValueError + If `thresholds` is an integer and smaller than 2. + ValueError + If `thresholds` is a list of floats with values outside the range [0, 1] + and not monotonically increasing. + ValueError + If `thresholds` is an Array of floats and not all values are in the range + [0, 1] or the array is not one-dimensional. + ValueError + If `ignore_index` is not `None` or an integer. + ValueError + If `num_labels` is not an integer larger than 1. + ValueError + If `average` is not `"micro"`, `"macro"`, `"none"` or `None`. + TypeError + If `target` and `preds` are not compatible with the Python array API standard. + ValueError + If `target` and `preds` are empty. + ValueError + If `target` and `preds` are not numeric arrays. + ValueError + If `target` and `preds` do not have the same shape. + ValueError + If `target` contains floating point values. + ValueError + If `preds` contains non-floating point values. + RuntimeError + If `target` contains values outside the range [0, 1] or does not contain + `ignore_index` if `ignore_index` is not `None`. + ValueError + If the array API namespace of `target` and `preds` are not the same as the + array API namespace of `thresholds`. + ValueError + If the second dimension of `preds` is not equal to `num_labels`. + + Examples + -------- + >>> from cyclops.evaluate.metrics.experimental.functional import multilabel_auroc + >>> import numpy.array_api as anp + >>> target = anp.asarray([[0, 1, 0], [1, 1, 0], [0, 0, 1]]) + >>> preds = anp.asarray( + ... [[0.11, 0.22, 0.67], [0.84, 0.73, 0.12], [0.33, 0.92, 0.44]], + ... ) + >>> multilabel_auroc(target, preds, num_labels=3, thresholds=None) + Array(0.5, dtype=float32) + >>> multilabel_auroc(target, preds, num_labels=3, thresholds=5) + Array(0.5, dtype=float32) + >>> multilabel_auroc(target, preds, num_labels=3, average=None) + Array([1. , 0. , 0.5], dtype=float32) + + """ + _multilabel_auroc_validate_args( + num_labels, + thresholds=thresholds, + average=average, + ignore_index=ignore_index, + ) + xp = _multilabel_precision_recall_curve_validate_arrays( + target, + preds, + num_labels, + thresholds=thresholds, + ignore_index=ignore_index, + ) + target, preds, thresholds = _multilabel_precision_recall_curve_format_arrays( + target, + preds, + num_labels, + thresholds=thresholds, + ignore_index=ignore_index, + xp=xp, + ) + state = _multilabel_precision_recall_curve_update( + target, + preds, + num_labels, + thresholds=thresholds, + xp=xp, + ) + return _multilabel_auroc_compute( + state, + num_labels, + thresholds=thresholds, + average=average, + ignore_index=ignore_index, + ) diff --git a/cyclops/evaluate/metrics/experimental/functional/precision_recall_curve.py b/cyclops/evaluate/metrics/experimental/functional/precision_recall_curve.py index 609548cf8..0c2409670 100644 --- a/cyclops/evaluate/metrics/experimental/functional/precision_recall_curve.py +++ b/cyclops/evaluate/metrics/experimental/functional/precision_recall_curve.py @@ -158,8 +158,8 @@ def _format_thresholds( def _binary_precision_recall_curve_format_arrays( target: Array, preds: Array, - thresholds: Optional[Union[int, List[float], Array]], - ignore_index: Optional[int], + thresholds: Optional[Union[int, List[float], Array]] = None, + ignore_index: Optional[int] = None, *, xp: ModuleType, ) -> Tuple[Array, Array, Optional[Array]]: @@ -449,8 +449,8 @@ def binary_precision_recall_curve( def _multiclass_precision_recall_curve_validate_args( num_classes: int, thresholds: Optional[Union[int, List[float], Array]] = None, - ignore_index: Optional[Union[int, Tuple[int]]] = None, average: Optional[Literal["macro", "micro", "none"]] = None, + ignore_index: Optional[Union[int, Tuple[int]]] = None, ) -> None: """Validate the arguments for the `multiclass_precision_recall_curve` function.""" _validate_thresholds(thresholds) @@ -482,7 +482,7 @@ def _multiclass_precision_recall_curve_validate_arrays( target: Array, preds: Array, num_classes: int, - ignore_index: Optional[Union[int, Tuple[int]]], + ignore_index: Optional[Union[int, Tuple[int]]] = None, ) -> ModuleType: """Validate the arrays for the `multiclass_precision_recall_curve` function.""" _basic_input_array_checks(target, preds) @@ -537,8 +537,8 @@ def _multiclass_precision_recall_curve_format_arrays( preds: Array, num_classes: int, thresholds: Optional[Union[int, List[float], Array]], - ignore_index: Optional[Union[int, Tuple[int]]], - average: Optional[Literal["macro", "micro", "none"]], + ignore_index: Optional[Union[int, Tuple[int]]] = None, + average: Optional[Literal["macro", "micro", "none"]] = None, *, xp: ModuleType, ) -> Tuple[Array, Array, Optional[Array]]: @@ -828,15 +828,15 @@ class is returned with 1-D Arrays of shape `(num_thresholds,)`. Otherwise, """ # noqa: W505 _multiclass_precision_recall_curve_validate_args( num_classes, - thresholds, - ignore_index, - average, + thresholds=thresholds, + average=average, + ignore_index=ignore_index, ) xp = _multiclass_precision_recall_curve_validate_arrays( target, preds, num_classes, - ignore_index, + ignore_index=ignore_index, ) target, preds, thresholds = _multiclass_precision_recall_curve_format_arrays( target, diff --git a/cyclops/evaluate/metrics/experimental/functional/roc.py b/cyclops/evaluate/metrics/experimental/functional/roc.py index 2371a5a83..e95948d5b 100644 --- a/cyclops/evaluate/metrics/experimental/functional/roc.py +++ b/cyclops/evaluate/metrics/experimental/functional/roc.py @@ -196,14 +196,14 @@ def binary_roc( xp = _binary_precision_recall_curve_validate_arrays( target, preds, - thresholds, - ignore_index, + thresholds=thresholds, + ignore_index=ignore_index, ) target, preds, thresholds = _binary_precision_recall_curve_format_arrays( target, preds, - thresholds, - ignore_index, + thresholds=thresholds, + ignore_index=ignore_index, xp=xp, ) state = _binary_precision_recall_curve_update(target, preds, thresholds, xp=xp) @@ -214,7 +214,7 @@ def _multiclass_roc_compute( state: Union[Array, Tuple[Array, Array]], num_classes: int, thresholds: Optional[Array], - average: Optional[Literal["macro", "micro", "none"]], + average: Optional[Literal["macro", "micro", "none"]] = None, ) -> Union[Tuple[Array, Array, Array], Tuple[List[Array], List[Array], List[Array]]]: """Compute the multiclass ROC curve.""" if average == "micro": @@ -417,9 +417,9 @@ def multiclass_roc( """ # noqa: W505 _multiclass_precision_recall_curve_validate_args( num_classes, - thresholds, - ignore_index, - average, + thresholds=thresholds, + average=average, + ignore_index=ignore_index, ) xp = _multiclass_precision_recall_curve_validate_arrays( target, diff --git a/cyclops/evaluate/metrics/experimental/utils/ops.py b/cyclops/evaluate/metrics/experimental/utils/ops.py index 5d9279e56..83d148ae8 100644 --- a/cyclops/evaluate/metrics/experimental/utils/ops.py +++ b/cyclops/evaluate/metrics/experimental/utils/ops.py @@ -695,7 +695,72 @@ def _array_indexing(arr: Array, idx: Array) -> Array: return xp.asarray(np_arr, dtype=arr.dtype, device=apc.device(arr)) -def _cumsum(x: Array, axis: Optional[int], dtype: Optional[Any] = None) -> Array: +def _auc_compute( + x: Array, + y: Array, + direction: Optional[float] = None, + axis: int = -1, + reorder: bool = False, +) -> Array: + """Compute the area under the curve using the trapezoidal rule. + + Adapted from: https://github.com/Lightning-AI/torchmetrics/blob/fd2e332b66df1b484728efedad9d430c7efae990/src/torchmetrics/utilities/compute.py#L99-L115 + + Parameters + ---------- + x : Array + The x-coordinates of the curve. + y : Array + The y-coordinates of the curve. + direction : float, optional, default=None + The direction of the curve. If None, the direction will be inferred from the + values in `x`. + axis : int, optional, default=-1 + The axis along which to compute the area under the curve. + reorder : bool, optional, default=False + Whether to sort the arrays `x` and `y` by `x` before computing the area under + the curve. + """ + xp = apc.array_namespace(x, y) + if reorder: + x, x_idx = xp.sort(x, stable=True) + y = _array_indexing(y, x_idx) + + if direction is None: + dx = x[1:] - x[:-1] + if xp.any(dx < 0): + if xp.all(dx <= 0): + direction = -1.0 + else: + raise ValueError( + "The array `x` is neither increasing or decreasing. " + "Try setting the reorder argument to `True`.", + ) + else: + direction = 1.0 + + return xp.astype(_trapz(y, x, axis=axis) * direction, xp.float32) + + +def _cumsum(x: Array, axis: Optional[int] = None, dtype: Optional[Any] = None) -> Array: + """Compute the cumulative sum of an array along a given axis. + + Parameters + ---------- + x : Array + The input array. + axis : int, optional, default=None + The axis along which to compute the cumulative sum. If None, the input array + will be flattened before computing the cumulative sum. + dtype : Any, optional, default=None + The data type of the output array. If None, the data type of the input array + will be used. + + Returns + ------- + Array + An array containing the cumulative sum of the input array along the given axis. + """ xp = apc.array_namespace(x) if hasattr(xp, "cumsum"): return xp.cumsum(x, axis, dtype=dtype) @@ -734,23 +799,147 @@ def _cumsum(x: Array, axis: Optional[int], dtype: Optional[Any] = None) -> Array return result +def _diff( + a: Array, + n: int = 1, + axis: int = -1, + prepend: Optional[Array] = None, + append: Optional[Array] = None, +) -> Array: + """Calculate the n-th discrete difference along the given axis. + + Adapted from: https://github.com/numpy/numpy/blob/v1.26.0/numpy/lib/function_base.py#L1324-L1454 + + Parameters + ---------- + a : Array + Input array. + n : int, optional, default=1 + The number of times values are differenced. If zero, the input is returned + as-is. + axis : int, optional, default=-1 + The axis along which the difference is taken, default is the last axis. + prepend : Array, optional, default=None + Values to prepend to `a` along `axis` prior to performing the difference. + append : Array, optional, default=None + Values to append to `a` along `axis` after performing the difference. + + Returns + ------- + Array + The n-th differences. The shape of the output is the same as `a` except along + `axis` where the dimension is smaller by `n`. The type of the output is the + same as the type of the difference between any two elements of `a`. This is + the same type as `a` in most cases. + """ + xp = apc.array_namespace(a) + + if prepend is not None and not apc.is_array_api_obj(prepend): + raise TypeError( + "Expected argument `prepend` to be an object that is compatible with the " + f"Python array API standard. Got {type(prepend)} instead.", + ) + if append is not None and not apc.is_array_api_obj(append): + raise TypeError( + "Expected argument `append` to be an object that is compatible with the " + f"Python array API standard. Got {type(append)} instead.", + ) + + if n == 0: + return a + if n < 0: + raise ValueError("order must be non-negative but got " + repr(n)) + + nd = a.ndim + if nd == 0: + raise ValueError("diff requires input that is at least one dimensional") + + combined = [] + if prepend is not None: + if prepend.ndim == 0: + shape = list(a.shape) + shape[axis] = 1 + prepend = xp.broadcast_to(prepend, tuple(shape)) + combined.append(prepend) + + combined.append(a) + + if append is not None: + if append.ndim == 0: + shape = list(a.shape) + shape[axis] = 1 + append = xp.broadcast_to(append, tuple(shape)) + combined.append(append) + + if len(combined) > 1: + a = xp.concat(combined, axis) + + slice1 = [slice(None)] * nd + slice2 = [slice(None)] * nd + slice1[axis] = slice(1, None) + slice2[axis] = slice(None, -1) + slice1 = tuple(slice1) # type: ignore[assignment] + slice2 = tuple(slice2) # type: ignore[assignment] + + op = xp.not_equal if a.dtype == xp.bool else xp.subtract + for _ in range(n): + a = op(a[slice1], a[slice2]) + + return a + + def _interp(x: Array, xcoords: Array, ycoords: Array) -> Array: - """Perform linear interpolation for 1D arrays.""" + """Perform linear interpolation for 1D arrays. + + Parameters + ---------- + x : Array + The 1D array of points on which to interpolate. + xcoords : Array + The 1D array of x-coordinates containing known data points. + ycoords : Array + The 1D array of y-coordinates containing known data points. + + Returns + ------- + Array + The interpolated values. + """ xp = apc.array_namespace(x, xcoords, ycoords) if hasattr(xp, "interp"): return xp.interp(x, xcoords, ycoords) + if _is_torch_array(x): + weight = (x - xcoords[0]) / (xcoords[-1] - xcoords[0]) + return xp.lerp(ycoords[0], ycoords[-1], weight) + + if xcoords.ndim != 1 or ycoords.ndim != 1: + raise ValueError( + "Expected `xcoords` and `ycoords` to be 1D arrays. " + f"Got xcoords.ndim={xcoords.ndim} and ycoords.ndim={ycoords.ndim}.", + ) + if xcoords.shape[0] != ycoords.shape[0]: + raise ValueError( + "Expected `xcoords` and `ycoords` to have the same shape along axis 0. " + f"Got xcoords.shape={xcoords.shape} and ycoords.shape={ycoords.shape}.", + ) + m = safe_divide(ycoords[1:] - ycoords[:-1], xcoords[1:] - xcoords[:-1]) b = ycoords[:-1] - (m * xcoords[:-1]) - indices = xp.sum(x[:, None] >= xcoords[None, :], 1) - 1 - _min_val = xp.asarray(0, dtype=xp.float32, device=apc.device(x)) + # create slices to work for any ndim of x and xcoords + indices = ( + xp.sum(xp.astype(x[..., None] >= xcoords[None, ...], xp.int32), axis=1) - 1 + ) + _min_val = xp.asarray(0, dtype=xp.int32, device=apc.device(x)) _max_val = xp.asarray( m.shape[0] if m.ndim > 0 else 1 - 1, - dtype=xp.float32, + dtype=xp.int32, device=apc.device(x), ) - indices = xp.min(xp.max(indices, _min_val), _max_val) + # clamp indices to _min_val and _max_val + indices = xp.where(indices < _min_val, _min_val, indices) + indices = xp.where(indices > _max_val, _max_val, indices) return _array_indexing(m, indices) * x + _array_indexing(b, indices) @@ -845,6 +1034,53 @@ def _select_topk( # noqa: PLR0912 return xp.asarray(result, device=apc.device(scores)) +def _searchsorted( + a: Array, + v: Array, + side: str = "left", + sorter: Optional[Array] = None, +) -> Array: + """Find indices where elements of `v` should be inserted to maintain order. + + Parameters + ---------- + a : Array + Input array. Must be sorted in ascending order if `sorter` is `None`. + v : Array + Values to insert into `a`. + side : {'left', 'right'}, optional, default='left' + If 'left', the index of the first suitable location found is given. + If 'right', return the last such index. If there is no suitable index, + return either 0 or `N` (where N is the length of `a`). + sorter : Array, optional, default=None + An optional array of integer indices that sort array `a` into ascending order. + This is typically the result of `argsort`. + + Returns + ------- + Array + Array of insertion points with the same shape as `v`. + + Warnings + -------- + This method uses `numpy.from_dlpack` to convert the input arrays to NumPy arrays + and then uses `numpy.searchsorted` to perform the search. This may result in + unexpected behavior for some array namespaces. + + """ + xp = apc.array_namespace(a, v) + if hasattr(xp, "searchsorted"): + return xp.searchsorted(a, v, side=side, sorter=sorter) + + np_a = np.from_dlpack(apc.to_device(a, "cpu")) + np_v = np.from_dlpack(apc.to_device(v, "cpu")) + np_sorter = ( + np.from_dlpack(apc.to_device(sorter, "cpu")) if sorter is not None else None + ) + np_result = np.searchsorted(np_a, np_v, side=side, sorter=np_sorter) # type: ignore[call-overload] + return xp.asarray(np_result, dtype=xp.int32, device=apc.device(a)) + + def _to_one_hot( array: Array, num_classes: Optional[int] = None, @@ -893,3 +1129,62 @@ def _to_one_hot( output_shape = input_shape + (num_classes,) return xp.reshape(categorical, output_shape) + + +def _trapz( + y: Array, + x: Optional[Array] = None, + dx: float = 1.0, + axis: int = -1, +) -> Array: + """Integrate along the given axis using the composite trapezoidal rule. + + Adapted from: https://github.com/cupy/cupy/blob/v12.3.0/cupy/_math/sumprod.py#L580-L626 + + Parameters + ---------- + y : Array + Input array to integrate. + x : Array, optional, default=None + Sample points over which to integrate. If `x` is None, the sample points are + assumed to be evenly spaced `dx` apart. + dx : float, optional, default=1.0 + Spacing between sample points when `x` is None. + axis : int, optional, default=-1 + Axis along which to integrate. + + Returns + ------- + Array + Definite integral as approximated by trapezoidal rule. + """ + xp = apc.array_namespace(y) + + if not apc.is_array_api_obj(y): + raise TypeError( + "The type for `y` should be compatible with the Python array API standard.", + ) + + if x is None: + d = dx + else: + if not apc.is_array_api_obj(x): + raise TypeError( + "The type for `x` should be compatible with the Python array API standard.", + ) + if x.ndim == 1: + d = _diff(x) # type: ignore[assignment] + # reshape to correct shape + shape = [1] * y.ndim + shape[axis] = d.shape[0] # type: ignore[attr-defined] + d = xp.reshape(d, shape) + else: + d = _diff(x, axis=axis) # type: ignore[assignment] + + nd = y.ndim + slice1 = [slice(None)] * nd + slice2 = [slice(None)] * nd + slice1[axis] = slice(1, None) + slice2[axis] = slice(None, -1) + product = d * (y[tuple(slice1)] + y[tuple(slice2)]) / 2.0 + return xp.sum(product, dtype=xp.float32, axis=axis) diff --git a/tests/cyclops/evaluate/metrics/experimental/inputs.py b/tests/cyclops/evaluate/metrics/experimental/inputs.py index 2a0197600..05c35a47d 100644 --- a/tests/cyclops/evaluate/metrics/experimental/inputs.py +++ b/tests/cyclops/evaluate/metrics/experimental/inputs.py @@ -1,6 +1,7 @@ """Input data for tests of metrics in cyclops/evaluate/metrics/experimental.""" import random from collections import namedtuple +from types import ModuleType from typing import Any import array_api_compat as apc @@ -32,6 +33,13 @@ def _inv_sigmoid(arr: Array) -> Array: set_random_seed(1) + +def _thresholds(*, xp: ModuleType) -> list: + """Return thresholds for AUROC.""" + thresh_list = [0.0, 0.3, 0.5, 0.7, 0.9, 1.0] + return [None, 5, thresh_list, xp.asarray(thresh_list)] + + # binary # NOTE: the test will loop over the first dimension of the input _binary_labels_0d = np.random.randint(0, 2, size=(NUM_BATCHES, 1)) diff --git a/tests/cyclops/evaluate/metrics/experimental/test_auroc.py b/tests/cyclops/evaluate/metrics/experimental/test_auroc.py new file mode 100644 index 000000000..b3ea68c89 --- /dev/null +++ b/tests/cyclops/evaluate/metrics/experimental/test_auroc.py @@ -0,0 +1,517 @@ +"""Test AUROC metric.""" +from functools import partial + +import array_api_compat as apc +import array_api_compat.torch +import numpy.array_api as anp +import pytest +import torch.utils.dlpack +from torchmetrics.functional.classification import ( + binary_auroc as tm_binary_auroc, +) +from torchmetrics.functional.classification import ( + multiclass_auroc as tm_multiclass_auroc, +) +from torchmetrics.functional.classification import ( + multilabel_auroc as tm_multilabel_auroc, +) + +from cyclops.evaluate.metrics.experimental.auroc import ( + BinaryAUROC, + MulticlassAUROC, + MultilabelAUROC, +) +from cyclops.evaluate.metrics.experimental.functional.auroc import ( + binary_auroc, + multiclass_auroc, + multilabel_auroc, +) +from cyclops.evaluate.metrics.experimental.utils.ops import to_int +from cyclops.evaluate.metrics.experimental.utils.validation import is_floating_point + +from ..conftest import NUM_CLASSES, NUM_LABELS +from .inputs import _binary_cases, _multiclass_cases, _multilabel_cases, _thresholds +from .testers import MetricTester, _inject_ignore_index + + +def _binary_auroc_reference( + target, + preds, + max_fpr, + thresholds, + ignore_index, +) -> torch.Tensor: + """Return the reference binary AUROC.""" + return tm_binary_auroc( + torch.utils.dlpack.from_dlpack(preds), + torch.utils.dlpack.from_dlpack(target), + max_fpr=max_fpr, + thresholds=torch.utils.dlpack.from_dlpack(thresholds) + if apc.is_array_api_obj(thresholds) + else thresholds, + ignore_index=ignore_index, + ) + + +class TestBinaryAUROC(MetricTester): + """Test binary AUROC function and class.""" + + atol = 1e-7 + + @pytest.mark.parametrize("inputs", _binary_cases(xp=anp)[3:]) + @pytest.mark.parametrize("max_fpr", [None, 0.5]) + @pytest.mark.parametrize("thresholds", _thresholds(xp=anp)) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_binary_auroc_function_with_numpy_array_api_arrays( + self, + inputs, + max_fpr, + thresholds, + ignore_index, + ) -> None: + """Test function for binary AUROC using array_api arrays.""" + target, preds = inputs + + if ignore_index is not None: + if target.shape[1] == 1 and anp.any(target == ignore_index): + pytest.skip( + "When targets are single elements and 'ignore_index' in target " + "the function will raise an error because it will receive an " + "empty array after filtering out the 'ignore_index' values.", + ) + target = _inject_ignore_index(target, ignore_index) + + self.run_metric_function_implementation_test( + target, + preds, + metric_function=binary_auroc, + metric_args={ + "max_fpr": max_fpr, + "thresholds": thresholds, + "ignore_index": ignore_index, + }, + reference_metric=partial( + _binary_auroc_reference, + max_fpr=max_fpr, + thresholds=thresholds, + ignore_index=ignore_index, + ), + ) + + @pytest.mark.parametrize("inputs", _binary_cases(xp=anp)[3:]) + @pytest.mark.parametrize("max_fpr", [None, 0.5]) + @pytest.mark.parametrize("thresholds", _thresholds(xp=anp)) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_binary_auroc_class_with_numpy_array_api_arrays( + self, + inputs, + max_fpr, + thresholds, + ignore_index, + ) -> None: + """Test class for binary AUROC using array_api arrays.""" + target, preds = inputs + + if ( + preds.shape[1] == 1 + and is_floating_point(preds) + and not anp.all(to_int((preds >= 0)) * to_int((preds <= 1))) + ): + pytest.skip( + "When using 0-D logits, batch result will be different from local " + "result because the `sigmoid` operation may not be applied to each " + "batch (some values may be in [0, 1] and some may not).", + ) + + if ignore_index is not None: + if target.shape[1] == 1 and anp.any(target == ignore_index): + pytest.skip( + "When targets are single elements and 'ignore_index' in target " + "the function will raise an error because it will receive an " + "empty array after filtering out the 'ignore_index' values.", + ) + target = _inject_ignore_index(target, ignore_index) + + self.run_metric_class_implementation_test( + target, + preds, + metric_class=BinaryAUROC, + metric_args={ + "max_fpr": max_fpr, + "thresholds": thresholds, + "ignore_index": ignore_index, + }, + reference_metric=partial( + _binary_auroc_reference, + max_fpr=max_fpr, + thresholds=thresholds, + ignore_index=ignore_index, + ), + ) + + @pytest.mark.integration_test() # machine for integration tests has GPU + @pytest.mark.parametrize("inputs", _binary_cases(xp=array_api_compat.torch)[3:]) + @pytest.mark.parametrize("max_fpr", [None, 0.5]) + @pytest.mark.parametrize( + "thresholds", + _thresholds(xp=array_api_compat.torch), + ) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_binary_auroc_with_torch_tensors( + self, + inputs, + max_fpr, + thresholds, + ignore_index, + ) -> None: + """Test binary AUROC class with torch tensors.""" + target, preds = inputs + + if ( + preds.shape[1] == 1 + and is_floating_point(preds) + and not torch.all(to_int((preds >= 0)) * to_int((preds <= 1))) + ): + pytest.skip( + "When using 0-D logits, batch result will be different from local " + "result because the `sigmoid` operation may not be applied to each " + "batch (some values may be in [0, 1] and some may not).", + ) + + if ignore_index is not None: + if target.shape[1] == 1 and torch.any(target == ignore_index): + pytest.skip( + "When targets are single elements and 'ignore_index' in target " + "the function will raise an error because it will receive an " + "empty array after filtering out the 'ignore_index' values.", + ) + target = _inject_ignore_index(target, ignore_index) + + device = "cuda" if torch.cuda.is_available() else "cpu" + if isinstance(thresholds, torch.Tensor): + thresholds = thresholds.to(device) + + self.run_metric_class_implementation_test( + target, + preds, + metric_class=BinaryAUROC, + metric_args={ + "max_fpr": max_fpr, + "thresholds": thresholds, + "ignore_index": ignore_index, + }, + reference_metric=partial( + _binary_auroc_reference, + max_fpr=max_fpr, + thresholds=thresholds, + ignore_index=ignore_index, + ), + device=device, + use_device_for_ref=True, + ) + + +def _multiclass_auroc_reference( + target, + preds, + num_classes=NUM_CLASSES, + thresholds=None, + average="macro", + ignore_index=None, +) -> torch.Tensor: + """Return the reference multiclass AUROC.""" + if preds.ndim == 1 and is_floating_point(preds): + xp = apc.array_namespace(preds) + preds = xp.argmax(preds, axis=0) + + return tm_multiclass_auroc( + torch.utils.dlpack.from_dlpack(preds), + torch.utils.dlpack.from_dlpack(target), + num_classes, + thresholds=torch.utils.dlpack.from_dlpack(thresholds) + if apc.is_array_api_obj(thresholds) + else thresholds, + average=average, + ignore_index=ignore_index, + ) + + +class TestMulticlassAUROC(MetricTester): + """Test multiclass AUROC function and class.""" + + atol = 2e-7 + + @pytest.mark.parametrize("inputs", _multiclass_cases(xp=anp)[4:]) + @pytest.mark.parametrize("thresholds", _thresholds(xp=anp)) + @pytest.mark.parametrize("average", [None, "none", "macro", "weighted"]) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_multiclass_auroc_with_numpy_array_api_arrays( + self, + inputs, + thresholds, + average, + ignore_index, + ) -> None: + """Test function for multiclass AUROC.""" + target, preds = inputs + + if ignore_index is not None: + if target.shape[1] == 1 and anp.any(target == ignore_index): + pytest.skip( + "When targets are single elements and 'ignore_index' in target " + "the function will raise an error because it will receive an " + "empty array after filtering out the 'ignore_index' values.", + ) + target = _inject_ignore_index(target, ignore_index) + + self.run_metric_function_implementation_test( + target, + preds, + metric_function=multiclass_auroc, + metric_args={ + "num_classes": NUM_CLASSES, + "thresholds": thresholds, + "average": average, + "ignore_index": ignore_index, + }, + reference_metric=partial( + _multiclass_auroc_reference, + thresholds=thresholds, + average=average, + ignore_index=ignore_index, + ), + ) + + @pytest.mark.parametrize("inputs", _multiclass_cases(xp=anp)[4:]) + @pytest.mark.parametrize("thresholds", _thresholds(xp=anp)) + @pytest.mark.parametrize("average", [None, "none", "macro", "weighted"]) + @pytest.mark.parametrize("ignore_index", [None, 1, -1]) + def test_multiclass_auroc_class_with_numpy_array_api_arrays( + self, + inputs, + thresholds, + average, + ignore_index, + ) -> None: + """Test class for multiclass AUROC.""" + target, preds = inputs + + if ignore_index is not None: + if target.shape[1] == 1 and anp.any(target == ignore_index): + pytest.skip( + "When targets are single elements and 'ignore_index' in target " + "the function will raise an error because it will receive an " + "empty array after filtering out the 'ignore_index' values.", + ) + target = _inject_ignore_index(target, ignore_index) + + self.run_metric_class_implementation_test( + target, + preds, + metric_class=MulticlassAUROC, + reference_metric=partial( + _multiclass_auroc_reference, + thresholds=thresholds, + average=average, + ignore_index=ignore_index, + ), + metric_args={ + "num_classes": NUM_CLASSES, + "average": average, + "thresholds": thresholds, + "ignore_index": ignore_index, + }, + ) + + @pytest.mark.integration_test() # machine for integration tests has GPU + @pytest.mark.parametrize("inputs", _multiclass_cases(xp=array_api_compat.torch)[4:]) + @pytest.mark.parametrize( + "thresholds", + _thresholds(xp=array_api_compat.torch), + ) + @pytest.mark.parametrize("average", [None, "none", "macro", "weighted"]) + @pytest.mark.parametrize("ignore_index", [None, 1, -1]) + def test_multiclass_auroc_class_with_torch_tensors( + self, + inputs, + thresholds, + average, + ignore_index, + ) -> None: + """Test class for multiclass AUROC.""" + target, preds = inputs + + if ignore_index is not None: + if target.shape[1] == 1 and torch.any(target == ignore_index): + pytest.skip( + "When targets are single elements and 'ignore_index' in target " + "the function will raise an error because it will receive an " + "empty array after filtering out the 'ignore_index' values.", + ) + target = _inject_ignore_index(target, ignore_index) + + device = "cuda" if torch.cuda.is_available() else "cpu" + if isinstance(thresholds, torch.Tensor): + thresholds = thresholds.to(device) + + self.run_metric_class_implementation_test( + target, + preds, + metric_class=MulticlassAUROC, + reference_metric=partial( + _multiclass_auroc_reference, + thresholds=thresholds, + average=average, + ignore_index=ignore_index, + ), + metric_args={ + "num_classes": NUM_CLASSES, + "thresholds": thresholds, + "average": average, + "ignore_index": ignore_index, + }, + device=device, + use_device_for_ref=True, + ) + + +def _multilabel_auroc_reference( + preds, + target, + num_labels=NUM_LABELS, + thresholds=None, + average="macro", + ignore_index=None, +) -> torch.Tensor: + """Return the reference multilabel AUROC.""" + return tm_multilabel_auroc( + torch.utils.dlpack.from_dlpack(preds), + torch.utils.dlpack.from_dlpack(target), + num_labels, + thresholds=torch.utils.dlpack.from_dlpack(thresholds) + if apc.is_array_api_obj(thresholds) + else thresholds, + average=average, + ignore_index=ignore_index, + ) + + +class TestMultilabelAUROC(MetricTester): + """Test multilabel AUROC function and class.""" + + atol = 2e-7 + + @pytest.mark.parametrize("inputs", _multilabel_cases(xp=anp)[2:]) + @pytest.mark.parametrize("thresholds", _thresholds(xp=anp)) + @pytest.mark.parametrize("average", [None, "none", "micro", "macro", "weighted"]) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_multilabel_auroc_with_numpy_array_api_arrays( + self, + inputs, + thresholds, + average, + ignore_index, + ) -> None: + """Test function for multilabel AUROC.""" + target, preds = inputs + + if ignore_index is not None: + target = _inject_ignore_index(target, ignore_index) + + self.run_metric_function_implementation_test( + target, + preds, + metric_function=multilabel_auroc, + reference_metric=partial( + _multilabel_auroc_reference, + thresholds=thresholds, + average=average, + ignore_index=ignore_index, + ), + metric_args={ + "num_labels": NUM_LABELS, + "thresholds": thresholds, + "average": average, + "ignore_index": ignore_index, + }, + ) + + @pytest.mark.parametrize("inputs", _multilabel_cases(xp=anp)[2:]) + @pytest.mark.parametrize("thresholds", _thresholds(xp=anp)) + @pytest.mark.parametrize("average", [None, "none", "micro", "macro", "weighted"]) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_multilabel_auroc_class_with_numpy_array_api_arrays( + self, + inputs, + thresholds, + average, + ignore_index, + ) -> None: + """Test class for multilabel AUROC.""" + target, preds = inputs + + if ignore_index is not None: + target = _inject_ignore_index(target, ignore_index) + + self.run_metric_class_implementation_test( + target, + preds, + metric_class=MultilabelAUROC, + reference_metric=partial( + _multilabel_auroc_reference, + thresholds=thresholds, + average=average, + ignore_index=ignore_index, + ), + metric_args={ + "num_labels": NUM_LABELS, + "thresholds": thresholds, + "average": average, + "ignore_index": ignore_index, + }, + ) + + @pytest.mark.integration_test() # machine for integration tests has GPU + @pytest.mark.parametrize("inputs", _multilabel_cases(xp=array_api_compat.torch)[2:]) + @pytest.mark.parametrize( + "thresholds", + _thresholds(xp=array_api_compat.torch), + ) + @pytest.mark.parametrize("average", [None, "none", "micro", "macro", "weighted"]) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_multilabel_auroc_class_with_torch_tensors( + self, + inputs, + thresholds, + average, + ignore_index, + ) -> None: + """Test class for multilabel AUROC.""" + target, preds = inputs + + if ignore_index is not None: + target = _inject_ignore_index(target, ignore_index) + + device = "cuda" if torch.cuda.is_available() else "cpu" + if isinstance(thresholds, torch.Tensor): + thresholds = thresholds.to(device) + + self.run_metric_class_implementation_test( + target, + preds, + metric_class=MultilabelAUROC, + reference_metric=partial( + _multilabel_auroc_reference, + thresholds=thresholds, + average=average, + ignore_index=ignore_index, + ), + metric_args={ + "num_labels": NUM_LABELS, + "thresholds": thresholds, + "average": average, + "ignore_index": ignore_index, + }, + device=device, + use_device_for_ref=True, + ) diff --git a/tests/cyclops/evaluate/metrics/experimental/test_precision_recall_curve.py b/tests/cyclops/evaluate/metrics/experimental/test_precision_recall_curve.py index f0608d059..4dc5989fd 100644 --- a/tests/cyclops/evaluate/metrics/experimental/test_precision_recall_curve.py +++ b/tests/cyclops/evaluate/metrics/experimental/test_precision_recall_curve.py @@ -1,6 +1,5 @@ """Test precision-recall curve metric.""" from functools import partial -from types import ModuleType from typing import List, Tuple, Union import array_api_compat as apc @@ -32,16 +31,10 @@ from cyclops.evaluate.metrics.experimental.utils.validation import is_floating_point from ..conftest import NUM_CLASSES, NUM_LABELS -from .inputs import _binary_cases, _multiclass_cases, _multilabel_cases +from .inputs import _binary_cases, _multiclass_cases, _multilabel_cases, _thresholds from .testers import MetricTester, _inject_ignore_index -def _thresholds_for_prc(*, xp: ModuleType) -> list: - """Return thresholds for precision-recall curve.""" - thresh_list = [0.0, 0.3, 0.5, 0.7, 0.9, 1.0] - return [None, 5, thresh_list, xp.asarray(thresh_list)] - - def _binary_precision_recall_curve_reference( target, preds, @@ -63,7 +56,7 @@ class TestBinaryPrecisionRecallCurve(MetricTester): """Test binary precision-recall curve function and class.""" @pytest.mark.parametrize("inputs", _binary_cases(xp=anp)[3:]) - @pytest.mark.parametrize("thresholds", _thresholds_for_prc(xp=anp)) + @pytest.mark.parametrize("thresholds", _thresholds(xp=anp)) @pytest.mark.parametrize("ignore_index", [None, 0, -1]) def test_binary_precision_recall_curve_function_with_numpy_array_api_arrays( self, @@ -99,7 +92,7 @@ def test_binary_precision_recall_curve_function_with_numpy_array_api_arrays( ) @pytest.mark.parametrize("inputs", _binary_cases(xp=anp)[3:]) - @pytest.mark.parametrize("thresholds", _thresholds_for_prc(xp=anp)) + @pytest.mark.parametrize("thresholds", _thresholds(xp=anp)) @pytest.mark.parametrize("ignore_index", [None, 0, -1]) def test_binary_precision_recall_curve_class_with_numpy_array_api_arrays( self, @@ -149,7 +142,7 @@ def test_binary_precision_recall_curve_class_with_numpy_array_api_arrays( @pytest.mark.parametrize("inputs", _binary_cases(xp=array_api_compat.torch)[3:]) @pytest.mark.parametrize( "thresholds", - _thresholds_for_prc(xp=array_api_compat.torch), + _thresholds(xp=array_api_compat.torch), ) @pytest.mark.parametrize("ignore_index", [None, 0, -1]) def test_binary_precision_recall_curve_with_torch_tensors( @@ -233,7 +226,7 @@ class TestMulticlassPrecisionRecallCurve(MetricTester): """Test multiclass precision-recall curve function and class.""" @pytest.mark.parametrize("inputs", _multiclass_cases(xp=anp)[4:]) - @pytest.mark.parametrize("thresholds", _thresholds_for_prc(xp=anp)) + @pytest.mark.parametrize("thresholds", _thresholds(xp=anp)) @pytest.mark.parametrize("average", [None, "none"]) @pytest.mark.parametrize("ignore_index", [None, 0, -1]) def test_multiclass_precision_recall_curve_with_numpy_array_api_arrays( @@ -273,7 +266,7 @@ def test_multiclass_precision_recall_curve_with_numpy_array_api_arrays( ) @pytest.mark.parametrize("inputs", _multiclass_cases(xp=anp)[4:]) - @pytest.mark.parametrize("thresholds", _thresholds_for_prc(xp=anp)) + @pytest.mark.parametrize("thresholds", _thresholds(xp=anp)) @pytest.mark.parametrize("average", [None, "none"]) @pytest.mark.parametrize("ignore_index", [None, 1, -1]) def test_multiclass_precision_recall_curve_class_with_numpy_array_api_arrays( @@ -316,7 +309,7 @@ def test_multiclass_precision_recall_curve_class_with_numpy_array_api_arrays( @pytest.mark.parametrize("inputs", _multiclass_cases(xp=array_api_compat.torch)[4:]) @pytest.mark.parametrize( "thresholds", - _thresholds_for_prc(xp=array_api_compat.torch), + _thresholds(xp=array_api_compat.torch), ) @pytest.mark.parametrize("average", [None, "none"]) @pytest.mark.parametrize("ignore_index", [None, 1, -1]) @@ -389,7 +382,7 @@ class TestMultilabelPrecisionRecallCurve(MetricTester): """Test multilabel precision-recall curve function and class.""" @pytest.mark.parametrize("inputs", _multilabel_cases(xp=anp)[2:]) - @pytest.mark.parametrize("thresholds", _thresholds_for_prc(xp=anp)) + @pytest.mark.parametrize("thresholds", _thresholds(xp=anp)) @pytest.mark.parametrize("ignore_index", [None, 0, -1]) def test_multilabel_precision_recall_curve_with_numpy_array_api_arrays( self, @@ -420,7 +413,7 @@ def test_multilabel_precision_recall_curve_with_numpy_array_api_arrays( ) @pytest.mark.parametrize("inputs", _multilabel_cases(xp=anp)[2:]) - @pytest.mark.parametrize("thresholds", _thresholds_for_prc(xp=anp)) + @pytest.mark.parametrize("thresholds", _thresholds(xp=anp)) @pytest.mark.parametrize("ignore_index", [None, 0, -1]) def test_multilabel_precision_recall_curve_class_with_numpy_array_api_arrays( self, @@ -454,7 +447,7 @@ def test_multilabel_precision_recall_curve_class_with_numpy_array_api_arrays( @pytest.mark.parametrize("inputs", _multilabel_cases(xp=array_api_compat.torch)[2:]) @pytest.mark.parametrize( "thresholds", - _thresholds_for_prc(xp=array_api_compat.torch), + _thresholds(xp=array_api_compat.torch), ) @pytest.mark.parametrize("ignore_index", [None, 0, -1]) def test_multilabel_precision_recall_curve_class_with_torch_tensors( diff --git a/tests/cyclops/evaluate/metrics/experimental/test_roc.py b/tests/cyclops/evaluate/metrics/experimental/test_roc.py index c1a977268..ddc4f9556 100644 --- a/tests/cyclops/evaluate/metrics/experimental/test_roc.py +++ b/tests/cyclops/evaluate/metrics/experimental/test_roc.py @@ -1,6 +1,5 @@ """Test roc curve metric.""" from functools import partial -from types import ModuleType from typing import List, Tuple, Union import array_api_compat as apc @@ -32,16 +31,10 @@ from cyclops.evaluate.metrics.experimental.utils.validation import is_floating_point from ..conftest import NUM_CLASSES, NUM_LABELS -from .inputs import _binary_cases, _multiclass_cases, _multilabel_cases +from .inputs import _binary_cases, _multiclass_cases, _multilabel_cases, _thresholds from .testers import MetricTester, _inject_ignore_index -def _thresholds_for_roc(*, xp: ModuleType) -> list: - """Return thresholds for roc curve.""" - thresh_list = [0.0, 0.3, 0.5, 0.7, 0.9, 1.0] - return [None, 5, thresh_list, xp.asarray(thresh_list)] - - def _binary_roc_reference( target, preds, @@ -63,7 +56,7 @@ class TestBinaryROC(MetricTester): """Test binary roc curve function and class.""" @pytest.mark.parametrize("inputs", _binary_cases(xp=anp)[3:]) - @pytest.mark.parametrize("thresholds", _thresholds_for_roc(xp=anp)) + @pytest.mark.parametrize("thresholds", _thresholds(xp=anp)) @pytest.mark.parametrize("ignore_index", [None, 0, -1]) def test_binary_roc_function_with_numpy_array_api_arrays( self, @@ -99,7 +92,7 @@ def test_binary_roc_function_with_numpy_array_api_arrays( ) @pytest.mark.parametrize("inputs", _binary_cases(xp=anp)[3:]) - @pytest.mark.parametrize("thresholds", _thresholds_for_roc(xp=anp)) + @pytest.mark.parametrize("thresholds", _thresholds(xp=anp)) @pytest.mark.parametrize("ignore_index", [None, 0, -1]) def test_binary_roc_class_with_numpy_array_api_arrays( self, @@ -149,7 +142,7 @@ def test_binary_roc_class_with_numpy_array_api_arrays( @pytest.mark.parametrize("inputs", _binary_cases(xp=array_api_compat.torch)[3:]) @pytest.mark.parametrize( "thresholds", - _thresholds_for_roc(xp=array_api_compat.torch), + _thresholds(xp=array_api_compat.torch), ) @pytest.mark.parametrize("ignore_index", [None, 0, -1]) def test_binary_roc_with_torch_tensors( @@ -233,7 +226,7 @@ class TestMulticlassROC(MetricTester): """Test multiclass roc curve function and class.""" @pytest.mark.parametrize("inputs", _multiclass_cases(xp=anp)[4:]) - @pytest.mark.parametrize("thresholds", _thresholds_for_roc(xp=anp)) + @pytest.mark.parametrize("thresholds", _thresholds(xp=anp)) @pytest.mark.parametrize("average", [None, "none"]) @pytest.mark.parametrize("ignore_index", [None, 0, -1]) def test_multiclass_roc_with_numpy_array_api_arrays( @@ -273,7 +266,7 @@ def test_multiclass_roc_with_numpy_array_api_arrays( ) @pytest.mark.parametrize("inputs", _multiclass_cases(xp=anp)[4:]) - @pytest.mark.parametrize("thresholds", _thresholds_for_roc(xp=anp)) + @pytest.mark.parametrize("thresholds", _thresholds(xp=anp)) @pytest.mark.parametrize("average", [None, "none"]) @pytest.mark.parametrize("ignore_index", [None, 1, -1]) def test_multiclass_roc_class_with_numpy_array_api_arrays( @@ -316,7 +309,7 @@ def test_multiclass_roc_class_with_numpy_array_api_arrays( @pytest.mark.parametrize("inputs", _multiclass_cases(xp=array_api_compat.torch)[4:]) @pytest.mark.parametrize( "thresholds", - _thresholds_for_roc(xp=array_api_compat.torch), + _thresholds(xp=array_api_compat.torch), ) @pytest.mark.parametrize("average", [None, "none"]) @pytest.mark.parametrize("ignore_index", [None, 1, -1]) @@ -389,7 +382,7 @@ class TestMultilabelROC(MetricTester): """Test multilabel roc curve function and class.""" @pytest.mark.parametrize("inputs", _multilabel_cases(xp=anp)[2:]) - @pytest.mark.parametrize("thresholds", _thresholds_for_roc(xp=anp)) + @pytest.mark.parametrize("thresholds", _thresholds(xp=anp)) @pytest.mark.parametrize("ignore_index", [None, 0, -1]) def test_multilabel_roc_with_numpy_array_api_arrays( self, @@ -420,7 +413,7 @@ def test_multilabel_roc_with_numpy_array_api_arrays( ) @pytest.mark.parametrize("inputs", _multilabel_cases(xp=anp)[2:]) - @pytest.mark.parametrize("thresholds", _thresholds_for_roc(xp=anp)) + @pytest.mark.parametrize("thresholds", _thresholds(xp=anp)) @pytest.mark.parametrize("ignore_index", [None, 0, -1]) def test_multilabel_roc_class_with_numpy_array_api_arrays( self, @@ -454,7 +447,7 @@ def test_multilabel_roc_class_with_numpy_array_api_arrays( @pytest.mark.parametrize("inputs", _multilabel_cases(xp=array_api_compat.torch)[2:]) @pytest.mark.parametrize( "thresholds", - _thresholds_for_roc(xp=array_api_compat.torch), + _thresholds(xp=array_api_compat.torch), ) @pytest.mark.parametrize("ignore_index", [None, 0, -1]) def test_multilabel_roc_class_with_torch_tensors( From 5c4ebb240f60c1c9060477e6f3b002499e0494b0 Mon Sep 17 00:00:00 2001 From: Franklin <41602287+fcogidi@users.noreply.github.com> Date: Tue, 16 Jan 2024 09:19:53 -0500 Subject: [PATCH 5/6] Add regression metrics (#547) * Add regression support * update docstrings * Ignore `no-any-return` errors --- .../evaluate/metrics/experimental/__init__.py | 9 + .../experimental/functional/__init__.py | 11 ++ .../metrics/experimental/functional/mae.py | 88 +++++++++ .../metrics/experimental/functional/mape.py | 124 ++++++++++++ .../metrics/experimental/functional/mse.py | 140 ++++++++++++++ .../metrics/experimental/functional/smape.py | 126 +++++++++++++ .../metrics/experimental/functional/wmape.py | 106 +++++++++++ cyclops/evaluate/metrics/experimental/mae.py | 48 +++++ cyclops/evaluate/metrics/experimental/mape.py | 65 +++++++ cyclops/evaluate/metrics/experimental/mse.py | 84 +++++++++ .../evaluate/metrics/experimental/smape.py | 66 +++++++ .../evaluate/metrics/experimental/wmape.py | 72 +++++++ pyproject.toml | 4 +- .../evaluate/metrics/experimental/inputs.py | 20 ++ .../metrics/experimental/test_mean_error.py | 176 ++++++++++++++++++ 15 files changed, 1137 insertions(+), 2 deletions(-) create mode 100644 cyclops/evaluate/metrics/experimental/functional/mae.py create mode 100644 cyclops/evaluate/metrics/experimental/functional/mape.py create mode 100644 cyclops/evaluate/metrics/experimental/functional/mse.py create mode 100644 cyclops/evaluate/metrics/experimental/functional/smape.py create mode 100644 cyclops/evaluate/metrics/experimental/functional/wmape.py create mode 100644 cyclops/evaluate/metrics/experimental/mae.py create mode 100644 cyclops/evaluate/metrics/experimental/mape.py create mode 100644 cyclops/evaluate/metrics/experimental/mse.py create mode 100644 cyclops/evaluate/metrics/experimental/smape.py create mode 100644 cyclops/evaluate/metrics/experimental/wmape.py create mode 100644 tests/cyclops/evaluate/metrics/experimental/test_mean_error.py diff --git a/cyclops/evaluate/metrics/experimental/__init__.py b/cyclops/evaluate/metrics/experimental/__init__.py index e052ee224..ec6c72609 100644 --- a/cyclops/evaluate/metrics/experimental/__init__.py +++ b/cyclops/evaluate/metrics/experimental/__init__.py @@ -22,7 +22,10 @@ MultilabelF1Score, MultilabelFBetaScore, ) +from cyclops.evaluate.metrics.experimental.mae import MeanAbsoluteError +from cyclops.evaluate.metrics.experimental.mape import MeanAbsolutePercentageError from cyclops.evaluate.metrics.experimental.metric_dict import MetricDict +from cyclops.evaluate.metrics.experimental.mse import MeanSquaredError from cyclops.evaluate.metrics.experimental.negative_predictive_value import ( BinaryNPV, MulticlassNPV, @@ -55,6 +58,9 @@ MulticlassROC, MultilabelROC, ) +from cyclops.evaluate.metrics.experimental.smape import ( + SymmetricMeanAbsolutePercentageError, +) from cyclops.evaluate.metrics.experimental.specificity import ( BinarySpecificity, BinaryTNR, @@ -63,3 +69,6 @@ MultilabelSpecificity, MultilabelTNR, ) +from cyclops.evaluate.metrics.experimental.wmape import ( + WeightedMeanAbsolutePercentageError, +) diff --git a/cyclops/evaluate/metrics/experimental/functional/__init__.py b/cyclops/evaluate/metrics/experimental/functional/__init__.py index 14a887191..e24543e64 100644 --- a/cyclops/evaluate/metrics/experimental/functional/__init__.py +++ b/cyclops/evaluate/metrics/experimental/functional/__init__.py @@ -22,6 +22,11 @@ multilabel_f1_score, multilabel_fbeta_score, ) +from cyclops.evaluate.metrics.experimental.functional.mae import mean_absolute_error +from cyclops.evaluate.metrics.experimental.functional.mape import ( + mean_absolute_percentage_error, +) +from cyclops.evaluate.metrics.experimental.functional.mse import mean_squared_error from cyclops.evaluate.metrics.experimental.functional.negative_predictive_value import ( binary_npv, multiclass_npv, @@ -51,6 +56,9 @@ multiclass_roc, multilabel_roc, ) +from cyclops.evaluate.metrics.experimental.functional.smape import ( + symmetric_mean_absolute_percentage_error, +) from cyclops.evaluate.metrics.experimental.functional.specificity import ( binary_specificity, binary_tnr, @@ -59,3 +67,6 @@ multilabel_specificity, multilabel_tnr, ) +from cyclops.evaluate.metrics.experimental.functional.wmape import ( + weighted_mean_absolute_percentage_error, +) diff --git a/cyclops/evaluate/metrics/experimental/functional/mae.py b/cyclops/evaluate/metrics/experimental/functional/mae.py new file mode 100644 index 000000000..870259e3b --- /dev/null +++ b/cyclops/evaluate/metrics/experimental/functional/mae.py @@ -0,0 +1,88 @@ +"""Functional interface for the mean absolute error metric.""" +from typing import Tuple, Union + +import array_api_compat as apc + +from cyclops.evaluate.metrics.experimental.utils.types import Array +from cyclops.evaluate.metrics.experimental.utils.validation import ( + _basic_input_array_checks, + _check_same_shape, + is_floating_point, +) + + +def _mean_absolute_error_update(target: Array, preds: Array) -> Tuple[Array, int]: + """Update and return variables required to compute Mean Absolute Error.""" + _basic_input_array_checks(target, preds) + _check_same_shape(target, preds) + xp = apc.array_namespace(target, preds) + + target = target if is_floating_point(target) else xp.astype(target, xp.float32) + preds = preds if is_floating_point(preds) else xp.astype(preds, xp.float32) + + sum_abs_error = xp.sum(xp.abs(preds - target), dtype=xp.float32) + num_obs = int(apc.size(target) or 0) + return sum_abs_error, num_obs + + +def _mean_absolute_error_compute( + sum_abs_error: Array, + num_obs: Union[int, Array], +) -> Array: + """Compute Mean Absolute Error. + + Parameters + ---------- + sum_abs_error : Array + Sum of absolute value of errors over all observations. + num_obs : int, Array + Total number of observations. + + Returns + ------- + Array + The mean absolute error. + + """ + return sum_abs_error / num_obs # type: ignore[no-any-return] + + +def mean_absolute_error(target: Array, preds: Array) -> Array: + """Compute the mean absolute error. + + Parameters + ---------- + target : Array + Ground truth target values. + preds : Array + Estimated target values. + + Return + ------ + Array + The mean absolute error. + + Raises + ------ + TypeError + If `target` or `preds` is not an array object that is compatible with + the Python array API standard. + ValueError + If `target` or `preds` is empty. + ValueError + If `target` or `preds` is not a numeric array. + ValueError + If the shape of `target` and `preds` are not the same. + + Examples + -------- + >>> import numpy.array_api as anp + >>> from cyclops.evaluate.metrics.experimental.functional import mean_absolute_error + >>> target = anp.asarray([0.009, 1.05, 2., 3.]) + >>> preds = anp.asarray([0., 1., 2., 2.]) + >>> mean_absolute_error(target, preds) + Array(0.26475, dtype=float32) + + """ + sum_abs_error, num_obs = _mean_absolute_error_update(target, preds) + return _mean_absolute_error_compute(sum_abs_error, num_obs) diff --git a/cyclops/evaluate/metrics/experimental/functional/mape.py b/cyclops/evaluate/metrics/experimental/functional/mape.py new file mode 100644 index 000000000..4c90cd733 --- /dev/null +++ b/cyclops/evaluate/metrics/experimental/functional/mape.py @@ -0,0 +1,124 @@ +"""Functional interface for the mean absolute percentage error metric.""" +from typing import Tuple, Union + +import array_api_compat as apc + +from cyclops.evaluate.metrics.experimental.utils.types import Array +from cyclops.evaluate.metrics.experimental.utils.validation import ( + _basic_input_array_checks, + _check_same_shape, +) + + +def _mean_absolute_percentage_error_update( + target: Array, + preds: Array, + epsilon: float = 1.17e-06, +) -> Tuple[Array, int]: + """Update and return variables required to compute the Mean Percentage Error. + + Parameters + ---------- + target : Array + Ground truth target values. + preds : Array + Estimated target values. + epsilon : float, optional, default=1.17e-06 + Specifies the lower bound for target values. Any target value below epsilon + is set to epsilon (avoids division by zero errors). + + Returns + ------- + Tuple[Array, int] + Sum of absolute value of percentage errors over all observations and number + of observations. + + """ + _basic_input_array_checks(target, preds) + _check_same_shape(target, preds) + xp = apc.array_namespace(target, preds) + + abs_diff = xp.abs(preds - target) + abs_target = xp.abs(target) + clamped_abs_target = xp.where( + abs_target < epsilon, + xp.asarray(epsilon, dtype=abs_target.dtype, device=apc.device(abs_target)), + abs_target, + ) + abs_per_error = abs_diff / clamped_abs_target + + sum_abs_per_error = xp.sum(abs_per_error, dtype=xp.float32) + + num_obs = int(apc.size(target) or 0) + + return sum_abs_per_error, num_obs + + +def _mean_absolute_percentage_error_compute( + sum_abs_per_error: Array, + num_obs: Union[int, Array], +) -> Array: + """Compute the Mean Absolute Percentage Error. + + Parameters + ---------- + sum_abs_per_error : Array + Sum of absolute value of percentage errors over all observations. + ``(percentage error = (target - prediction) / target)`` + num_obs : int, Array + Number of observations. + + Returns + ------- + Array + The mean absolute percentage error. + + """ + return sum_abs_per_error / num_obs # type: ignore[no-any-return] + + +def mean_absolute_percentage_error(target: Array, preds: Array) -> Array: + """Compute the mean absolute percentage error. + + Parameters + ---------- + target : Array + Ground truth target values. + preds : Array + Estimated target values. + + Returns + ------- + Array + The mean absolute percentage error. + + Raises + ------ + TypeError + If `target` or `preds` is not an array object that is compatible with + the Python array API standard. + ValueError + If `target` or `preds` is empty. + ValueError + If `target` or `preds` is not a numeric array. + ValueError + If the shape of `target` and `preds` are not the same. + + Notes + ----- + The epsilon value is taken from `scikit-learn's implementation of MAPE`. + + Examples + -------- + >>> import numpy.array_api as anp + >>> from cyclops.evaluate.metrics.experimental.functional import ( + ... mean_absolute_percentage_error + ... ) + >>> target = anp.asarray([1., 10., 1e6]) + >>> preds = anp.asarray([0.9, 15., 1.2e6]) + >>> mean_absolute_percentage_error(target, preds) + Array(0.26666668, dtype=float32) + + """ + sum_abs_per_error, num_obs = _mean_absolute_percentage_error_update(target, preds) + return _mean_absolute_percentage_error_compute(sum_abs_per_error, num_obs) diff --git a/cyclops/evaluate/metrics/experimental/functional/mse.py b/cyclops/evaluate/metrics/experimental/functional/mse.py new file mode 100644 index 000000000..4af5dd535 --- /dev/null +++ b/cyclops/evaluate/metrics/experimental/functional/mse.py @@ -0,0 +1,140 @@ +"""Functional interface for the mean squared error metric.""" +from typing import Tuple, Union + +import array_api_compat as apc + +from cyclops.evaluate.metrics.experimental.utils.ops import flatten, squeeze_all +from cyclops.evaluate.metrics.experimental.utils.types import Array +from cyclops.evaluate.metrics.experimental.utils.validation import ( + _basic_input_array_checks, + _check_same_shape, +) + + +def _mean_squared_error_update( + target: Array, + preds: Array, + num_outputs: int, +) -> Tuple[Array, int]: + """Update and returns variables required to compute the Mean Squared Error. + + Parameters + ---------- + target : Array + Ground truth target values. + preds : Array + Estimated target values. + num_outputs : int + Number of outputs in multioutput setting. + + Returns + ------- + Tuple[Array, int] + Sum of square of errors over all observations and number of observations. + + """ + _basic_input_array_checks(target, preds) + _check_same_shape(target, preds) + xp = apc.array_namespace(target, preds) + + if num_outputs == 1: + target = flatten(target, copy=False) + preds = flatten(preds, copy=False) + + diff = preds - target + sum_squared_error = xp.sum(diff * diff, axis=0, dtype=xp.float32) + return sum_squared_error, target.shape[0] + + +def _mean_squared_error_compute( + sum_squared_error: Array, + num_obs: Union[int, Array], + squared: bool = True, +) -> Array: + """Compute Mean Squared Error. + + Parameters + ---------- + sum_squared_error : Array + Sum of square of errors over all observations. + num_obs : Array + Number of predictions or observations. + squared : bool, optional, default=True + Whether to return MSE or RMSE. If set to False, returns RMSE. + + Returns + ------- + Array + The mean squared error or root mean squared error. + + """ + xp = apc.array_namespace(sum_squared_error) + return squeeze_all( + sum_squared_error / num_obs + if squared + else xp.sqrt(sum_squared_error / num_obs), + ) + + +def mean_squared_error( + target: Array, + preds: Array, + squared: bool = True, + num_outputs: int = 1, +) -> Array: + """Compute mean squared error. + + Parameters + ---------- + target : Array + Ground truth target values. + preds : Array + Estimated target values. + squared : bool, optional, default=True + Whether to return mean squared error or root mean squared error. If set + to `False`, returns the root mean squared error. + num_outputs : int, optional, default=1 + Number of outputs in multioutput setting. + + Raises + ------ + TypeError + If `target` or `preds` is not an array object that is compatible with + the Python array API standard. + ValueError + If `target` or `preds` is empty. + ValueError + If `target` or `preds` is not a numeric array. + ValueError + If the shape of `target` and `preds` are not the same. + + Returns + ------- + Array + The mean squared error or root mean squared error. + + Examples + -------- + >>> import numpy.array_api as anp + >>> from cyclops.evaluate.metrics.experimental.functional import mean_squared_error + >>> target = anp.asarray([0., 1., 2., 3.]) + >>> preds = anp.asarray([0.025, 1., 2., 2.44]) + >>> mean_squared_error(target, preds) + Array(0.07855625, dtype=float32) + >>> mean_squared_error(target, preds, squared=False) + Array(0.2802789, dtype=float32) + >>> target = anp.asarray([[0., 1.], [2., 3.]]) + >>> preds = anp.asarray([[0.025, 1.], [2., 2.44]]) + >>> mean_squared_error(target, preds, num_outputs=2) + Array([0.0003125, 0.1568 ], dtype=float32) + >>> mean_squared_error(target, preds, squared=False, num_outputs=2) + Array([0.01767767, 0.3959798 ], dtype=float32) + + + """ + sum_squared_error, num_obs = _mean_squared_error_update( + preds, + target, + num_outputs=num_outputs, + ) + return _mean_squared_error_compute(sum_squared_error, num_obs, squared=squared) diff --git a/cyclops/evaluate/metrics/experimental/functional/smape.py b/cyclops/evaluate/metrics/experimental/functional/smape.py new file mode 100644 index 000000000..badefeeb1 --- /dev/null +++ b/cyclops/evaluate/metrics/experimental/functional/smape.py @@ -0,0 +1,126 @@ +"""Functional interface for the Symmetric Mean Absolute Percentage Error metric.""" +from typing import Tuple, Union + +import array_api_compat as apc + +from cyclops.evaluate.metrics.experimental.utils.types import Array +from cyclops.evaluate.metrics.experimental.utils.validation import ( + _basic_input_array_checks, + _check_same_shape, +) + + +def _symmetric_mean_absolute_percentage_error_update( + target: Array, + preds: Array, + epsilon: float = 1.17e-06, +) -> Tuple[Array, int]: + """Update and return variables required to compute Symmetric MAPE. + + Parameters + ---------- + target : Array + Ground truth target values. + preds : Array + Estimated target values. + epsilon : float, optional, default=1.17e-06 + Specifies the lower bound for target values. Any target value below epsilon + is set to epsilon (avoids division by zero errors). + + Returns + ------- + Tuple[Array, int] + Sum of absolute value of percentage errors over all observations and number + of observations. + + """ + _basic_input_array_checks(target, preds) + _check_same_shape(target, preds) + xp = apc.array_namespace(target, preds) + + abs_diff = xp.abs(preds - target) + arr_sum = xp.abs(target) + xp.abs(preds) + clamped_val = xp.where( + arr_sum < epsilon, + xp.asarray(epsilon, dtype=arr_sum.dtype, device=apc.device(arr_sum)), + arr_sum, + ) + abs_per_error = abs_diff / clamped_val + + sum_abs_per_error = 2 * xp.sum(abs_per_error, dtype=xp.float32) + + num_obs = int(apc.size(target) or 0) + + return sum_abs_per_error, num_obs + + +def _symmetric_mean_absolute_percentage_error_compute( + sum_abs_per_error: Array, + num_obs: Union[int, Array], +) -> Array: + """Compute the Symmetric Mean Absolute Percentage Error. + + Parameters + ---------- + sum_abs_per_error : Array + Sum of absolute value of percentage errors over all observations. + ``(percentage error = (target - prediction) / target)`` + num_obs : int, Array + Total number of observations. + + Returns + ------- + Array + The symmetric mean absolute percentage error. + + """ + return sum_abs_per_error / num_obs # type: ignore[no-any-return] + + +def symmetric_mean_absolute_percentage_error(target: Array, preds: Array) -> Array: + """Compute the symmetric mean absolute percentage error (SMAPE). + + Parameters + ---------- + target : Array + Ground truth target values. + preds : Array + Estimated target values. + + Returns + ------- + Array + The symmetric mean absolute percentage error. + + Raises + ------ + TypeError + If `target` or `preds` is not an array object that is compatible with + the Python array API standard. + ValueError + If `target` or `preds` is empty. + ValueError + If `target` or `preds` is not a numeric array. + ValueError + If the shape of `target` and `preds` are not the same. + + Examples + -------- + >>> import numpy.array_api as anp + >>> from cyclops.evaluate.metrics.experimental.functional import ( + ... symmetric_mean_absolute_percentage_error, + ... ) + >>> target = anp.asarray([1., 10., 1e6]) + >>> preds = anp.asarray([0.9, 15., 1.2e6]) + >>> symmetric_mean_absolute_percentage_error(target, preds) + Array(0.2290271, dtype=float32) + + """ + sum_abs_per_error, num_obs = _symmetric_mean_absolute_percentage_error_update( + target, + preds, + ) + return _symmetric_mean_absolute_percentage_error_compute( + sum_abs_per_error, + num_obs, + ) diff --git a/cyclops/evaluate/metrics/experimental/functional/wmape.py b/cyclops/evaluate/metrics/experimental/functional/wmape.py new file mode 100644 index 000000000..0dcd5c9c8 --- /dev/null +++ b/cyclops/evaluate/metrics/experimental/functional/wmape.py @@ -0,0 +1,106 @@ +"""Functional interface for the Weighted Mean Absolute Percentage Error metric.""" +from typing import Tuple + +import array_api_compat as apc + +from cyclops.evaluate.metrics.experimental.utils.types import Array +from cyclops.evaluate.metrics.experimental.utils.validation import ( + _basic_input_array_checks, + _check_same_shape, +) + + +def _weighted_mean_absolute_percentage_error_update( + target: Array, + preds: Array, +) -> Tuple[Array, Array]: + """Update and return variables required to compute the weighted MAPE.""" + _basic_input_array_checks(target, preds) + _check_same_shape(target, preds) + xp = apc.array_namespace(target, preds) + + sum_abs_error = xp.sum(xp.abs((preds - target)), dtype=xp.float32) + sum_scale = xp.sum(xp.abs(target), dtype=xp.float32) + + return sum_abs_error, sum_scale + + +def _weighted_mean_absolute_percentage_error_compute( + sum_abs_error: Array, + sum_scale: Array, + epsilon: float = 1.17e-06, +) -> Array: + """Compute Weighted Absolute Percentage Error. + + Parameters + ---------- + sum_abs_error : Array + Sum of absolute value of errors over all observations. + sum_scale : Array + Sum of absolute value of target values over all observations. + epsilon : float, optional, default=1.17e-06 + Specifies the lower bound for target values. Any target value below epsilon + is set to epsilon (avoids division by zero errors). + + """ + xp = apc.array_namespace(sum_abs_error, sum_scale) + clamped_sum_scale = xp.where( + sum_scale < epsilon, + xp.asarray(epsilon, dtype=sum_scale.dtype, device=apc.device(sum_scale)), + sum_scale, + ) + return sum_abs_error / clamped_sum_scale # type: ignore[no-any-return] + + +def weighted_mean_absolute_percentage_error( + target: Array, + preds: Array, + epsilon: float = 1.17e-06, +) -> Array: + """Compute the weighted mean absolute percentage error (`WMAPE`). + + Parameters + ---------- + target : Array + Ground truth target values. + preds : Array + Estimated target values. + epsilon : float, optional, default=1.17e-06 + Specifies the lower bound for target values. Any target value below epsilon + is set to epsilon (avoids division by zero errors). + + Returns + ------- + Array + The weighted mean absolute percentage error. + + Raises + ------ + TypeError + If `target` or `preds` is not an array object that is compatible with + the Python array API standard. + ValueError + If `target` or `preds` is empty. + ValueError + If `target` or `preds` is not a numeric array. + ValueError + If the shape of `target` and `preds` are not the same. + + Examples + -------- + >>> import numpy.array_api as anp + >>> preds = anp.asarray([1.24, 2.3, 3.4, 4.5, 5.6, 6.7]) + >>> target = anp.asarray([1.2, 2.4, 3.6, 4.8, 6.0, 7.2]) + >>> weighted_mean_absolute_percentage_error(target, preds) + Array(0.06111111, dtype=float32) + + """ + sum_abs_error, sum_scale = _weighted_mean_absolute_percentage_error_update( + target, + preds, + ) + return _weighted_mean_absolute_percentage_error_compute( + sum_abs_error, + sum_scale, + epsilon=epsilon, + ) diff --git a/cyclops/evaluate/metrics/experimental/mae.py b/cyclops/evaluate/metrics/experimental/mae.py new file mode 100644 index 000000000..3221d3340 --- /dev/null +++ b/cyclops/evaluate/metrics/experimental/mae.py @@ -0,0 +1,48 @@ +"""Mean Absolute Error metric.""" +from cyclops.evaluate.metrics.experimental.functional.mae import ( + _mean_absolute_error_compute, + _mean_absolute_error_update, +) +from cyclops.evaluate.metrics.experimental.metric import Metric +from cyclops.evaluate.metrics.experimental.utils.types import Array + + +class MeanAbsoluteError(Metric): + """Mean Absolute Error. + + Examples + -------- + >>> import numpy.array_api as anp + >>> from cyclops.evaluate.metrics.experimental import MeanAbsoluteError + >>> target = anp.asarray([0.009, 1.05, 2., 3.]) + >>> preds = anp.asarray([0., 1., 2., 2.]) + >>> metric = MeanAbsoluteError() + >>> metric(target, preds) + Array(0.26475, dtype=float32) + + """ + + name: str = "Mean Absolute Error" + + def __init__(self) -> None: + super().__init__() + self.add_state_default_factory( + "sum_abs_error", + lambda xp: xp.asarray(0.0, dtype=xp.float32, device=self.device), # type: ignore + dist_reduce_fn="sum", + ) + self.add_state_default_factory( + "num_obs", + lambda xp: xp.asarray(0.0, dtype=xp.float32, device=self.device), # type: ignore + dist_reduce_fn="sum", + ) + + def _update_state(self, target: Array, preds: Array) -> None: + """Update state of metric.""" + sum_abs_error, num_obs = _mean_absolute_error_update(target, preds) + self.sum_abs_error += sum_abs_error # type: ignore + self.num_obs += num_obs # type: ignore + + def _compute_metric(self) -> Array: + """Compute the Mean Absolute Error.""" + return _mean_absolute_error_compute(self.sum_abs_error, self.num_obs) # type: ignore diff --git a/cyclops/evaluate/metrics/experimental/mape.py b/cyclops/evaluate/metrics/experimental/mape.py new file mode 100644 index 000000000..dede691f1 --- /dev/null +++ b/cyclops/evaluate/metrics/experimental/mape.py @@ -0,0 +1,65 @@ +"""Mean Absolute Percentage Error (MAPE) metric.""" +from cyclops.evaluate.metrics.experimental.functional.mape import ( + _mean_absolute_percentage_error_compute, + _mean_absolute_percentage_error_update, +) +from cyclops.evaluate.metrics.experimental.metric import Metric +from cyclops.evaluate.metrics.experimental.utils.types import Array + + +class MeanAbsolutePercentageError(Metric): + """Mean Absolute Percentage Error. + + Parameters + ---------- + epsilon : float, optional, default=1.17e-06 + Specifies the lower bound for target values. Any target value below epsilon + is set to epsilon (avoids division by zero errors). + + Examples + -------- + >>> import numpy.array_api as anp + >>> from cyclops.evaluate.metrics.experimental import MeanAbsolutePercentageError + >>> target = anp.asarray([0.009, 1.05, 2., 3.]) + >>> preds = anp.asarray([0., 1., 2., 2.]) + >>> metric = MeanAbsolutePercentageError() + >>> metric(target, preds) + Array(0.34523812, dtype=float32) + + """ + + name: str = "Mean Absolute Percentage Error" + + def __init__(self, epsilon: float = 1.17e-6) -> None: + super().__init__() + if not isinstance(epsilon, float): + raise TypeError(f"Expected `epsilon` to be a float. Got {type(epsilon)}") + self.epsilon = epsilon + + self.add_state_default_factory( + "sum_abs_per_error", + lambda xp: xp.asarray(0.0, dtype=xp.float32, device=self.device), # type: ignore + dist_reduce_fn="sum", + ) + self.add_state_default_factory( + "num_obs", + lambda xp: xp.asarray(0.0, dtype=xp.float32, device=self.device), # type: ignore + dist_reduce_fn="sum", + ) + + def _update_state(self, target: Array, preds: Array) -> None: + """Update state of metric.""" + sum_abs_per_error, num_obs = _mean_absolute_percentage_error_update( + target, + preds, + self.epsilon, + ) + self.sum_abs_per_error += sum_abs_per_error # type: ignore + self.num_obs += num_obs # type: ignore + + def _compute_metric(self) -> Array: + """Compute the Mean Absolute Percentage Error.""" + return _mean_absolute_percentage_error_compute( + self.sum_abs_per_error, # type: ignore + self.num_obs, # type: ignore + ) diff --git a/cyclops/evaluate/metrics/experimental/mse.py b/cyclops/evaluate/metrics/experimental/mse.py new file mode 100644 index 000000000..b8ef4b435 --- /dev/null +++ b/cyclops/evaluate/metrics/experimental/mse.py @@ -0,0 +1,84 @@ +"""Mean Squared Error metric.""" +from cyclops.evaluate.metrics.experimental.functional.mse import ( + _mean_squared_error_compute, + _mean_squared_error_update, +) +from cyclops.evaluate.metrics.experimental.metric import Metric +from cyclops.evaluate.metrics.experimental.utils.types import Array + + +class MeanSquaredError(Metric): + """Mean Squared Error. + + Parameters + ---------- + squared : bool, optional, default=True + Whether to return mean squared error or root mean squared error. If set + to `False`, returns the root mean squared error. + num_outputs : int, optional, default=1 + Number of outputs in multioutput setting. + + Examples + -------- + >>> import numpy.array_api as anp + >>> from cyclops.evaluate.metrics.experimental import MeanSquaredError + >>> target = anp.asarray([0.009, 1.05, 2., 3.]) + >>> preds = anp.asarray([0., 1., 2., 2.]) + >>> metric = MeanSquaredError() + >>> metric(target, preds) + Array(0.25064525, dtype=float32) + >>> metric = MeanSquaredError(squared=False) + >>> metric(target, preds) + Array(0.50064486, dtype=float32) + >>> metric = MeanSquaredError(num_outputs=2) + >>> target = anp.asarray([[0.009, 1.05], [2., 3.]]) + >>> preds = anp.asarray([[0., 1.], [2., 2.]]) + >>> metric(target, preds) + Array([4.0500e-05, 5.0125e-01], dtype=float32) + >>> metric = MeanSquaredError(squared=False, num_outputs=2) + >>> metric(target, preds) + Array([0.00636396, 0.7079901 ], dtype=float32) + + """ + + name: str = "Mean Squared Error" + + def __init__(self, squared: bool = True, num_outputs: int = 1) -> None: + super().__init__() + if not isinstance(squared, bool): + raise TypeError(f"Expected `squared` to be a boolean. Got {type(squared)}") + if not isinstance(num_outputs, int) and num_outputs > 0: + raise TypeError( + f"Expected `num_outputs` to be a positive integer. Got {type(num_outputs)}", + ) + self.num_outputs = num_outputs + self.squared = squared + + self.add_state_default_factory( + "sum_squared_error", + lambda xp: xp.zeros(num_outputs, dtype=xp.float32, device=self.device), # type: ignore + dist_reduce_fn="sum", + ) + self.add_state_default_factory( + "num_obs", + lambda xp: xp.asarray(0.0, dtype=xp.float32, device=self.device), # type: ignore + dist_reduce_fn="sum", + ) + + def _update_state(self, target: Array, preds: Array) -> None: + """Update state of metric.""" + sum_squared_error, num_obs = _mean_squared_error_update( + target, + preds, + self.num_outputs, + ) + self.sum_squared_error += sum_squared_error # type: ignore + self.num_obs += num_obs # type: ignore + + def _compute_metric(self) -> Array: + """Compute the Mean Squared Error.""" + return _mean_squared_error_compute( + self.sum_squared_error, # type: ignore + self.num_obs, # type: ignore + self.squared, + ) diff --git a/cyclops/evaluate/metrics/experimental/smape.py b/cyclops/evaluate/metrics/experimental/smape.py new file mode 100644 index 000000000..df2392ce4 --- /dev/null +++ b/cyclops/evaluate/metrics/experimental/smape.py @@ -0,0 +1,66 @@ +"""Symmetric Mean Absolute Percentage Error metric.""" +from cyclops.evaluate.metrics.experimental.functional.smape import ( + _symmetric_mean_absolute_percentage_error_compute, + _symmetric_mean_absolute_percentage_error_update, +) +from cyclops.evaluate.metrics.experimental.metric import Metric +from cyclops.evaluate.metrics.experimental.utils.types import Array + + +class SymmetricMeanAbsolutePercentageError(Metric): + """Symmetric Mean Absolute Percentage Error. + + Parameters + ---------- + epsilon : float, optional, default=1.17e-6 + Specifies the lower bound for target values. Any target value below epsilon + is set to epsilon (avoids division by zero errors). + + Examples + -------- + >>> import numpy.array_api as anp + >>> from cyclops.evaluate.metrics.experimental import ( + ... SymmetricMeanAbsolutePercentageError, + ... ) + >>> target = anp.asarray([0.009, 1.05, 2., 3.]) + >>> preds = anp.asarray([0., 1., 2., 2.]) + >>> metric = SymmetricMeanAbsolutePercentageError() + >>> metric(target, preds) + Array(0.61219513, dtype=float32) + + """ + + name: str = "Symmetric Mean Absolute Percentage Error" + + def __init__(self, epsilon: float = 1.17e-6) -> None: + super().__init__() + if not isinstance(epsilon, float): + raise TypeError(f"Expected `epsilon` to be a float. Got {type(epsilon)}") + self.epsilon = epsilon + + self.add_state_default_factory( + "sum_abs_per_error", + lambda xp: xp.asarray(0.0, dtype=xp.float32, device=self.device), # type: ignore + dist_reduce_fn="sum", + ) + self.add_state_default_factory( + "num_obs", + lambda xp: xp.asarray(0.0, dtype=xp.float32, device=self.device), # type: ignore + dist_reduce_fn="sum", + ) + + def _update_state(self, target: Array, preds: Array) -> None: + """Update state of metric.""" + sum_abs_per_error, num_obs = _symmetric_mean_absolute_percentage_error_update( + target, + preds, + ) + self.sum_abs_per_error += sum_abs_per_error # type: ignore + self.num_obs += num_obs # type: ignore + + def _compute_metric(self) -> Array: + """Compute the Symmetric Mean Absolute Percentage Error.""" + return _symmetric_mean_absolute_percentage_error_compute( + self.sum_abs_per_error, # type: ignore + self.num_obs, # type: ignore + ) diff --git a/cyclops/evaluate/metrics/experimental/wmape.py b/cyclops/evaluate/metrics/experimental/wmape.py new file mode 100644 index 000000000..a24e8eac5 --- /dev/null +++ b/cyclops/evaluate/metrics/experimental/wmape.py @@ -0,0 +1,72 @@ +"""Weighted Mean Absolute Percentage Error metric.""" +from types import ModuleType + +from cyclops.evaluate.metrics.experimental.functional.wmape import ( + _weighted_mean_absolute_percentage_error_compute, + _weighted_mean_absolute_percentage_error_update, +) +from cyclops.evaluate.metrics.experimental.metric import Metric +from cyclops.evaluate.metrics.experimental.utils.types import Array + + +class WeightedMeanAbsolutePercentageError(Metric): + """Weighted Mean Absolute Percentage Error. + + Parameters + ---------- + epsilon : float, optional, default=1.17e-6 + Specifies the lower bound for target values. Any target value below epsilon + is set to epsilon (avoids division by zero errors). + + Examples + -------- + >>> import numpy.array_api as anp + >>> from cyclops.evaluate.metrics.experimental import ( + ... WeightedMeanAbsolutePercentageError, + ... ) + >>> target = anp.asarray([0.009, 1.05, 2., 3.]) + >>> preds = anp.asarray([0., 1., 2., 2.]) + >>> metric = WeightedMeanAbsolutePercentageError() + >>> metric(target, preds) + Array(0.17478132, dtype=float32) + + """ + + name: str = "Weighted Mean Absolute Percentage Error" + + def __init__(self, epsilon: float = 1.17e-6) -> None: + super().__init__() + if not isinstance(epsilon, float): + raise TypeError(f"Expected `epsilon` to be a float. Got {type(epsilon)}") + self.epsilon = epsilon + + def default_factory(*, xp: ModuleType) -> Array: + return xp.asarray(0.0, dtype=xp.float32, device=self.device) # type: ignore[no-any-return] + + self.add_state_default_factory( + "sum_abs_error", + default_factory=default_factory, # type: ignore + dist_reduce_fn="sum", + ) + self.add_state_default_factory( + "sum_scale", + default_factory=default_factory, # type: ignore + dist_reduce_fn="sum", + ) + + def _update_state(self, target: Array, preds: Array) -> None: + """Update state of metric.""" + sum_abs_error, sum_scale = _weighted_mean_absolute_percentage_error_update( + target, + preds, + ) + self.sum_abs_error += sum_abs_error # type: ignore + self.sum_scale += sum_scale # type: ignore + + def _compute_metric(self) -> Array: + """Compute the Weighted Mean Absolute Percentage Error.""" + return _weighted_mean_absolute_percentage_error_compute( + self.sum_abs_error, # type: ignore + self.sum_scale, # type: ignore + self.epsilon, + ) diff --git a/pyproject.toml b/pyproject.toml index 519dbd7ee..ef989e98c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -102,7 +102,7 @@ mypy = "^1.7.0" ruff = "^0.1.0" nbqa = { version = "^1.7.0", extras = ["toolchain"] } cycquery = "^0.1.2" # used for integration test -torchmetrics = {version = "^1.2.0", extras = ["classification"]} +torchmetrics = {version = "^1.2.0", extras = ["classification", "regression"]} [tool.poetry.group.docs] optional = true @@ -132,7 +132,7 @@ jupyter = "^1.0.0" jupyterlab = "^3.4.2" ipympl = "^0.9.3" ipywidgets = "^8.0.6" -torchmetrics = {version = "^1.2.0", extras = ["classification"]} +torchmetrics = {version = "^1.2.0", extras = ["classification", "regression"]} cupy = "^12.2.0" mpi4py = {git = "https://github.com/mpi4py/mpi4py"} lightning = "^2.1.0" diff --git a/tests/cyclops/evaluate/metrics/experimental/inputs.py b/tests/cyclops/evaluate/metrics/experimental/inputs.py index 05c35a47d..92af7b9e6 100644 --- a/tests/cyclops/evaluate/metrics/experimental/inputs.py +++ b/tests/cyclops/evaluate/metrics/experimental/inputs.py @@ -337,3 +337,23 @@ def _multilabel_cases(*, xp: Any): id="input[multidim-logits]", ), ) + + +def _regression_cases(*, xp: Any): + """Return regression input cases for the given array namespace.""" + return ( + pytest.param( + InputSpec( + target=xp.asarray(np.random.rand(NUM_BATCHES, BATCH_SIZE)), + preds=xp.asarray(np.random.rand(NUM_BATCHES, BATCH_SIZE)), + ), + id="input[single-targets]", + ), + pytest.param( + InputSpec( + target=xp.asarray(np.random.rand(NUM_BATCHES, BATCH_SIZE, NUM_LABELS)), + preds=xp.asarray(np.random.rand(NUM_BATCHES, BATCH_SIZE, NUM_LABELS)), + ), + id="input[multi-targets]", + ), + ) diff --git a/tests/cyclops/evaluate/metrics/experimental/test_mean_error.py b/tests/cyclops/evaluate/metrics/experimental/test_mean_error.py new file mode 100644 index 000000000..b20069932 --- /dev/null +++ b/tests/cyclops/evaluate/metrics/experimental/test_mean_error.py @@ -0,0 +1,176 @@ +"""Tests for mean error metrics.""" +from functools import partial + +import array_api_compat.torch +import numpy.array_api as anp +import pytest +import torch +import torch.utils.dlpack +from array_api_compat.common._helpers import _is_torch_array +from torchmetrics.functional import ( + mean_absolute_error as tm_mean_absolute_error, +) +from torchmetrics.functional import ( + mean_absolute_percentage_error as tm_mean_abs_percentage_error, +) +from torchmetrics.functional import ( + mean_squared_error as tm_mean_squared_error, +) +from torchmetrics.functional import ( + symmetric_mean_absolute_percentage_error as tm_smape, +) +from torchmetrics.functional import ( + weighted_mean_absolute_percentage_error as tm_wmape, +) + +from cyclops.evaluate.metrics.experimental import ( + MeanAbsoluteError, + MeanAbsolutePercentageError, + MeanSquaredError, + SymmetricMeanAbsolutePercentageError, + WeightedMeanAbsolutePercentageError, +) +from cyclops.evaluate.metrics.experimental.functional import ( + mean_absolute_error, + mean_absolute_percentage_error, + mean_squared_error, + symmetric_mean_absolute_percentage_error, + weighted_mean_absolute_percentage_error, +) + +from .inputs import NUM_LABELS, _regression_cases +from .testers import MetricTester + + +def _tm_metric_wrapper(target, preds, tm_fn, metric_args) -> torch.Tensor: + target = torch.utils.dlpack.from_dlpack(target) + preds = torch.utils.dlpack.from_dlpack(preds) + return tm_fn(preds, target, **metric_args) + + +@pytest.mark.parametrize( + "inputs", + (*_regression_cases(xp=anp), *_regression_cases(xp=array_api_compat.torch)), +) +@pytest.mark.parametrize( + "metric_class, metric_functional, tm_fn, metric_args", + [ + (MeanAbsoluteError, mean_absolute_error, tm_mean_absolute_error, {}), + ( + MeanSquaredError, + mean_squared_error, + tm_mean_squared_error, + {"squared": True}, + ), + ( + MeanSquaredError, + mean_squared_error, + tm_mean_squared_error, + {"squared": False}, + ), + ( + MeanSquaredError, + mean_squared_error, + tm_mean_squared_error, + {"squared": True, "num_outputs": NUM_LABELS}, + ), + ( + MeanAbsolutePercentageError, + mean_absolute_percentage_error, + tm_mean_abs_percentage_error, + {}, + ), + ( + SymmetricMeanAbsolutePercentageError, + symmetric_mean_absolute_percentage_error, + tm_smape, + {}, + ), + ( + WeightedMeanAbsolutePercentageError, + weighted_mean_absolute_percentage_error, + tm_wmape, + {}, + ), + ], +) +class TestMeanError(MetricTester): + """Test class for `MeanError` metric.""" + + atol = 2e-6 + + def test_mean_error_class( + self, + inputs, + metric_class, + metric_functional, + tm_fn, + metric_args, + ): + """Test class implementation of metric.""" + target, preds = inputs + device = "cpu" + if _is_torch_array(target) and torch.cuda.is_available(): + device = "cuda" + + self.run_metric_class_implementation_test( + target, + preds, + metric_class=metric_class, + reference_metric=partial( + _tm_metric_wrapper, + tm_fn=tm_fn, + metric_args=metric_args, + ), + metric_args=metric_args, + device=device, + use_device_for_ref=_is_torch_array(target), + ) + + def test_mean_error_functional( + self, + inputs, + metric_class, + metric_functional, + tm_fn, + metric_args, + ): + """Test functional implementation of metric.""" + target, preds = inputs + device = "cpu" + if _is_torch_array(target) and torch.cuda.is_available(): + device = "cuda" + + self.run_metric_function_implementation_test( + target, + preds, + metric_function=metric_functional, + reference_metric=partial( + _tm_metric_wrapper, + tm_fn=tm_fn, + metric_args=metric_args, + ), + metric_args=metric_args, + device=device, + use_device_for_ref=_is_torch_array(target), + ) + + +@pytest.mark.parametrize( + "metric_class", + [ + MeanSquaredError, + MeanAbsoluteError, + MeanAbsolutePercentageError, + WeightedMeanAbsolutePercentageError, + SymmetricMeanAbsolutePercentageError, + ], +) +def test_error_on_different_shape(metric_class): + """Test that error is raised on different shapes of input.""" + metric = metric_class() + with pytest.raises( + ValueError, + match="Expected `target` and `preds` to have the same shape, but got `target`.*", + ): + metric(torch.randn(100), torch.randn(50)) From 8fb3cf1141b2496298bcb075544f5e059f99e99c Mon Sep 17 00:00:00 2001 From: Franklin <41602287+fcogidi@users.noreply.github.com> Date: Mon, 29 Jan 2024 19:11:51 -0500 Subject: [PATCH 6/6] Integrate experimental metrics with other modules (#549) * integrate experimental metrics with other modules * add average precision metric to experimental metrics package * fix tutorials * Add type hints and keyword arguments to metrics classes * Update nbsphinx version to 0.9.3 * Update nbconvert version to 7.14.2 * Fix type annotations and formatting issues * Update kernel display name in mortality_prediction.ipynb * Add guard clause to prevent module execution on import * Update `torch_distributed.py` with type hints * Add multiclass and multilabel average precision metrics * Change jupyter kernel * Fix type annotations for metric values in ClassificationPlotter --------- Co-authored-by: Amrit K --- cyclops/evaluate/evaluator.py | 33 +- cyclops/evaluate/fairness/config.py | 5 +- cyclops/evaluate/fairness/evaluator.py | 89 ++- .../evaluate/metrics/experimental/__init__.py | 5 + .../evaluate/metrics/experimental/auroc.py | 25 +- .../metrics/experimental/average_precision.py | 272 +++++++ .../distributed_backends/torch_distributed.py | 10 +- .../evaluate/metrics/experimental/f_score.py | 12 +- .../experimental/functional/__init__.py | 5 + .../metrics/experimental/functional/auroc.py | 6 +- .../functional/average_precision.py | 677 ++++++++++++++++++ cyclops/evaluate/metrics/experimental/mae.py | 11 +- cyclops/evaluate/metrics/experimental/mape.py | 8 +- cyclops/evaluate/metrics/experimental/mse.py | 13 +- .../experimental/negative_predictive_value.py | 6 +- .../metrics/experimental/precision_recall.py | 30 +- .../experimental/precision_recall_curve.py | 17 +- cyclops/evaluate/metrics/experimental/roc.py | 18 +- .../evaluate/metrics/experimental/smape.py | 8 +- .../metrics/experimental/specificity.py | 12 +- .../evaluate/metrics/experimental/wmape.py | 7 +- cyclops/evaluate/metrics/factory.py | 27 +- cyclops/report/plot/classification.py | 34 +- cyclops/tasks/classification.py | 38 +- .../kaggle/heart_failure_prediction.ipynb | 66 +- .../mimiciv/mortality_prediction.ipynb | 55 +- .../tutorials/nihcxr/cxr_classification.ipynb | 90 +-- .../nihcxr/generate_nihcxr_report.py | 96 +-- .../tutorials/synthea/los_prediction.ipynb | 62 +- poetry.lock | 49 +- pyproject.toml | 3 +- .../experimental/test_average_precision.py | 503 +++++++++++++ .../experimental/test_precision_recall.py | 2 + 33 files changed, 1900 insertions(+), 394 deletions(-) create mode 100644 cyclops/evaluate/metrics/experimental/average_precision.py create mode 100644 cyclops/evaluate/metrics/experimental/functional/average_precision.py create mode 100644 tests/cyclops/evaluate/metrics/experimental/test_average_precision.py diff --git a/cyclops/evaluate/evaluator.py b/cyclops/evaluate/evaluator.py index e7af9a2c0..3763cf37e 100644 --- a/cyclops/evaluate/evaluator.py +++ b/cyclops/evaluate/evaluator.py @@ -1,5 +1,4 @@ """Evaluate one or more models on a dataset.""" - import logging import warnings from dataclasses import asdict @@ -16,7 +15,9 @@ ) from cyclops.evaluate.fairness.config import FairnessConfig from cyclops.evaluate.fairness.evaluator import evaluate_fairness -from cyclops.evaluate.metrics.metric import Metric, MetricCollection +from cyclops.evaluate.metrics.experimental.metric import Metric +from cyclops.evaluate.metrics.experimental.metric_dict import MetricDict +from cyclops.evaluate.metrics.experimental.utils.types import Array from cyclops.evaluate.utils import _format_column_names, choose_split from cyclops.utils.log import setup_logging @@ -27,7 +28,7 @@ def evaluate( dataset: Union[str, Dataset, DatasetDict], - metrics: Union[Metric, Sequence[Metric], Dict[str, Metric], MetricCollection], + metrics: Union[Metric, Sequence[Metric], Dict[str, Metric], MetricDict], target_columns: Union[str, List[str]], prediction_columns: Union[str, List[str]], ignore_columns: Optional[Union[str, List[str]]] = None, @@ -47,7 +48,7 @@ def evaluate( The dataset to evaluate on. If a string, the dataset will be loaded using `datasets.load_dataset`. If `DatasetDict`, the `split` argument must be specified. - metrics : Union[Metric, Sequence[Metric], Dict[str, Metric], MetricCollection] + metrics : Union[Metric, Sequence[Metric], Dict[str, Metric], MetricDict] The metrics to compute. target_columns : Union[str, List[str]] The name of the column(s) containing the target values. A string value @@ -202,28 +203,28 @@ def _load_data( def _prepare_metrics( - metrics: Union[Metric, Sequence[Metric], Dict[str, Metric], MetricCollection], -) -> MetricCollection: + metrics: Union[Metric, Sequence[Metric], Dict[str, Metric], MetricDict], +) -> MetricDict: """Prepare metrics for evaluation.""" - # TODO: wrap in BootstrappedMetric if computing confidence intervals + # TODO [fcogidi]: wrap in BootstrappedMetric if computing confidence intervals if isinstance(metrics, (Metric, Sequence, Dict)) and not isinstance( metrics, - MetricCollection, + MetricDict, ): - return MetricCollection(metrics) - if isinstance(metrics, MetricCollection): + return MetricDict(metrics) # type: ignore[arg-type] + if isinstance(metrics, MetricDict): return metrics raise TypeError( f"Invalid type for `metrics`: {type(metrics)}. " "Expected one of: Metric, Sequence[Metric], Dict[str, Metric], " - "MetricCollection.", + "MetricDict.", ) def _compute_metrics( dataset: Dataset, - metrics: MetricCollection, + metrics: MetricDict, slice_spec: SliceSpec, target_columns: Union[str, List[str]], prediction_columns: Union[str, List[str]], @@ -266,8 +267,8 @@ def _compute_metrics( RuntimeWarning, stacklevel=1, ) - metric_output = { - metric_name: float("NaN") for metric_name in metrics + metric_output: Dict[str, Array] = { + metric_name: float("NaN") for metric_name in metrics # type: ignore[attr-defined,misc] } elif ( batch_size is None or batch_size < 0 @@ -293,10 +294,10 @@ def _compute_metrics( ) # update the metric state - metrics.update_state(targets, predictions) + metrics.update(targets, predictions) metric_output = metrics.compute() - metrics.reset_state() + metrics.reset() model_name: str = "model_for_%s" % prediction_column results.setdefault(model_name, {}) diff --git a/cyclops/evaluate/fairness/config.py b/cyclops/evaluate/fairness/config.py index 3f220f4b4..f6e2aaebe 100644 --- a/cyclops/evaluate/fairness/config.py +++ b/cyclops/evaluate/fairness/config.py @@ -5,14 +5,15 @@ from datasets import Dataset, config -from cyclops.evaluate.metrics.metric import Metric, MetricCollection +from cyclops.evaluate.metrics.experimental.metric import Metric +from cyclops.evaluate.metrics.experimental.metric_dict import MetricDict @dataclass class FairnessConfig: """Configuration for fairness metrics.""" - metrics: Union[str, Callable[..., Any], Metric, MetricCollection] + metrics: Union[str, Callable[..., Any], Metric, MetricDict] dataset: Dataset groups: Union[str, List[str]] target_columns: Union[str, List[str]] diff --git a/cyclops/evaluate/fairness/evaluator.py b/cyclops/evaluate/fairness/evaluator.py index 1296f0e56..cd44a07f1 100644 --- a/cyclops/evaluate/fairness/evaluator.py +++ b/cyclops/evaluate/fairness/evaluator.py @@ -1,5 +1,4 @@ """Fairness evaluator.""" - import inspect import itertools import logging @@ -7,8 +6,8 @@ from datetime import datetime from typing import Any, Callable, Dict, List, Optional, Union +import array_api_compat.numpy import numpy as np -import numpy.typing as npt import pandas as pd from datasets import Dataset, config from datasets.features import Features @@ -21,15 +20,14 @@ get_columns_as_numpy_array, set_decode, ) -from cyclops.evaluate.metrics.factory import create_metric -from cyclops.evaluate.metrics.functional.precision_recall_curve import ( +from cyclops.evaluate.metrics.experimental.functional.precision_recall_curve import ( _format_thresholds, + _validate_thresholds, ) -from cyclops.evaluate.metrics.metric import Metric, MetricCollection, OperatorMetric -from cyclops.evaluate.metrics.utils import ( - _check_thresholds, - _get_value_if_singleton_array, -) +from cyclops.evaluate.metrics.experimental.metric import Metric, OperatorMetric +from cyclops.evaluate.metrics.experimental.metric_dict import MetricDict +from cyclops.evaluate.metrics.experimental.utils.types import Array +from cyclops.evaluate.metrics.factory import create_metric from cyclops.evaluate.utils import _format_column_names from cyclops.utils.log import setup_logging @@ -39,7 +37,7 @@ def evaluate_fairness( - metrics: Union[str, Callable[..., Any], Metric, MetricCollection], + metrics: Union[str, Callable[..., Any], Metric, MetricDict], dataset: Dataset, groups: Union[str, List[str]], target_columns: Union[str, List[str]], @@ -62,7 +60,7 @@ def evaluate_fairness( Parameters ---------- - metrics : Union[str, Callable[..., Any], Metric, MetricCollection] + metrics : Union[str, Callable[..., Any], Metric, MetricDict] The metric or metrics to compute. If a string, it should be the name of a metric provided by CyclOps. If a callable, it should be a function that takes target, prediction, and optionally threshold/thresholds as arguments @@ -147,18 +145,14 @@ def evaluate_fairness( raise TypeError( "Expected `dataset` to be of type `Dataset`, but got " f"{type(dataset)}.", ) + _validate_thresholds(thresholds) - _check_thresholds(thresholds) - fmt_thresholds: npt.NDArray[np.float_] = _format_thresholds( # type: ignore - thresholds, - ) - - metrics_: Union[Callable[..., Any], MetricCollection] = _format_metrics( + metrics_: Union[Callable[..., Any], MetricDict] = _format_metrics( metrics, metric_name, **(metric_kwargs or {}), ) - + fmt_thresholds = _format_thresholds(thresholds, xp=array_api_compat.numpy) fmt_groups: List[str] = _format_column_names(groups) fmt_target_columns: List[str] = _format_column_names(target_columns) fmt_prediction_columns: List[str] = _format_column_names(prediction_columns) @@ -361,15 +355,15 @@ def warn_too_many_unique_values( def _format_metrics( - metrics: Union[str, Callable[..., Any], Metric, MetricCollection], + metrics: Union[str, Callable[..., Any], Metric, MetricDict], metric_name: Optional[str] = None, **metric_kwargs: Any, -) -> Union[Callable[..., Any], Metric, MetricCollection]: +) -> Union[Callable[..., Any], Metric, MetricDict]: """Format the metrics argument. Parameters ---------- - metrics : Union[str, Callable[..., Any], Metric, MetricCollection] + metrics : Union[str, Callable[..., Any], Metric, MetricDict] The metrics to use for computing the metric results. metric_name : str, optional, default=None The name of the metric. This is only used if `metrics` is a callable. @@ -379,23 +373,23 @@ def _format_metrics( Returns ------- - Union[Callable[..., Any], Metric, MetricCollection] + Union[Callable[..., Any], Metric, MetricDict] The formatted metrics. Raises ------ TypeError - If `metrics` is not of type `str`, `Callable`, `Metric`, or `MetricCollection`. + If `metrics` is not of type `str`, `Callable`, `Metric`, or `MetricDict`. """ if isinstance(metrics, str): - metrics = create_metric(metric_name=metrics, **metric_kwargs) + metrics = create_metric(metric_name=metrics, experimental=True, **metric_kwargs) if isinstance(metrics, Metric): if metric_name is not None and isinstance(metrics, OperatorMetric): # single metric created from arithmetic operation, with given name - return MetricCollection({metric_name: metrics}) - return MetricCollection(metrics) - if isinstance(metrics, MetricCollection): + return MetricDict({metric_name: metrics}) + return MetricDict(metrics) + if isinstance(metrics, MetricDict): return metrics if callable(metrics): if metric_name is None: @@ -407,7 +401,7 @@ def _format_metrics( return metrics raise TypeError( - f"Expected `metrics` to be of type `str`, `Metric`, `MetricCollection`, or " + f"Expected `metrics` to be of type `str`, `Metric`, `MetricDict`, or " f"`Callable`, but got {type(metrics)}.", ) @@ -701,7 +695,7 @@ def _get_slice_spec( def _compute_metrics( # noqa: C901, PLR0912 - metrics: Union[Callable[..., Any], MetricCollection], + metrics: Union[Callable[..., Any], MetricDict], dataset: Dataset, target_columns: List[str], prediction_column: str, @@ -713,7 +707,7 @@ def _compute_metrics( # noqa: C901, PLR0912 Parameters ---------- - metrics : Union[Callable, MetricCollection] + metrics : Union[Callable, MetricDict] The metrics to compute. dataset : Dataset The dataset to compute the metrics on. @@ -738,12 +732,19 @@ def _compute_metrics( # noqa: C901, PLR0912 "Encountered empty dataset while computing metrics. " "The metric values will be set to `None`." ) - if isinstance(metrics, MetricCollection): + if isinstance(metrics, MetricDict): if threshold is not None: # set the threshold for each metric in the collection for name, metric in metrics.items(): - if hasattr(metric, "threshold"): + if isinstance(metric, Metric) and hasattr(metric, "threshold"): metric.threshold = threshold + elif isinstance(metric, OperatorMetric): + if hasattr(metric.metric_a, "threshold") and hasattr( + metric.metric_b, + "threshold", + ): + metric.metric_a.threshold = threshold + metric.metric_b.threshold = threshold # type: ignore[union-attr] else: LOGGER.warning( "Metric %s does not have a threshold attribute. " @@ -754,7 +755,7 @@ def _compute_metrics( # noqa: C901, PLR0912 if len(dataset) == 0: warnings.warn(empty_dataset_msg, RuntimeWarning, stacklevel=1) results: Dict[str, Any] = { - metric_name: float("NaN") for metric_name in metrics + metric_name: float("NaN") for metric_name in metrics # type: ignore[attr-defined] } elif ( batch_size is None or batch_size <= 0 @@ -779,11 +780,11 @@ def _compute_metrics( # noqa: C901, PLR0912 columns=prediction_column, ) - metrics.update_state(targets, predictions) + metrics.update(targets, predictions) results = metrics.compute() - metrics.reset_state() + metrics.reset() return results if callable(metrics): @@ -817,26 +818,26 @@ def _compute_metrics( # noqa: C901, PLR0912 return {metric_name.title(): output} raise TypeError( - "The `metrics` argument must be a string, a Metric, a MetricCollection, " + "The `metrics` argument must be a string, a Metric, a MetricDict, " f"or a callable. Got {type(metrics)}.", ) def _get_metric_results_for_prediction_and_slice( - metrics: Union[Callable[..., Any], MetricCollection], + metrics: Union[Callable[..., Any], MetricDict], dataset: Dataset, target_columns: List[str], prediction_column: str, slice_name: str, batch_size: Optional[int] = config.DEFAULT_MAX_BATCH_SIZE, metric_name: Optional[str] = None, - thresholds: Optional[npt.NDArray[np.float_]] = None, + thresholds: Optional[Array] = None, ) -> Dict[str, Dict[str, Any]]: """Compute metrics for a slice of a dataset. Parameters ---------- - metrics : Union[Callable, MetricCollection] + metrics : Union[Callable, MetricDict] The metrics to compute. dataset : Dataset The dataset to compute the metrics on. @@ -850,7 +851,7 @@ def _get_metric_results_for_prediction_and_slice( The batch size to use for the computation. metric_name : Optional[str] The name of the metric to compute. - thresholds : Optional[List[float]] + thresholds : Optional[Array] The thresholds to use for the metrics. Returns @@ -873,7 +874,7 @@ def _get_metric_results_for_prediction_and_slice( return {slice_name: metric_output} results: Dict[str, Dict[str, Any]] = {} - for threshold in thresholds: + for threshold in thresholds: # type: ignore[attr-defined] metric_output = _compute_metrics( metrics=metrics, dataset=dataset, @@ -969,11 +970,7 @@ def _compute_parity_metrics( ) parity_results[key].setdefault(slice_name, {}).update( - { - parity_metric_name: _get_value_if_singleton_array( - parity_metric_value, - ), - }, + {parity_metric_name: parity_metric_value}, ) return parity_results diff --git a/cyclops/evaluate/metrics/experimental/__init__.py b/cyclops/evaluate/metrics/experimental/__init__.py index ec6c72609..3a5b9974a 100644 --- a/cyclops/evaluate/metrics/experimental/__init__.py +++ b/cyclops/evaluate/metrics/experimental/__init__.py @@ -9,6 +9,11 @@ MulticlassAUROC, MultilabelAUROC, ) +from cyclops.evaluate.metrics.experimental.average_precision import ( + BinaryAveragePrecision, + MulticlassAveragePrecision, + MultilabelAveragePrecision, +) from cyclops.evaluate.metrics.experimental.confusion_matrix import ( BinaryConfusionMatrix, MulticlassConfusionMatrix, diff --git a/cyclops/evaluate/metrics/experimental/auroc.py b/cyclops/evaluate/metrics/experimental/auroc.py index 17c6af31f..bd139cb77 100644 --- a/cyclops/evaluate/metrics/experimental/auroc.py +++ b/cyclops/evaluate/metrics/experimental/auroc.py @@ -1,5 +1,5 @@ """Classes for computing the area under the ROC curve.""" -from typing import List, Literal, Optional, Tuple, Union +from typing import Any, List, Literal, Optional, Tuple, Union from cyclops.evaluate.metrics.experimental.functional.auroc import ( _binary_auroc_compute, @@ -18,7 +18,7 @@ from cyclops.evaluate.metrics.experimental.utils.types import Array -class BinaryAUROC(BinaryPrecisionRecallCurve): +class BinaryAUROC(BinaryPrecisionRecallCurve, registry_key="binary_auroc"): """Area under the Receiver Operating Characteristic (ROC) curve. Parameters @@ -37,6 +37,8 @@ class BinaryAUROC(BinaryPrecisionRecallCurve): ignore_index : int, optional, default=None The value in `target` that should be ignored when computing the AUROC. If `None`, all values in `target` are used. + **kwargs : Any + Additional keyword arguments that are common to all metrics. Examples -------- @@ -59,9 +61,10 @@ def __init__( max_fpr: Optional[float] = None, thresholds: Optional[Union[int, List[float], Array]] = None, ignore_index: Optional[int] = None, + **kwargs: Any, ) -> None: """Initialize the BinaryAUROC metric.""" - super().__init__(thresholds=thresholds, ignore_index=ignore_index) + super().__init__(thresholds=thresholds, ignore_index=ignore_index, **kwargs) _binary_auroc_validate_args( max_fpr=max_fpr, thresholds=thresholds, @@ -70,7 +73,7 @@ def __init__( self.max_fpr = max_fpr def _compute_metric(self) -> Array: # type: ignore[override] - """Compute the AUROC.""" "" + """Compute the AUROC.""" state = ( (dim_zero_cat(self.target), dim_zero_cat(self.preds)) # type: ignore[attr-defined] if self.thresholds is None @@ -79,7 +82,7 @@ def _compute_metric(self) -> Array: # type: ignore[override] return _binary_auroc_compute(state, thresholds=self.thresholds, max_fpr=self.max_fpr) # type: ignore -class MulticlassAUROC(MulticlassPrecisionRecallCurve): +class MulticlassAUROC(MulticlassPrecisionRecallCurve, registry_key="multiclass_auroc"): """Area under the Receiver Operating Characteristic (ROC) curve. Parameters @@ -105,6 +108,8 @@ class MulticlassAUROC(MulticlassPrecisionRecallCurve): ignore_index : int or Tuple[int], optional, default=None The value(s) in `target` that should be ignored when computing the AUROC. If `None`, all values in `target` are used. + **kwargs : Any + Additional keyword arguments that are common to all metrics. Examples -------- @@ -140,12 +145,14 @@ def __init__( thresholds: Optional[Union[int, List[float], Array]] = None, average: Optional[Literal["macro", "weighted", "none"]] = "macro", ignore_index: Optional[Union[int, Tuple[int]]] = None, + **kwargs: Any, ) -> None: """Initialize the MulticlassAUROC metric.""" super().__init__( num_classes, thresholds=thresholds, ignore_index=ignore_index, + **kwargs, ) _multiclass_auroc_validate_args( num_classes=num_classes, @@ -170,9 +177,11 @@ def _compute_metric(self) -> Array: # type: ignore[override] ) -class MultilabelAUROC(MultilabelPrecisionRecallCurve): +class MultilabelAUROC(MultilabelPrecisionRecallCurve, registry_key="multilabel_auroc"): """Area under the Receiver Operating Characteristic (ROC) curve. + Parameters + ---------- num_labels : int The number of labels in the multilabel classification problem. thresholds : Union[int, List[float], Array], optional, default=None @@ -195,6 +204,8 @@ class MultilabelAUROC(MultilabelPrecisionRecallCurve): ignore_index : int, optional, default=None The value in `target` that should be ignored when computing the AUROC. If `None`, all values in `target` are used. + **kwargs : Any + Additional keyword arguments that are common to all metrics. Examples -------- @@ -227,12 +238,14 @@ def __init__( thresholds: Optional[Union[int, List[float], Array]] = None, average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", ignore_index: Optional[int] = None, + **kwargs: Any, ) -> None: """Initialize the MultilabelAUROC metric.""" super().__init__( num_labels, thresholds=thresholds, ignore_index=ignore_index, + **kwargs, ) _multilabel_auroc_validate_args( num_labels=num_labels, diff --git a/cyclops/evaluate/metrics/experimental/average_precision.py b/cyclops/evaluate/metrics/experimental/average_precision.py new file mode 100644 index 000000000..f8f8692ac --- /dev/null +++ b/cyclops/evaluate/metrics/experimental/average_precision.py @@ -0,0 +1,272 @@ +"""Classes for computing area under the Average Precision (AUPRC).""" + +from typing import Any, List, Literal, Optional, Tuple, Union + +from cyclops.evaluate.metrics.experimental.functional.average_precision import ( + _binary_average_precision_compute, + _multiclass_average_precision_compute, + _multiclass_average_precision_validate_args, + _multilabel_average_precision_compute, + _multilabel_average_precision_validate_args, +) +from cyclops.evaluate.metrics.experimental.precision_recall_curve import ( + BinaryPrecisionRecallCurve, + MulticlassPrecisionRecallCurve, + MultilabelPrecisionRecallCurve, +) +from cyclops.evaluate.metrics.experimental.utils.ops import dim_zero_cat +from cyclops.evaluate.metrics.experimental.utils.types import Array + + +class BinaryAveragePrecision( + BinaryPrecisionRecallCurve, + registry_key="binary_average_precision", +): + """A summary of the precision-recall curve via a weighted mean of the points. + + The average precision score summarizes a precision-recall curve as the weighted + mean of precisions achieved at each threshold, with the increase in recall from + the previous threshold used as the weight. + + Parameters + ---------- + thresholds : Union[int, List[float], Array], optional, default=None + The thresholds to use for computing the precision and recall. Can be one + of the following: + - `None`: use all unique values in `preds` as thresholds. + - `int`: use `int` (larger than 1) uniformly spaced thresholds in the range + [0, 1]. + - `List[float]`: use the values in the list as bins for the thresholds. + - `Array`: use the values in the Array as bins for the thresholds. The + array must be 1D. + ignore_index : int, optional, default=None + The value in `target` that should be ignored when computing the average + precision. If `None`, all values in `target` are used. + **kwargs : Any + Additional keyword arguments to pass to the `Metric` base class. + + Examples + -------- + >>> import numpy.array_api as anp + >>> from cyclops.evaluate.metrics.experimental import BinaryAveragePrecision + >>> target = anp.asarray([0, 1, 0, 1]) + >>> preds = anp.asarray([0.1, 0.4, 0.35, 0.8]) + >>> metric = BinaryAveragePrecision(thresholds=3) + >>> metric(target, preds) + Array(0.75, dtype=float32) + >>> metric.reset() + >>> target = [[0, 1, 0, 1], [1, 1, 0, 0]] + >>> preds = [[0.1, 0.4, 0.35, 0.8], [0.6, 0.3, 0.1, 0.7]] + >>> for t, p in zip(target, preds): + ... metric.update(anp.asarray(t), anp.asarray(p)) + >>> metric.compute() + Array(0.5833334, dtype=float32) + + """ + + name: str = "Average Precision Score" + + def _compute_metric(self) -> Array: # type: ignore[override] + """Compute the metric.""" + state = ( + (dim_zero_cat(self.target), dim_zero_cat(self.preds)) # type: ignore[attr-defined] + if self.thresholds is None + else self.confmat # type: ignore[attr-defined] + ) + + return _binary_average_precision_compute( + state, + self.thresholds, # type: ignore + pos_label=1, + ) + + +class MulticlassAveragePrecision( + MulticlassPrecisionRecallCurve, + registry_key="multiclass_average_precision", +): + """A summary of the precision-recall curve via a weighted mean of the points. + + The average precision score summarizes a precision-recall curve as the weighted + mean of precisions achieved at each threshold, with the increase in recall from + the previous threshold used as the weight. + + Parameters + ---------- + num_classes : int + The number of classes in the classification problem. + thresholds : Union[int, List[float], Array], optional, default=None + The thresholds to use for computing the average precision score. Can be one + of the following: + - `None`: use all unique values in `preds` as thresholds. + - `int`: use `int` (larger than 1) uniformly spaced thresholds in the range + [0, 1]. + - `List[float]`: use the values in the list as bins for the thresholds. + - `Array`: use the values in the Array as bins for the thresholds. The + array must be 1D. + average : {"macro", "weighted", "none"}, optional, default="macro" + The type of averaging to use for computing the average precision score. Can + be one of the following: + - `"macro"`: compute the average precision score for each class and average + over the classes. + - `"weighted"`: computes the average of the precision for each class and + average over the classwise scores using the support of each class as + weights. + - `"none"`: do not average over the classwise scores. + ignore_index : int or Tuple[int], optional, default=None + The value(s) in `target` that should be ignored when computing the average + precision score. If `None`, all values in `target` are used. + **kwargs : Any + Additional keyword arguments to pass to the `Metric` base class. + + Examples + -------- + >>> import numpy.array_api as anp + >>> from cyclops.evaluate.metrics.experimental import MulticlassAveragePrecision + >>> target = anp.asarray([0, 1, 2, 0, 1, 2]) + >>> preds = anp.asarray( + ... [[0.11, 0.22, 0.67], + ... [0.84, 0.73, 0.12], + ... [0.33, 0.92, 0.44], + ... [0.11, 0.22, 0.67], + ... [0.84, 0.73, 0.12], + ... [0.33, 0.92, 0.44]]) + >>> metric = MulticlassAveragePrecision( + ... num_classes=3, thresholds=None, average=None, + ... ) + >>> metric(target, preds) + Array([0.33333334, 0.5 , 0.5 ], dtype=float32) + + """ + + name: str = "Average Precision Score" + + def __init__( + self, + num_classes: int, + thresholds: Optional[Union[int, List[float], Array]] = None, + average: Optional[Literal["macro", "weighted", "none"]] = "macro", + ignore_index: Optional[Union[int, Tuple[int]]] = None, + **kwargs: Any, + ) -> None: + """Initialize a `MulticlassAveragePrecision` instance.""" + super().__init__(num_classes, thresholds, ignore_index=ignore_index, **kwargs) + _multiclass_average_precision_validate_args( + num_classes, + thresholds=thresholds, + average=average, + ignore_index=ignore_index, + ) + self.average = average # type: ignore[assignment] + + def _compute_metric(self) -> Array: # type: ignore[override] + """Compute the metric.""" + state = ( + (dim_zero_cat(self.target), dim_zero_cat(self.preds)) # type: ignore[attr-defined] + if self.thresholds is None + else self.confmat # type: ignore[attr-defined] + ) + + return _multiclass_average_precision_compute( + state, + self.num_classes, + thresholds=self.thresholds, # type: ignore[arg-type] + average=self.average, # type: ignore[arg-type] + ) + + +class MultilabelAveragePrecision( + MultilabelPrecisionRecallCurve, + registry_key="multilabel_average_precision", +): + """A summary of the precision-recall curve via a weighted mean of the points. + + The average precision score summarizes a precision-recall curve as the weighted + mean of precisions achieved at each threshold, with the increase in recall from + the previous threshold used as the weight. + + Parameters + ---------- + num_labels : int + The number of labels in the multilabel classification problem. + thresholds : Union[int, List[float], Array], optional, default=None + The thresholds to use for computing the average precision score. Can be one + of the following: + - `None`: use all unique values in `preds` as thresholds. + - `int`: use `int` (larger than 1) uniformly spaced thresholds in the range + [0, 1]. + - `List[float]`: use the values in the list as bins for the thresholds. + - `Array`: use the values in the Array as bins for the thresholds. The + array must be 1D. + average : {"micro", "macro", "weighted", "none"}, optional, default="macro" + The type of averaging to use for computing the average precision score. Can + be one of the following: + - `"micro"`: computes the average precision score globally by summing over + the average precision scores for each label. + - `"macro"`: compute the average precision score for each label and average + over the labels. + - `"weighted"`: computes the average of the precision for each label and + average over the labelwise scores using the support of each label as + weights. + - `"none"`: do not average over the labelwise scores. + ignore_index : int, optional, default=None + The value in `target` that should be ignored when computing the average + precision score. If `None`, all values in `target` are used. + **kwargs : Any + Additional keyword arguments to pass to the `Metric` base class. + + Examples + -------- + >>> import numpy.array_api as anp + >>> from cyclops.evaluate.metrics.experimental import MultilabelAveragePrecision + >>> target = anp.asarray([[0, 1, 0], [1, 1, 0], [0, 0, 1]]) + >>> preds = anp.asarray( + ... [[0.11, 0.22, 0.67], [0.84, 0.73, 0.12], [0.33, 0.92, 0.44]], + ... ) + >>> metric = MultilabelAveragePrecision( + ... num_labels=3, thresholds=None, average=None, + ... ) + >>> metric(target, preds) + Array([1. , 0.5833334, 0.5 ], dtype=float32) + """ + + name: str = "Average Precision Score" + + def __init__( + self, + num_labels: int, + thresholds: Optional[Union[int, List[float], Array]] = None, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + ignore_index: Optional[int] = None, + **kwargs: Any, + ) -> None: + """Initialize a `MultilabelAveragePrecision` instance.""" + super().__init__( + num_labels, + thresholds=thresholds, + ignore_index=ignore_index, + **kwargs, + ) + _multilabel_average_precision_validate_args( + num_labels, + thresholds=thresholds, + average=average, + ignore_index=ignore_index, + ) + self.average = average + + def _compute_metric(self) -> Array: # type: ignore[override] + """Compute the metric.""" + state = ( + (dim_zero_cat(self.target), dim_zero_cat(self.preds)) # type: ignore[attr-defined] + if self.thresholds is None + else self.confmat # type: ignore[attr-defined] + ) + + return _multilabel_average_precision_compute( + state, + self.num_labels, + thresholds=self.thresholds, # type: ignore[arg-type] + average=self.average, + ignore_index=self.ignore_index, + ) diff --git a/cyclops/evaluate/metrics/experimental/distributed_backends/torch_distributed.py b/cyclops/evaluate/metrics/experimental/distributed_backends/torch_distributed.py index 851b74f8a..25b0ea1cd 100644 --- a/cyclops/evaluate/metrics/experimental/distributed_backends/torch_distributed.py +++ b/cyclops/evaluate/metrics/experimental/distributed_backends/torch_distributed.py @@ -10,8 +10,10 @@ if TYPE_CHECKING: import torch import torch.distributed as torch_dist + from torch import Tensor else: torch = import_optional_module("torch", error="warn") + Tensor = import_optional_module("torch", attribute="Tensor", error="warn") torch_dist = import_optional_module("torch.distributed", error="warn") @@ -47,13 +49,13 @@ def world_size(self) -> int: """Return the world size of the current process group.""" return torch_dist.get_world_size() - def _simple_all_gather(self, data: torch.Tensor) -> List[torch.Tensor]: + def _simple_all_gather(self, data: Tensor) -> List[Tensor]: """Gather tensors of the same shape from all processes.""" gathered_data = [torch.zeros_like(data) for _ in range(self.world_size)] torch_dist.all_gather(gathered_data, data) # type: ignore[no-untyped-call] return gathered_data - def all_gather(self, data: torch.Tensor) -> List[torch.Tensor]: # type: ignore[override] + def all_gather(self, data: Tensor) -> List[Tensor]: # type: ignore[override] """Gather Arrays from current proccess and return as a list. Parameters @@ -95,3 +97,7 @@ def all_gather(self, data: torch.Tensor) -> List[torch.Tensor]: # type: ignore[ slice_param = [slice(dim_size) for dim_size in item_size] gathered_data[idx] = gathered_data[idx][slice_param] return gathered_data + + +if __name__ == "__main__": # prevent execution of module on import + pass diff --git a/cyclops/evaluate/metrics/experimental/f_score.py b/cyclops/evaluate/metrics/experimental/f_score.py index 1092e499c..7e9bc7a20 100644 --- a/cyclops/evaluate/metrics/experimental/f_score.py +++ b/cyclops/evaluate/metrics/experimental/f_score.py @@ -28,7 +28,7 @@ class BinaryFBetaScore(_AbstractBinaryStatScores, registry_key="binary_fbeta_sco Threshold for converting probabilities into binary values. ignore_index : int, optional Values in the target array to ignore when computing the metric. - **kwargs + **kwargs : Any Additional keyword arguments common to all metrics. Examples @@ -106,6 +106,8 @@ class MulticlassFBetaScore( Specifies a target class that is ignored when computing the F-beta score. Ignoring a target class means that the corresponding predictions do not contribute to the F-beta score. + **kwargs : Any + Additional keyword arguments common to all metrics. Examples -------- @@ -203,6 +205,8 @@ class MultilabelFBetaScore( ignore_index : int, optional, default=None Specifies a value in the target array(s) that is ignored when computing the F-beta score. + **kwargs : Any + Additional keyword arguments common to all metrics. Examples -------- @@ -277,7 +281,7 @@ class BinaryF1Score(BinaryFBetaScore, registry_key="binary_f1_score"): Threshold for converting probabilities into binary values. ignore_index : int, optional Values in the target array to ignore when computing the metric. - **kwargs + **kwargs : Any Additional keyword arguments common to all metrics. Examples @@ -341,6 +345,8 @@ class MulticlassF1Score(MulticlassFBetaScore, registry_key="multiclass_f1_score" Specifies a target class that is ignored when computing the F1 score. Ignoring a target class means that the corresponding predictions do not contribute to the F1 score. + **kwargs : Any + Additional keyword arguments common to all metrics. Examples -------- @@ -413,6 +419,8 @@ class MultilabelF1Score( ignore_index : int, optional, default=None Specifies a value in the target array(s) that is ignored when computing the F1 score. + **kwargs : Any + Additional keyword arguments common to all metrics. Examples -------- diff --git a/cyclops/evaluate/metrics/experimental/functional/__init__.py b/cyclops/evaluate/metrics/experimental/functional/__init__.py index e24543e64..1a2e5902b 100644 --- a/cyclops/evaluate/metrics/experimental/functional/__init__.py +++ b/cyclops/evaluate/metrics/experimental/functional/__init__.py @@ -9,6 +9,11 @@ multiclass_auroc, multilabel_auroc, ) +from cyclops.evaluate.metrics.experimental.functional.average_precision import ( + binary_average_precision, + multiclass_average_precision, + multilabel_average_precision, +) from cyclops.evaluate.metrics.experimental.functional.confusion_matrix import ( binary_confusion_matrix, multiclass_confusion_matrix, diff --git a/cyclops/evaluate/metrics/experimental/functional/auroc.py b/cyclops/evaluate/metrics/experimental/functional/auroc.py index c6e7c83c5..7abe73990 100644 --- a/cyclops/evaluate/metrics/experimental/functional/auroc.py +++ b/cyclops/evaluate/metrics/experimental/functional/auroc.py @@ -1,5 +1,6 @@ """Functions for computing the area under the ROC curve (AUROC).""" import warnings +from types import ModuleType from typing import List, Literal, Optional, Tuple, Union import array_api_compat as apc @@ -194,6 +195,8 @@ def _reduce_auroc( tpr: Union[Array, List[Array]], average: Optional[Literal["macro", "weighted", "none"]] = None, weights: Optional[Array] = None, + *, + xp: ModuleType, ) -> Array: """Compute the area under the ROC curve and apply `average` method. @@ -225,7 +228,6 @@ def _reduce_auroc( If the AUROC for one or more classes is `nan` and ``average`` is not ``none``. """ - xp = apc.array_namespace((fpr[0], tpr[0]) if isinstance(fpr, list) else (fpr, tpr)) if apc.is_array_api_obj(fpr) and apc.is_array_api_obj(tpr): res = _auc_compute(fpr, tpr, 1.0, axis=1) # type: ignore else: @@ -288,6 +290,7 @@ def _multiclass_auroc_compute( weights=xp.astype(bincount(state[0], minlength=num_classes), xp.float32) if thresholds is None else xp.sum(state[0, ...][:, 1, :], axis=-1), # type: ignore[call-overload] + xp=xp, ) @@ -492,6 +495,7 @@ def _multilabel_auroc_compute( ) if thresholds is None else xp.sum(state[0, ...][:, 1, :], axis=-1), # type: ignore[call-overload] + xp=xp, ) diff --git a/cyclops/evaluate/metrics/experimental/functional/average_precision.py b/cyclops/evaluate/metrics/experimental/functional/average_precision.py new file mode 100644 index 000000000..257fd119c --- /dev/null +++ b/cyclops/evaluate/metrics/experimental/functional/average_precision.py @@ -0,0 +1,677 @@ +"""Functions for computing average precision (AUPRC) for classification tasks.""" +import warnings +from types import ModuleType +from typing import List, Literal, Optional, Tuple, Union + +import array_api_compat as apc + +from cyclops.evaluate.metrics.experimental.functional.precision_recall_curve import ( + _binary_precision_recall_curve_compute, + _binary_precision_recall_curve_format_arrays, + _binary_precision_recall_curve_update, + _binary_precision_recall_curve_validate_args, + _binary_precision_recall_curve_validate_arrays, + _multiclass_precision_recall_curve_compute, + _multiclass_precision_recall_curve_format_arrays, + _multiclass_precision_recall_curve_update, + _multiclass_precision_recall_curve_validate_args, + _multiclass_precision_recall_curve_validate_arrays, + _multilabel_precision_recall_curve_compute, + _multilabel_precision_recall_curve_format_arrays, + _multilabel_precision_recall_curve_update, + _multilabel_precision_recall_curve_validate_args, + _multilabel_precision_recall_curve_validate_arrays, +) +from cyclops.evaluate.metrics.experimental.utils.ops import ( + _diff, + bincount, + flatten, + remove_ignore_index, + safe_divide, +) +from cyclops.evaluate.metrics.experimental.utils.types import Array + + +def _binary_average_precision_compute( + state: Union[Tuple[Array, Array], Array], + thresholds: Optional[Array], + pos_label: int = 1, +) -> Array: + """Compute average precision for binary classification task. + + Parameters + ---------- + state : Array or Tuple[Array, Array] + State from which the precision-recall curve can be computed. Can be + either a tuple of (target, preds) or a multi-threshold confusion matrix. + thresholds : Array, optional + Thresholds used for computing the precision and recall scores. If not None, + must be a 1D numpy array of floats in the [0, 1] range and monotonically + increasing. + pos_label : int, optional, default=1 + The label of the positive class. + + Returns + ------- + Array + The average precision score. + + Raises + ------ + ValueError + If ``thresholds`` is None. + + """ + precision, recall, _ = _binary_precision_recall_curve_compute( + state, + thresholds, + pos_label, + ) + xp = apc.array_namespace(precision, recall) + return -xp.sum(_diff(recall) * precision[:-1], dtype=xp.float32) # type: ignore + + +def binary_average_precision( + target: Array, + preds: Array, + thresholds: Optional[Union[int, List[float], Array]] = None, + ignore_index: Optional[int] = None, +) -> Array: + """Compute average precision score for binary classification task. + + The average precision score summarizes a precision-recall curve as the weighted + mean of precisions achieved at each threshold, with the increase in recall from + the previous threshold used as the weight. + + Parameters + ---------- + target : Array + An array object that is compatible with the Python array API standard + and contains the ground truth labels in the range [0, 1]. The expected + shape of the array is `(N, ...)`, where `N` is the number of samples. + preds : Array + An array object that is compatible with the Python array API standard and + contains the probability/logit scores for the positive class. The expected + shape of the array is `(N, ...)` where `N` is the number of samples. If + `preds` contains floating point values that are not in the range `[0, 1]`, + a sigmoid function will be applied to each value before thresholding. + thresholds : Union[int, List[float], Array], optional, default=None + The thresholds to use for computing the average precision. Can be one + of the following: + - `None`: use all unique values in `preds` as thresholds. + - `int`: use `int` (larger than 1) uniformly spaced thresholds in the range + [0, 1]. + - `List[float]`: use the values in the list as bins for the thresholds. + - `Array`: use the values in the Array as bins for the thresholds. The + array must be 1D. + ignore_index : int, optional, default=None + The value in `target` that should be ignored when computing the average + precision. If `None`, all values in `target` are used. + + Returns + ------- + Array + The average precision score. + + Raises + ------ + TypeError + If `thresholds` is not `None` and not an integer, a list of floats or an + Array of floats. + ValueError + If `thresholds` is an integer and smaller than 2. + ValueError + If `thresholds` is a list of floats with values outside the range [0, 1] + and not monotonically increasing. + ValueError + If `thresholds` is an Array of floats and not all values are in the range + [0, 1] or the array is not one-dimensional. + ValueError + If `ignore_index` is not `None` or an integer. + TypeError + If `target` and `preds` are not compatible with the Python array API standard. + ValueError + If `target` and `preds` are empty. + ValueError + If `target` and `preds` are not numeric arrays. + ValueError + If `target` and `preds` do not have the same shape. + ValueError + If `target` contains floating point values. + ValueError + If `preds` contains non-floating point values. + RuntimeError + If `target` contains values outside the range [0, 1] or does not contain + `ignore_index` if `ignore_index` is not `None`. + ValueError + If the array API namespace of `target` and `preds` are not the same as the + array API namespace of `thresholds`. + + Examples + -------- + >>> import numpy.array_api as anp + >>> from cyclops.evaluate.metrics.experimental.functional import ( + ... binary_average_precision + ... ) + >>> target = anp.asarray([0, 1, 1, 0]) + >>> preds = anp.asarray([0, 0.5, 0.7, 0.8]) + >>> binary_average_precision(target, preds, thresholds=None) + Array(0.5833334, dtype=float32) + + """ + _binary_precision_recall_curve_validate_args(thresholds, ignore_index) + xp = _binary_precision_recall_curve_validate_arrays( + target, + preds, + thresholds, + ignore_index, + ) + target, preds, thresholds = _binary_precision_recall_curve_format_arrays( + target, + preds, + thresholds=thresholds, + ignore_index=ignore_index, + xp=xp, + ) + state = _binary_precision_recall_curve_update( + target, + preds, + thresholds=thresholds, + xp=xp, + ) + return _binary_average_precision_compute(state, thresholds, pos_label=1) + + +def _reduce_average_precision( + precision: Union[Array, List[Array]], + recall: Union[Array, List[Array]], + average: Optional[Literal["macro", "weighted", "none"]] = "macro", + weights: Optional[Array] = None, + *, + xp: ModuleType, +) -> Array: + """Reduce the precision-recall curve to a single average precision score. + + Applies the specified `average` after computing the average precision score + for each class/label. + + Parameters + ---------- + precision : Array or List[Array] + The precision values for each class/label, computed at different thresholds. + recall : Array or List[Array] + The recall values for each class/label, computed at different thresholds. + average : {"macro", "weighted", "none"}, optional, default="macro" + The type of averaging to use for computing the average precision score. Can + be one of the following: + - `"macro"`: computes the average precision score for each class/label and + average over the scores. + - `"weighted"`: computes the average of the precision score for each + class/label and average over the classwise/labelwise scores using + `weights` as weights. + - `"none"`: do not average over the classwise/labelwise scores. + weights : Array, optional, default=None + The weights to use for computing the weighted average precision score. + xp : ModuleType + The array API module to use for computations. + + Returns + ------- + Array + The average precision score. + + Raises + ------ + ValueError + If `average` is not `"macro"`, `"weighted"` or `"none"` or `None` or + average is `"weighted"` and `weights` is `None`. + """ + if apc.is_array_api_obj(precision) and apc.is_array_api_obj(recall): + avg_prec = -xp.sum( + (recall[:, 1:] - recall[:, :-1]) * precision[:, :-1], # type: ignore + axis=1, + dtype=xp.float32, + ) + else: + avg_prec = xp.stack( + [ + -xp.sum((rec[1:] - rec[:-1]) * prec[:-1], dtype=xp.float32) + for prec, rec in zip(precision, recall) # type: ignore[arg-type] + ], + ) + if average is None or average == "none": + return avg_prec # type: ignore[no-any-return] + if xp.any(xp.isnan(avg_prec)): + warnings.warn( + f"Average precision score for one or more classes was `nan`. Ignoring these classes in {average}-average", + UserWarning, + stacklevel=1, + ) + idx = ~xp.isnan(avg_prec) + if average == "macro": + return xp.mean(avg_prec[idx]) # type: ignore[no-any-return] + if average == "weighted" and weights is not None: + weights = safe_divide(weights[idx], xp.sum(weights[idx])) + return xp.sum(avg_prec[idx] * weights, dtype=xp.float32) # type: ignore[no-any-return] + raise ValueError( + "Received an incompatible combinations of inputs to make reduction.", + ) + + +def _multiclass_average_precision_validate_args( + num_classes: int, + thresholds: Optional[Union[int, List[float], Array]] = None, + average: Optional[Literal["macro", "weighted", "none"]] = "macro", + ignore_index: Optional[Union[int, Tuple[int]]] = None, +) -> None: + """Validate the arguments for the `multiclass_average_precision` function.""" + _multiclass_precision_recall_curve_validate_args( + num_classes, + thresholds=thresholds, + ignore_index=ignore_index, + ) + allowed_averages = ["macro", "weighted", "none"] + if average is not None and average not in allowed_averages: + raise ValueError( + f"Expected `average` to be one of {allowed_averages}, got {average}.", + ) + + +def _multiclass_average_precision_compute( + state: Union[Tuple[Array, Array], Array], + num_classes: int, + thresholds: Optional[Array], + average: Optional[Literal["macro", "weighted", "none"]] = "macro", +) -> Array: + """Compute the average precision score for multiclass classification task.""" + precision, recall, _ = _multiclass_precision_recall_curve_compute( + state, + num_classes, + thresholds=thresholds, + average=None, + ) + xp = apc.array_namespace(state) + return _reduce_average_precision( + precision, + recall, + average=average, + weights=xp.astype(bincount(state[0], minlength=num_classes), xp.float32) + if thresholds is None + else xp.sum(state[0, ...][:, 1, :], axis=-1, dtype=xp.float32), # type: ignore + xp=xp, + ) + + +def multiclass_average_precision( + target: Array, + preds: Array, + num_classes: int, + thresholds: Optional[Union[int, List[float], Array]] = None, + average: Optional[Literal["macro", "weighted", "none"]] = "macro", + ignore_index: Optional[Union[int, Tuple[int]]] = None, +) -> Array: + """Compute the average precision score for multiclass classification task. + + The average precision score summarizes a precision-recall curve as the weighted + mean of precisions achieved at each threshold, with the increase in recall from + the previous threshold used as the weight. + + Parameters + ---------- + target : Array + An array object that is compatible with the Python array API standard + and contains the ground truth labels in the range [0, `num_classes`] + (except if `ignore_index` is specified). The expected shape of the array + is `(N, ...)`, where `N` is the number of samples. + preds : Array + An array object that is compatible with the Python array API standard and + contains the probability/logit scores for each sample. The expected shape + of the array is `(N, C, ...)` where `N` is the number of samples and `C` + is the number of classes. If `preds` contains floating point values that + are not in the range `[0, 1]`, a softmax function will be applied to each + value before thresholding. + num_classes : int + The number of classes in the classification problem. + thresholds : Union[int, List[float], Array], optional, default=None + The thresholds to use for computing the average precision score. Can be one + of the following: + - `None`: use all unique values in `preds` as thresholds. + - `int`: use `int` (larger than 1) uniformly spaced thresholds in the range + [0, 1]. + - `List[float]`: use the values in the list as bins for the thresholds. + - `Array`: use the values in the Array as bins for the thresholds. The + array must be 1D. + average : {"macro", "weighted", "none"}, optional, default="macro" + The type of averaging to use for computing the average precision score. Can + be one of the following: + - `"macro"`: compute the average precision score for each class and average + over the classes. + - `"weighted"`: computes the average of the precision for each class and + average over the classwise scores using the support of each class as + weights. + - `"none"`: do not average over the classwise scores. + ignore_index : int or Tuple[int], optional, default=None + The value(s) in `target` that should be ignored when computing the average + precision score. If `None`, all values in `target` are used. + + Returns + ------- + Array + The average precision score. + + Raises + ------ + TypeError + If `thresholds` is not `None` and not an integer, a list of floats or an + Array of floats. + ValueError + If `thresholds` is an integer and smaller than 2. + ValueError + If `thresholds` is a list of floats with values outside the range [0, 1] + and not monotonically increasing. + ValueError + If `thresholds` is an Array of floats and not all values are in the range + [0, 1] or the array is not one-dimensional. + ValueError + If `num_classes` is not an integer larger than 1. + ValueError + If `ignore_index` is not `None`, an integer or a tuple of integers. + ValueError + If `average` is not `"macro"`, `"weighted"`, `"none"` or `None`. + TypeError + If `target` and `preds` are not compatible with the Python array API standard. + ValueError + If `target` and `preds` are empty. + ValueError + If `target` and `preds` are not numeric arrays. + ValueError + If `preds` does not have one more dimension than `target`. + ValueError + If `target` contains floating point values. + ValueError + If `preds` contains non-floating point values. + ValueError + If the second dimension of `preds` is not equal to `num_classes`. + ValueError + If the first dimension of `preds` is not equal to the first dimension of + `target` or the third dimension of `preds` is not equal to the second + dimension of `target`. + RuntimeError + If `target` contains more unique values than `num_classes` or `num_classes` + plus the number of values in `ignore_index` if `ignore_index` is not `None`. + + Examples + -------- + >>> from cyclops.evaluate.metrics.experimental.functional import ( + ... multiclass_average_precision, + ... ) + >>> import numpy.array_api as anp + >>> target = anp.asarray([0, 1, 2, 0, 1, 2]) + >>> preds = anp.asarray( + ... [[0.11, 0.22, 0.67], + ... [0.84, 0.73, 0.12], + ... [0.33, 0.92, 0.44], + ... [0.11, 0.22, 0.67], + ... [0.84, 0.73, 0.12], + ... [0.33, 0.92, 0.44]]) + >>> multiclass_average_precision( + ... target, preds, num_classes=3, thresholds=None, average=None, + ... ) + Array([0.33333334, 0.5 , 0.5 ], dtype=float32) + >>> multiclass_average_precision( + ... target, preds, num_classes=3, thresholds=None, average="macro", + ... ) + Array(0.44444445, dtype=float32) + >>> multiclass_average_precision( + ... target, preds, num_classes=3, thresholds=None, average="weighted", + ... ) + Array(0.44444448, dtype=float32) + """ + _multiclass_average_precision_validate_args( + num_classes, + thresholds=thresholds, + average=average, + ignore_index=ignore_index, + ) + xp = _multiclass_precision_recall_curve_validate_arrays( + target, + preds, + num_classes, + ignore_index=ignore_index, + ) + target, preds, thresholds = _multiclass_precision_recall_curve_format_arrays( + target, + preds, + num_classes, + thresholds=thresholds, + ignore_index=ignore_index, + xp=xp, + ) + state = _multiclass_precision_recall_curve_update( + target, + preds, + num_classes, + thresholds=thresholds, + xp=xp, + ) + return _multiclass_average_precision_compute( + state, + num_classes, + thresholds=thresholds, + average=average, + ) + + +def _multilabel_average_precision_validate_args( + num_labels: int, + thresholds: Optional[Union[int, List[float], Array]] = None, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + ignore_index: Optional[int] = None, +) -> None: + """Validate the arguments for the `multilabel_average_precision` function.""" + _multilabel_precision_recall_curve_validate_args( + num_labels, + thresholds=thresholds, + ignore_index=ignore_index, + ) + allowed_averages = ["micro", "macro", "weighted", "none"] + if average is not None and average not in allowed_averages: + raise ValueError( + f"Expected `average` to be one of {allowed_averages}, got {average}.", + ) + + +def _multilabel_average_precision_compute( + state: Union[Tuple[Array, Array], Array], + num_labels: int, + thresholds: Optional[Array], + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + ignore_index: Optional[int] = None, +) -> Array: + """Compute the average precision score for multilabel classification task.""" + xp = apc.array_namespace(state) + if average == "micro": + if apc.is_array_api_obj(state) and thresholds is not None: + state = xp.sum(state, axis=1) + else: + target, preds = flatten(state[0]), flatten(state[1]) + target, preds = remove_ignore_index(target, preds, ignore_index) + state = (target, preds) + return _binary_average_precision_compute(state, thresholds) + + precision, recall, _ = _multilabel_precision_recall_curve_compute( + state, + num_labels, + thresholds=thresholds, + ignore_index=ignore_index, + ) + return _reduce_average_precision( + precision, + recall, + average=average, + weights=xp.sum(xp.astype(state[0] == 1, xp.int32), axis=0, dtype=xp.float32) + if thresholds is None + else xp.sum(state[0, ...][:, 1, :], axis=-1), # type: ignore + xp=xp, + ) + + +def multilabel_average_precision( + target: Array, + preds: Array, + num_labels: int, + thresholds: Optional[Union[int, List[float], Array]] = None, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + ignore_index: Optional[int] = None, +) -> Array: + """Compute the average precision score for multilabel classification task. + + The average precision score summarizes a precision-recall curve as the weighted + mean of precisions achieved at each threshold, with the increase in recall from + the previous threshold used as the weight. + + Parameters + ---------- + target : Array + The target array of shape `(N, L, ...)` containing the ground truth labels + in the range [0, 1], where `N` is the number of samples and `L` is the + number of labels. + preds : Array + The prediction array of shape `(N, L, ...)` containing the probability/logit + scores for each sample, where `N` is the number of samples and `L` is the + number of labels. If `preds` contains floating point values that are not + in the range [0,1], they will be converted to probabilities using the + sigmoid function. + num_labels : int + The number of labels in the multilabel classification problem. + thresholds : Union[int, List[float], Array], optional, default=None + The thresholds to use for computing the average precision score. Can be one + of the following: + - `None`: use all unique values in `preds` as thresholds. + - `int`: use `int` (larger than 1) uniformly spaced thresholds in the range + [0, 1]. + - `List[float]`: use the values in the list as bins for the thresholds. + - `Array`: use the values in the Array as bins for the thresholds. The + array must be 1D. + average : {"micro", "macro", "weighted", "none"}, optional, default="macro" + The type of averaging to use for computing the average precision score. Can + be one of the following: + - `"micro"`: computes the average precision score globally by summing over + the average precision scores for each label. + - `"macro"`: compute the average precision score for each label and average + over the labels. + - `"weighted"`: computes the average of the precision for each label and + average over the labelwise scores using the support of each label as + weights. + - `"none"`: do not average over the labelwise scores. + ignore_index : int, optional, default=None + The value in `target` that should be ignored when computing the average + precision score. If `None`, all values in `target` are used. + + Returns + ------- + Array + The average precision score. + + Raises + ------ + TypeError + If `thresholds` is not `None` and not an integer, a list of floats or an + Array of floats. + ValueError + If `thresholds` is an integer and smaller than 2. + ValueError + If `thresholds` is a list of floats with values outside the range [0, 1] + and not monotonically increasing. + ValueError + If `thresholds` is an Array of floats and not all values are in the range + [0, 1] or the array is not one-dimensional. + ValueError + If `ignore_index` is not `None` or an integer. + ValueError + If `num_labels` is not an integer larger than 1. + ValueError + If `average` is not `"micro"`, `"macro"`, `"weighted"`, `"none"` or `None`. + TypeError + If `target` and `preds` are not compatible with the Python array API standard. + ValueError + If `target` and `preds` are empty. + ValueError + If `target` and `preds` are not numeric arrays. + ValueError + If `target` and `preds` do not have the same shape. + ValueError + If `target` contains floating point values. + ValueError + If `preds` contains non-floating point values. + RuntimeError + If `target` contains values outside the range [0, 1] or does not contain + `ignore_index` if `ignore_index` is not `None`. + ValueError + If the array API namespace of `target` and `preds` are not the same as the + array API namespace of `thresholds`. + ValueError + If the second dimension of `preds` is not equal to `num_labels`. + + Examples + -------- + >>> from cyclops.evaluate.metrics.experimental.functional import ( + ... multilabel_average_precision, + ... ) + >>> import numpy.array_api as anp + >>> target = anp.asarray([[0, 1, 0], [1, 1, 0], [0, 0, 1]]) + >>> preds = anp.asarray( + ... [[0.11, 0.22, 0.67], [0.84, 0.73, 0.12], [0.33, 0.92, 0.44]], + ... ) + >>> multilabel_average_precision( + ... target, preds, num_labels=3, thresholds=None, average=None, + ... ) + Array([1. , 0.5833334, 0.5 ], dtype=float32) + >>> multilabel_average_precision( + ... target, preds, num_labels=3, thresholds=None, average="micro", + ... ) + Array(0.58452386, dtype=float32) + >>> multilabel_average_precision( + ... target, preds, num_labels=3, thresholds=None, average="macro", + ... ) + Array(0.6944445, dtype=float32) + >>> multilabel_average_precision( + ... target, preds, num_labels=3, thresholds=None, average="weighted", + ... ) + Array(0.6666667, dtype=float32) + """ + _multilabel_average_precision_validate_args( + num_labels, + thresholds=thresholds, + average=average, + ignore_index=ignore_index, + ) + xp = _multilabel_precision_recall_curve_validate_arrays( + target, + preds, + num_labels, + thresholds=thresholds, + ignore_index=ignore_index, + ) + target, preds, thresholds = _multilabel_precision_recall_curve_format_arrays( + target, + preds, + num_labels, + thresholds=thresholds, + ignore_index=ignore_index, + xp=xp, + ) + state = _multilabel_precision_recall_curve_update( + target, + preds, + num_labels, + thresholds=thresholds, + xp=xp, + ) + return _multilabel_average_precision_compute( + state, + num_labels, + thresholds=thresholds, + average=average, + ignore_index=ignore_index, + ) diff --git a/cyclops/evaluate/metrics/experimental/mae.py b/cyclops/evaluate/metrics/experimental/mae.py index 3221d3340..dab2f5a5d 100644 --- a/cyclops/evaluate/metrics/experimental/mae.py +++ b/cyclops/evaluate/metrics/experimental/mae.py @@ -1,4 +1,6 @@ """Mean Absolute Error metric.""" +from typing import Any + from cyclops.evaluate.metrics.experimental.functional.mae import ( _mean_absolute_error_compute, _mean_absolute_error_update, @@ -10,6 +12,11 @@ class MeanAbsoluteError(Metric): """Mean Absolute Error. + Parameters + ---------- + **kwargs : Any + Keyword arguments to pass to the `Metric` base class. + Examples -------- >>> import numpy.array_api as anp @@ -24,8 +31,8 @@ class MeanAbsoluteError(Metric): name: str = "Mean Absolute Error" - def __init__(self) -> None: - super().__init__() + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) self.add_state_default_factory( "sum_abs_error", lambda xp: xp.asarray(0.0, dtype=xp.float32, device=self.device), # type: ignore diff --git a/cyclops/evaluate/metrics/experimental/mape.py b/cyclops/evaluate/metrics/experimental/mape.py index dede691f1..6d9d4afbf 100644 --- a/cyclops/evaluate/metrics/experimental/mape.py +++ b/cyclops/evaluate/metrics/experimental/mape.py @@ -1,4 +1,6 @@ """Mean Absolute Percentage Error (MAPE) metric.""" +from typing import Any + from cyclops.evaluate.metrics.experimental.functional.mape import ( _mean_absolute_percentage_error_compute, _mean_absolute_percentage_error_update, @@ -15,6 +17,8 @@ class MeanAbsolutePercentageError(Metric): epsilon : float, optional, default=1.17e-06 Specifies the lower bound for target values. Any target value below epsilon is set to epsilon (avoids division by zero errors). + **kwargs : Any + Keyword arguments to pass to the `Metric` base class. Examples -------- @@ -30,8 +34,8 @@ class MeanAbsolutePercentageError(Metric): name: str = "Mean Absolute Percentage Error" - def __init__(self, epsilon: float = 1.17e-6) -> None: - super().__init__() + def __init__(self, epsilon: float = 1.17e-6, **kwargs: Any) -> None: + super().__init__(**kwargs) if not isinstance(epsilon, float): raise TypeError(f"Expected `epsilon` to be a float. Got {type(epsilon)}") self.epsilon = epsilon diff --git a/cyclops/evaluate/metrics/experimental/mse.py b/cyclops/evaluate/metrics/experimental/mse.py index b8ef4b435..6210055a2 100644 --- a/cyclops/evaluate/metrics/experimental/mse.py +++ b/cyclops/evaluate/metrics/experimental/mse.py @@ -1,4 +1,6 @@ """Mean Squared Error metric.""" +from typing import Any + from cyclops.evaluate.metrics.experimental.functional.mse import ( _mean_squared_error_compute, _mean_squared_error_update, @@ -17,6 +19,8 @@ class MeanSquaredError(Metric): to `False`, returns the root mean squared error. num_outputs : int, optional, default=1 Number of outputs in multioutput setting. + **kwargs : Any + Additional keyword arguments to pass to the `Metric` base class. Examples -------- @@ -43,8 +47,13 @@ class MeanSquaredError(Metric): name: str = "Mean Squared Error" - def __init__(self, squared: bool = True, num_outputs: int = 1) -> None: - super().__init__() + def __init__( + self, + squared: bool = True, + num_outputs: int = 1, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) if not isinstance(squared, bool): raise TypeError(f"Expected `squared` to be a boolean. Got {type(squared)}") if not isinstance(num_outputs, int) and num_outputs > 0: diff --git a/cyclops/evaluate/metrics/experimental/negative_predictive_value.py b/cyclops/evaluate/metrics/experimental/negative_predictive_value.py index 7a5f1e5ee..99602555d 100644 --- a/cyclops/evaluate/metrics/experimental/negative_predictive_value.py +++ b/cyclops/evaluate/metrics/experimental/negative_predictive_value.py @@ -20,7 +20,7 @@ class BinaryNPV(_AbstractBinaryStatScores, registry_key="binary_npv"): Threshold for converting probabilities into binary values. ignore_index : int, optional Values in the target array to ignore when computing the metric. - **kwargs + **kwargs : Any Additional keyword arguments common to all metrics. Examples @@ -82,6 +82,8 @@ class MulticlassNPV( Specifies a target class that is ignored when computing the negative predictive value. Ignoring a target class means that the corresponding predictions do not contribute to the negative predictive value. + **kwargs : Any + Additional keyword arguments common to all metrics. Examples -------- @@ -151,6 +153,8 @@ class MultilabelNPV( ignore_index : int, optional, default=None Specifies a value in the target array(s) that is ignored when computing the negative predictive value. + **kwargs : Any + Additional keyword arguments common to all metrics. Examples -------- diff --git a/cyclops/evaluate/metrics/experimental/precision_recall.py b/cyclops/evaluate/metrics/experimental/precision_recall.py index d704aff89..57253a1d6 100644 --- a/cyclops/evaluate/metrics/experimental/precision_recall.py +++ b/cyclops/evaluate/metrics/experimental/precision_recall.py @@ -20,7 +20,7 @@ class BinaryPrecision(_AbstractBinaryStatScores, registry_key="binary_precision" Threshold for converting probabilities into binary values. ignore_index : int, optional Values in the target array to ignore when computing the metric. - **kwargs + **kwargs : Any Additional keyword arguments common to all metrics. Examples @@ -59,7 +59,7 @@ class BinaryPPV(BinaryPrecision, registry_key="binary_ppv"): Threshold for converting probabilities into binary values. ignore_index : int, optional Values in the target array to ignore when computing the metric. - **kwargs + **kwargs : Any Additional keyword arguments common to all metrics. Examples @@ -114,6 +114,8 @@ class MulticlassPrecision( Specifies a target class that is ignored when computing the precision score. Ignoring a target class means that the corresponding predictions do not contribute to the precision score. + **kwargs : Any + Additional keyword arguments common to all metrics. Examples -------- @@ -179,6 +181,8 @@ class MulticlassPPV(MulticlassPrecision, registry_key="multiclass_ppv"): Specifies a target class that is ignored when computing the positive predictive value. Ignoring a target class means that the corresponding predictions do not contribute to the positive predictive value. + **kwargs : Any + Additional keyword arguments common to all metrics. Examples -------- @@ -235,6 +239,8 @@ class MultilabelPrecision( ignore_index : int, optional, default=None Specifies a value in the target array(s) that is ignored when computing the precision score. + **kwargs : Any + Additional keyword arguments common to all metrics. Examples -------- @@ -301,6 +307,8 @@ class MultilabelPPV(MultilabelPrecision, registry_key="multilabel_ppv"): ignore_index : int, optional, default=None Specifies a value in the target array(s) that is ignored when computing the positive predictive value. + **kwargs : Any + Additional keyword arguments common to all metrics. Examples -------- @@ -334,7 +342,7 @@ class BinaryRecall(_AbstractBinaryStatScores, registry_key="binary_recall"): Threshold for converting probabilities into binary values. ignore_index : int, optional Values in the target array to ignore when computing the metric. - **kwargs + **kwargs : Any Additional keyword arguments common to all metrics. Examples @@ -373,7 +381,7 @@ class BinarySensitivity(BinaryRecall, registry_key="binary_sensitivity"): Threshold for converting probabilities into binary values. ignore_index : int, optional Values in the target array to ignore when computing the metric. - **kwargs + **kwargs : Any Additional keyword arguments common to all metrics. Examples @@ -407,7 +415,7 @@ class BinaryTPR(BinaryRecall, registry_key="binary_tpr"): Threshold for converting probabilities into binary values. ignore_index : int, optional Values in the target array to ignore when computing the metric. - **kwargs + **kwargs : Any Additional keyword arguments common to all metrics. Examples @@ -459,6 +467,8 @@ class MulticlassRecall(_AbstractMulticlassStatScores, registry_key="multiclass_r Specifies a target class that is ignored when computing the recall score. Ignoring a target class means that the corresponding predictions do not contribute to the recall score. + **kwargs : Any + Additional keyword arguments common to all metrics. Examples -------- @@ -523,6 +533,8 @@ class MulticlassSensitivity(MulticlassRecall, registry_key="multiclass_sensitivi Specifies a target class that is ignored when computing the sensitivity score. Ignoring a target class means that the corresponding predictions do not contribute to the sensitivity score. + **kwargs : Any + Additional keyword arguments common to all metrics. Examples -------- @@ -575,6 +587,8 @@ class MulticlassTPR(MulticlassRecall, registry_key="multiclass_tpr"): Specifies a target class that is ignored when computing the true positive rate. Ignoring a target class means that the corresponding predictions do not contribute to the true positive rate. + **kwargs : Any + Additional keyword arguments common to all metrics. Examples -------- @@ -628,6 +642,8 @@ class MultilabelRecall(_AbstractMultilabelStatScores, registry_key="multilabel_r ignore_index : int, optional, default=None Specifies a value in the target array(s) that is ignored when computing the recall score. + **kwargs : Any + Additional keyword arguments common to all metrics. Examples -------- @@ -694,6 +710,8 @@ class MultilabelSensitivity(MultilabelRecall, registry_key="multilabel_sensitivi ignore_index : int, optional, default=None Specifies a value in the target array(s) that is ignored when computing the sensitivity score. + **kwargs : Any + Additional keyword arguments common to all metrics. Examples -------- @@ -748,6 +766,8 @@ class MultilabelTPR(MultilabelRecall, registry_key="multilabel_tpr"): ignore_index : int, optional, default=None Specifies a value in the target array(s) that is ignored when computing the true positive rate. + **kwargs : Any + Additional keyword arguments common to all metrics. Examples -------- diff --git a/cyclops/evaluate/metrics/experimental/precision_recall_curve.py b/cyclops/evaluate/metrics/experimental/precision_recall_curve.py index 46bfba20e..6567e407f 100644 --- a/cyclops/evaluate/metrics/experimental/precision_recall_curve.py +++ b/cyclops/evaluate/metrics/experimental/precision_recall_curve.py @@ -1,6 +1,6 @@ """Classes for computing the precision-recall curve.""" from types import ModuleType -from typing import List, Literal, Optional, Tuple, Union +from typing import Any, List, Literal, Optional, Tuple, Union import array_api_compat as apc @@ -43,6 +43,8 @@ class BinaryPrecisionRecallCurve(Metric, registry_key="binary_precision_recall_c ignore_index : int, optional, default=None The value in `target` that should be ignored when computing the precision and recall. If `None`, all values in `target` are used. + **kwargs : Any + Additional keyword arguments to pass to the `Metric` base class. Examples -------- @@ -69,9 +71,10 @@ def __init__( self, thresholds: Optional[Union[int, List[float], Array]] = None, ignore_index: Optional[int] = None, + **kwargs: Any, ) -> None: """Initialize a `BinaryPrecisionRecallCurve` instance.""" - super().__init__() + super().__init__(**kwargs) _binary_precision_recall_curve_validate_args(thresholds, ignore_index) self.ignore_index = ignore_index self.thresholds = thresholds @@ -173,6 +176,8 @@ class MulticlassPrecisionRecallCurve( ignore_index : int or Tuple[int], optional, default=None The value(s) in `target` that should be ignored when computing the precision and recall. If `None`, all values in `target` are used. + **kwargs : Any + Additional keyword arguments to pass to the `Metric` base class. Examples -------- @@ -219,9 +224,10 @@ def __init__( thresholds: Optional[Union[int, List[float], Array]] = None, average: Optional[Literal["macro", "micro", "none"]] = None, ignore_index: Optional[Union[int, Tuple[int]]] = None, + **kwargs: Any, ) -> None: """Initialize a `MulticlassPrecisionRecallCurve` instance.""" - super().__init__() + super().__init__(**kwargs) _multiclass_precision_recall_curve_validate_args( num_classes, thresholds=thresholds, @@ -345,6 +351,8 @@ class MultilabelPrecisionRecallCurve( ignore_index : int, optional, default=None The value in `target` that should be ignored when computing the precision and recall. If `None`, all values in `target` are used. + **kwargs : Any + Additional keyword arguments to pass to the `Metric` base class. Examples -------- @@ -385,9 +393,10 @@ def __init__( num_labels: int, thresholds: Optional[Union[int, List[float], Array]] = None, ignore_index: Optional[int] = None, + **kwargs: Any, ) -> None: """Initialize a `MultilabelPrecisionRecallCurve` instance.""" - super().__init__() + super().__init__(**kwargs) _multilabel_precision_recall_curve_validate_args( num_labels, thresholds=thresholds, diff --git a/cyclops/evaluate/metrics/experimental/roc.py b/cyclops/evaluate/metrics/experimental/roc.py index 942cc4e89..6c6fbecb5 100644 --- a/cyclops/evaluate/metrics/experimental/roc.py +++ b/cyclops/evaluate/metrics/experimental/roc.py @@ -15,7 +15,7 @@ from cyclops.evaluate.metrics.experimental.utils.types import Array -class BinaryROC(BinaryPrecisionRecallCurve): +class BinaryROC(BinaryPrecisionRecallCurve, registry_key="binary_roc_curve"): """The receiver operating characteristic (ROC) curve. Parameters @@ -31,6 +31,8 @@ class BinaryROC(BinaryPrecisionRecallCurve): ignore_index : int, optional, default=None The value in `target` that should be ignored when computing the ROC curve. If `None`, all values in `target` are used. + **kwargs : Any + Additional keyword arguments common to all metrics. Examples -------- @@ -62,7 +64,10 @@ def _compute_metric(self) -> Tuple[Array, Array, Array]: return _binary_roc_compute(state, self.thresholds) # type: ignore[arg-type] -class MulticlassROC(MulticlassPrecisionRecallCurve): +class MulticlassROC( + MulticlassPrecisionRecallCurve, + registry_key="multiclass_roc_curve", +): """The reciever operator characteristics (ROC) curve. Parameters @@ -89,6 +94,8 @@ class MulticlassROC(MulticlassPrecisionRecallCurve): ignore_index : int or Tuple[int], optional, default=None The value(s) in `target` that should be ignored when computing the ROC curve. If `None`, all values in `target` are used. + **kwargs : Any + Additional keyword arguments common to all metrics. Examples -------- @@ -145,7 +152,10 @@ def _compute_metric( ) -class MultilabelROC(MultilabelPrecisionRecallCurve): +class MultilabelROC( + MultilabelPrecisionRecallCurve, + registry_key="multilabel_roc_curve", +): """The reciever operator characteristics (ROC) curve. Parameters @@ -163,6 +173,8 @@ class MultilabelROC(MultilabelPrecisionRecallCurve): ignore_index : int, optional, default=None The value in `target` that should be ignored when computing the ROC Curve. If `None`, all values in `target` are used. + **kwargs + Additional keyword arguments common to all metrics. Examples -------- diff --git a/cyclops/evaluate/metrics/experimental/smape.py b/cyclops/evaluate/metrics/experimental/smape.py index df2392ce4..a7e61c027 100644 --- a/cyclops/evaluate/metrics/experimental/smape.py +++ b/cyclops/evaluate/metrics/experimental/smape.py @@ -1,4 +1,6 @@ """Symmetric Mean Absolute Percentage Error metric.""" +from typing import Any + from cyclops.evaluate.metrics.experimental.functional.smape import ( _symmetric_mean_absolute_percentage_error_compute, _symmetric_mean_absolute_percentage_error_update, @@ -15,6 +17,8 @@ class SymmetricMeanAbsolutePercentageError(Metric): epsilon : float, optional, default=1.17e-6 Specifies the lower bound for target values. Any target value below epsilon is set to epsilon (avoids division by zero errors). + **kwargs : Any + Keyword arguments to pass to the `Metric` base class. Examples -------- @@ -32,8 +36,8 @@ class SymmetricMeanAbsolutePercentageError(Metric): name: str = "Symmetric Mean Absolute Percentage Error" - def __init__(self, epsilon: float = 1.17e-6) -> None: - super().__init__() + def __init__(self, epsilon: float = 1.17e-6, **kwargs: Any) -> None: + super().__init__(**kwargs) if not isinstance(epsilon, float): raise TypeError(f"Expected `epsilon` to be a float. Got {type(epsilon)}") self.epsilon = epsilon diff --git a/cyclops/evaluate/metrics/experimental/specificity.py b/cyclops/evaluate/metrics/experimental/specificity.py index b289b046d..768b8939e 100644 --- a/cyclops/evaluate/metrics/experimental/specificity.py +++ b/cyclops/evaluate/metrics/experimental/specificity.py @@ -20,7 +20,7 @@ class BinarySpecificity(_AbstractBinaryStatScores, registry_key="binary_specific Threshold for converting probabilities into binary values. ignore_index : int, optional Values in the target array to ignore when computing the metric. - **kwargs + **kwargs : Any Additional keyword arguments common to all metrics. Examples @@ -81,6 +81,8 @@ class MulticlassSpecificity( Specifies a target class that is ignored when computing the specificity score. Ignoring a target class means that the corresponding predictions do not contribute to the specificity score. + **kwargs : Any + Additional keyword arguments common to all metrics. Examples -------- @@ -150,6 +152,8 @@ class MultilabelSpecificity( ignore_index : int, optional, default=None Specifies a value in the target array(s) that is ignored when computing the specificity score. + **kwargs : Any + Additional keyword arguments common to all metrics. Examples -------- @@ -196,7 +200,7 @@ class BinaryTNR(BinarySpecificity, registry_key="binary_tnr"): Threshold for converting probabilities into binary values. ignore_index : int, optional Values in the target array to ignore when computing the metric. - **kwargs + **kwargs : Any Additional keyword arguments common to all metrics. Examples @@ -249,6 +253,8 @@ class MulticlassTNR(MulticlassSpecificity, registry_key="multiclass_tnr"): Specifies a target class that is ignored when computing the true negative rate. Ignoring a target class means that the corresponding predictions do not contribute to the true negative rate. + **kwargs : Any + Additional keyword arguments common to all metrics. Examples -------- @@ -303,6 +309,8 @@ class MultilabelTNR(MultilabelSpecificity, registry_key="multilabel_tnr"): ignore_index : int, optional, default=None Specifies a value in the target array(s) that is ignored when computing the true negative rate. + **kwargs : Any + Additional keyword arguments common to all metrics. Examples -------- diff --git a/cyclops/evaluate/metrics/experimental/wmape.py b/cyclops/evaluate/metrics/experimental/wmape.py index a24e8eac5..fc37cf6a6 100644 --- a/cyclops/evaluate/metrics/experimental/wmape.py +++ b/cyclops/evaluate/metrics/experimental/wmape.py @@ -1,5 +1,6 @@ """Weighted Mean Absolute Percentage Error metric.""" from types import ModuleType +from typing import Any from cyclops.evaluate.metrics.experimental.functional.wmape import ( _weighted_mean_absolute_percentage_error_compute, @@ -17,6 +18,8 @@ class WeightedMeanAbsolutePercentageError(Metric): epsilon : float, optional, default=1.17e-6 Specifies the lower bound for target values. Any target value below epsilon is set to epsilon (avoids division by zero errors). + **kwargs : Any + Keyword arguments to pass to the `Metric` base class. Examples -------- @@ -34,8 +37,8 @@ class WeightedMeanAbsolutePercentageError(Metric): name: str = "Weighted Mean Absolute Percentage Error" - def __init__(self, epsilon: float = 1.17e-6) -> None: - super().__init__() + def __init__(self, epsilon: float = 1.17e-6, **kwargs: Any) -> None: + super().__init__(**kwargs) if not isinstance(epsilon, float): raise TypeError(f"Expected `epsilon` to be a float. Got {type(epsilon)}") self.epsilon = epsilon diff --git a/cyclops/evaluate/metrics/factory.py b/cyclops/evaluate/metrics/factory.py index e83d76b55..bdcd6048f 100644 --- a/cyclops/evaluate/metrics/factory.py +++ b/cyclops/evaluate/metrics/factory.py @@ -1,18 +1,28 @@ """Factory for creating metrics.""" from difflib import get_close_matches -from typing import Any, List +from typing import Any, List, Union +from cyclops.evaluate.metrics.experimental.metric import ( + _METRIC_REGISTRY as _EXPERIMENTAL_METRIC_REGISTRY, +) +from cyclops.evaluate.metrics.experimental.metric import Metric as ExperimentalMetric from cyclops.evaluate.metrics.metric import _METRIC_REGISTRY, Metric -def create_metric(metric_name: str, **kwargs: Any) -> Metric: +def create_metric( + metric_name: str, + experimental: bool = False, + **kwargs: Any, +) -> Union[Metric, ExperimentalMetric]: """Create a metric instance from a name. Parameters ---------- metric_name : str The name of the metric. + experimental : bool + Whether to use metrics from `cyclops.evaluate.metrics.experimental`. **kwargs : Any The keyword arguments to pass to the metric constructor. @@ -22,11 +32,20 @@ def create_metric(metric_name: str, **kwargs: Any) -> Metric: The metric instance. """ - metric_class = _METRIC_REGISTRY.get(metric_name, None) + metric_class = ( + _METRIC_REGISTRY.get(metric_name, None) + if not experimental + else _EXPERIMENTAL_METRIC_REGISTRY.get(metric_name, None) + ) if metric_class is None: + registry_keys: List[str] = ( + list(_METRIC_REGISTRY.keys()) + if not experimental + else list(_EXPERIMENTAL_METRIC_REGISTRY.keys()) # type: ignore[arg-type] + ) similar_keys_list: List[str] = get_close_matches( metric_name, - _METRIC_REGISTRY.keys(), + registry_keys, n=5, ) similar_keys: str = ", ".join(similar_keys_list) diff --git a/cyclops/report/plot/classification.py b/cyclops/report/plot/classification.py index e176ec386..04e0130f6 100644 --- a/cyclops/report/plot/classification.py +++ b/cyclops/report/plot/classification.py @@ -131,7 +131,7 @@ def roc_curve( if auroc is not None: assert isinstance( auroc, - float, + (float, np.floating), ), "AUROCs must be a float for binary tasks" name = f"Model (AUC = {auroc:.2f})" else: @@ -227,7 +227,7 @@ def roc_curve_comparison( if aurocs and slice_name in aurocs: assert isinstance( aurocs[slice_name], - float, + (float, np.floating), ), "AUROCs must be a float for binary tasks" name = f"{slice_name} (AUC = {aurocs[slice_name]:.2f})" else: @@ -401,7 +401,7 @@ def precision_recall_curve_comparison( if auprcs and slice_name in auprcs: assert isinstance( auprcs[slice_name], - float, + (float, np.floating), ), "AUPRCs must be a float for binary tasks" name = f"{slice_name} (AUC = {auprcs[slice_name]:.2f})" else: @@ -483,8 +483,10 @@ def metrics_value( """ if self.task_type == "binary": assert all( - not isinstance(value, (list, np.ndarray)) for value in metrics.values() - ), ("Metrics must not be of type list or np.ndarray for" "binary tasks") + not isinstance(value, list) + and not (isinstance(value, np.ndarray) and value.ndim > 0) + for value in metrics.values() + ), "Metrics must not be of type list or np.ndarray for binary tasks" trace = bar_plot( x=list(metrics.keys()), # type: ignore[arg-type] y=list(metrics.values()), # type: ignore[arg-type] @@ -705,7 +707,8 @@ def metrics_comparison_radar( for slice_name, metrics in slice_metrics.items(): metric_names = list(metrics.keys()) assert all( - not isinstance(value, (list, np.ndarray)) + not isinstance(value, list) + and not (isinstance(value, np.ndarray) and value.ndim > 0) for value in metrics.values() ), ( "Generic metrics must not be of type list or np.ndarray for" @@ -725,7 +728,9 @@ def metrics_comparison_radar( radial_data: List[float] = [] theta_data: List[float] = [] for metric_name, metric_values in metrics.items(): - if isinstance(metric_values, (list, np.ndarray)): + if isinstance(metric_values, list) or ( + isinstance(metric_values, np.ndarray) and metric_values.ndim > 0 + ): assert ( len(metric_values) == self.class_num ), "Metric values must be of length class_num for \ @@ -736,7 +741,7 @@ def metrics_comparison_radar( for i in range(self.class_num) ] theta_data.extend(theta) # type: ignore[arg-type] - elif isinstance(metric_values, float): + elif isinstance(metric_values, (float, np.floating)): radial_data.append(metric_values) theta_data.append(metric_name) # type: ignore[arg-type] else: @@ -807,10 +812,11 @@ def metrics_comparison_bar( metric_names = list(metrics.keys()) metric_values = list(metrics.values()) assert all( - not isinstance(value, (list, np.ndarray)) + not isinstance(value, list) + and not (isinstance(value, np.ndarray) and value.ndim > 0) for value in metrics.values() ), ( - "Generic metrics must not be of type list or np.ndarray for" + "Generic metrics must not be of type list or np.ndarray for " "binary tasks" ) trace.append( @@ -856,7 +862,10 @@ def metrics_comparison_bar( metric_names = list(metrics.keys()) for num in range(self.class_num): for metric_name in metric_names: - if isinstance(metrics[metric_name], (list, np.ndarray)): + if isinstance(metrics[metric_name], list) or ( + isinstance(metrics[metric_name], np.ndarray) + and metrics[metric_name].ndim > 0 + ): metric_values = metrics[metric_name][num] # type: ignore else: metric_values = metrics[metric_name] # type: ignore @@ -926,7 +935,8 @@ def metrics_comparison_scatter( metric_names = list(metrics.keys()) metric_values = list(metrics.values()) assert all( - not isinstance(value, (list, np.ndarray)) + not isinstance(value, list) + and not (isinstance(value, np.ndarray) and value.ndim > 0) for value in metrics.values() ), ( "Generic metrics must not be of type list or np.ndarray for" diff --git a/cyclops/tasks/classification.py b/cyclops/tasks/classification.py index 0772f13d3..d5344bb82 100644 --- a/cyclops/tasks/classification.py +++ b/cyclops/tasks/classification.py @@ -14,8 +14,8 @@ from cyclops.data.slicer import SliceSpec from cyclops.evaluate.evaluator import evaluate from cyclops.evaluate.fairness.config import FairnessConfig +from cyclops.evaluate.metrics.experimental.metric_dict import MetricDict from cyclops.evaluate.metrics.factory import create_metric -from cyclops.evaluate.metrics.metric import MetricCollection from cyclops.models.catalog import ( _img_model_keys, _model_names_mapping, @@ -261,7 +261,7 @@ def predict( def evaluate( self, dataset: Union[Dataset, DatasetDict], - metrics: Union[List[str], MetricCollection], + metrics: Union[List[str], MetricDict], model_names: Optional[Union[str, List[str]]] = None, transforms: Optional[ColumnTransformer] = None, prediction_column_prefix: str = "predictions", @@ -278,7 +278,7 @@ def evaluate( ---------- dataset : Union[Dataset, DatasetDict] HuggingFace dataset. - metrics : Union[List[str], MetricCollection] + metrics : Union[List[str], MetricDict] Metrics to be evaluated. model_names : Union[str, List[str]], optional Model names to be evaluated, if not specified all fitted models \ @@ -315,9 +315,9 @@ def evaluate( if splits_mapping is None: splits_mapping = {"test": "test"} if isinstance(metrics, list) and len(metrics): - metrics_collection = MetricCollection( + metrics_collection = MetricDict( [ - create_metric( + create_metric( # type: ignore[misc] m, task=self.task_type, num_labels=len(self.task_features), @@ -325,7 +325,7 @@ def evaluate( for m in metrics ], ) - elif isinstance(metrics, MetricCollection): + elif isinstance(metrics, MetricDict): metrics_collection = metrics if isinstance(model_names, str): model_names = [model_names] @@ -345,6 +345,22 @@ def evaluate( only_predictions=False, splits_mapping=splits_mapping, ) + + # select the probability scores of the positive class since metrics + # expect a single column of probabilities + dataset = dataset.map( # type: ignore[union-attr] + lambda examples: { + f"{prediction_column_prefix}.{model_name}": np.array( # noqa: B023 + examples, + )[ + :, + 1, + ].tolist(), + }, + batched=True, + batch_size=batch_size, + input_columns=f"{prediction_column_prefix}.{model_name}", + ) results = evaluate( dataset=dataset, metrics=metrics_collection, @@ -448,7 +464,7 @@ def predict( def evaluate( self, dataset: Union[Dataset, DatasetDict], - metrics: Union[List[str], MetricCollection], + metrics: Union[List[str], MetricDict], model_names: Optional[Union[str, List[str]]] = None, transforms: Optional[Compose] = None, prediction_column_prefix: str = "predictions", @@ -465,7 +481,7 @@ def evaluate( ---------- dataset : Union[Dataset, DatasetDict] HuggingFace dataset. - metrics : Union[List[str], MetricCollection] + metrics : Union[List[str], MetricDict] Metrics to be evaluated. model_names : Union[str, List[str]], optional Model names to be evaluated, required if more than one model exists, \ @@ -515,9 +531,9 @@ def add_missing_labels(examples: Dict[str, Any]) -> Dict[str, Any]: dataset = dataset.map(add_missing_labels) if isinstance(metrics, list) and len(metrics): - metrics_collection = MetricCollection( + metrics_collection = MetricDict( [ - create_metric( + create_metric( # type: ignore[misc] m, task=self.task_type, num_labels=len(self.task_target), @@ -525,7 +541,7 @@ def add_missing_labels(examples: Dict[str, Any]) -> Dict[str, Any]: for m in metrics ], ) - elif isinstance(metrics, MetricCollection): + elif isinstance(metrics, MetricDict): metrics_collection = metrics if isinstance(model_names, str): model_names = [model_names] diff --git a/docs/source/tutorials/kaggle/heart_failure_prediction.ipynb b/docs/source/tutorials/kaggle/heart_failure_prediction.ipynb index 69d0642e3..44acea15b 100644 --- a/docs/source/tutorials/kaggle/heart_failure_prediction.ipynb +++ b/docs/source/tutorials/kaggle/heart_failure_prediction.ipynb @@ -44,7 +44,8 @@ "from cyclops.data.df.feature import TabularFeatures\n", "from cyclops.data.slicer import SliceSpec\n", "from cyclops.evaluate.fairness import FairnessConfig # noqa: E402\n", - "from cyclops.evaluate.metrics import MetricCollection, create_metric\n", + "from cyclops.evaluate.metrics import create_metric\n", + "from cyclops.evaluate.metrics.experimental.metric_dict import MetricDict\n", "from cyclops.models.catalog import create_model\n", "from cyclops.report import ModelCardReport\n", "from cyclops.report.plot.classification import ClassificationPlotter\n", @@ -697,7 +698,7 @@ "\n", "Evaluation is done using various evaluation metrics that provide different perspectives on the model's predictive abilities i.e. standard performance metrics and fairness metrics.\n", "\n", - "The standard performance metrics can be created using the `MetricCollection` object." + "The standard performance metrics can be created using the `MetricDict` object." ] }, { @@ -709,17 +710,19 @@ "outputs": [], "source": [ "metric_names = [\n", - " \"accuracy\",\n", - " \"precision\",\n", - " \"recall\",\n", - " \"f1_score\",\n", - " \"auroc\",\n", - " \"average_precision\",\n", - " \"roc_curve\",\n", - " \"precision_recall_curve\",\n", + " \"binary_accuracy\",\n", + " \"binary_precision\",\n", + " \"binary_recall\",\n", + " \"binary_f1_score\",\n", + " \"binary_auroc\",\n", + " \"binary_average_precision\",\n", + " \"binary_roc_curve\",\n", + " \"binary_precision_recall_curve\",\n", "]\n", - "metrics = [create_metric(metric_name, task=\"binary\") for metric_name in metric_names]\n", - "metric_collection = MetricCollection(metrics)" + "metrics = [\n", + " create_metric(metric_name, experimental=True) for metric_name in metric_names\n", + "]\n", + "metric_collection = MetricDict(metrics)" ] }, { @@ -762,7 +765,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "A `MetricCollection` can also be defined for the fairness metrics." + "A `MetricDict` can also be defined for the fairness metrics." ] }, { @@ -773,21 +776,15 @@ }, "outputs": [], "source": [ - "specificity = create_metric(\n", - " metric_name=\"specificity\",\n", - " task=\"binary\",\n", - ")\n", - "sensitivity = create_metric(\n", - " metric_name=\"sensitivity\",\n", - " task=\"binary\",\n", - ")\n", + "specificity = create_metric(metric_name=\"binary_specificity\", experimental=True)\n", + "sensitivity = create_metric(metric_name=\"binary_sensitivity\", experimental=True)\n", "\n", - "fpr = 1 - specificity\n", - "fnr = 1 - sensitivity\n", + "fpr = -specificity + 1\n", + "fnr = -sensitivity + 1\n", "\n", "ber = (fpr + fnr) / 2\n", "\n", - "fairness_metric_collection = MetricCollection(\n", + "fairness_metric_collection = MetricDict(\n", " {\n", " \"Sensitivity\": sensitivity,\n", " \"Specificity\": specificity,\n", @@ -858,8 +855,13 @@ "source": [ "results_female, _ = heart_failure_prediction_task.evaluate(\n", " dataset=dataset[\"test\"],\n", - " metrics=MetricCollection(\n", - " {\"BinaryAccuracy\": create_metric(metric_name=\"accuracy\", task=\"binary\")},\n", + " metrics=MetricDict(\n", + " {\n", + " \"BinaryAccuracy\": create_metric(\n", + " metric_name=\"binary_accuracy\",\n", + " experimental=True,\n", + " ),\n", + " },\n", " ),\n", " model_names=model_name,\n", " transforms=preprocessor,\n", @@ -889,7 +891,7 @@ "model_name = f\"model_for_preds.{model_name}\"\n", "results_flat = flatten_results_dict(\n", " results=results,\n", - " remove_metrics=[\"BinaryROCCurve\", \"BinaryPrecisionRecallCurve\"],\n", + " remove_metrics=[\"BinaryROC\", \"BinaryPrecisionRecallCurve\"],\n", " model_name=model_name,\n", ")\n", "results_female_flat = flatten_results_dict(\n", @@ -910,7 +912,7 @@ " report.log_quantitative_analysis(\n", " \"performance\",\n", " name=name,\n", - " value=metric,\n", + " value=metric.tolist(),\n", " description=descriptions[name],\n", " metric_slice=split,\n", " pass_fail_thresholds=0.7,\n", @@ -930,7 +932,7 @@ " report.log_quantitative_analysis(\n", " \"performance\",\n", " name=name,\n", - " value=metric,\n", + " value=metric.tolist(),\n", " description=descriptions[name],\n", " metric_slice=split,\n", " pass_fail_thresholds=0.7,\n", @@ -963,7 +965,7 @@ "source": [ "# extracting the ROC curves and AUROC results for all the slices\n", "roc_curves = {\n", - " slice_name: slice_results[\"BinaryROCCurve\"]\n", + " slice_name: slice_results[\"BinaryROC\"]\n", " for slice_name, slice_results in results[model_name].items()\n", "}\n", "aurocs = {\n", @@ -1036,7 +1038,7 @@ "overall_performance = {\n", " metric_name: metric_value\n", " for metric_name, metric_value in results[model_name][\"overall\"].items()\n", - " if metric_name not in [\"BinaryROCCurve\", \"BinaryPrecisionRecallCurve\"]\n", + " if metric_name not in [\"BinaryROC\", \"BinaryPrecisionRecallCurve\"]\n", "}" ] }, @@ -1070,7 +1072,7 @@ " slice_name: {\n", " metric_name: metric_value\n", " for metric_name, metric_value in slice_results.items()\n", - " if metric_name not in [\"BinaryROCCurve\", \"BinaryPrecisionRecallCurve\"]\n", + " if metric_name not in [\"BinaryROC\", \"BinaryPrecisionRecallCurve\"]\n", " }\n", " for slice_name, slice_results in results[model_name].items()\n", "}" diff --git a/docs/source/tutorials/mimiciv/mortality_prediction.ipynb b/docs/source/tutorials/mimiciv/mortality_prediction.ipynb index 09fa6649f..b480f2dfd 100644 --- a/docs/source/tutorials/mimiciv/mortality_prediction.ipynb +++ b/docs/source/tutorials/mimiciv/mortality_prediction.ipynb @@ -48,7 +48,8 @@ "from cyclops.data.df.feature import TabularFeatures\n", "from cyclops.data.slicer import SliceSpec\n", "from cyclops.evaluate.fairness import FairnessConfig # noqa: E402\n", - "from cyclops.evaluate.metrics import MetricCollection, create_metric\n", + "from cyclops.evaluate.metrics import create_metric\n", + "from cyclops.evaluate.metrics.experimental.metric_dict import MetricDict\n", "from cyclops.models.catalog import create_model\n", "from cyclops.report import ModelCardReport\n", "from cyclops.report.plot.classification import ClassificationPlotter\n", @@ -767,7 +768,7 @@ "\n", "Evaluation is done using various evaluation metrics that provide different perspectives on the model's predictive abilities i.e. standard performance metrics and fairness metrics.\n", "\n", - "The standard performance metrics can be created using the `MetricCollection` object." + "The standard performance metrics can be created using the `MetricDict` object." ] }, { @@ -777,17 +778,19 @@ "outputs": [], "source": [ "metric_names = [\n", - " \"accuracy\",\n", - " \"precision\",\n", - " \"recall\",\n", - " \"f1_score\",\n", - " \"auroc\",\n", - " \"average_precision\",\n", - " \"roc_curve\",\n", - " \"precision_recall_curve\",\n", + " \"binary_accuracy\",\n", + " \"binary_precision\",\n", + " \"binary_recall\",\n", + " \"binary_f1_score\",\n", + " \"binary_auroc\",\n", + " \"binary_average_precision\",\n", + " \"binary_roc_curve\",\n", + " \"binary_precision_recall_curve\",\n", "]\n", - "metrics = [create_metric(metric_name, task=\"binary\") for metric_name in metric_names]\n", - "metric_collection = MetricCollection(metrics)" + "metrics = [\n", + " create_metric(metric_name, experimental=True) for metric_name in metric_names\n", + "]\n", + "metric_collection = MetricDict(metrics)" ] }, { @@ -830,7 +833,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "A `MetricCollection` can also be defined for the fairness metrics." + "A `MetricDict` can also be defined for the fairness metrics." ] }, { @@ -839,18 +842,12 @@ "metadata": {}, "outputs": [], "source": [ - "specificity = create_metric(\n", - " metric_name=\"specificity\",\n", - " task=\"binary\",\n", - ")\n", - "sensitivity = create_metric(\n", - " metric_name=\"sensitivity\",\n", - " task=\"binary\",\n", - ")\n", - "fpr = 1 - specificity\n", - "fnr = 1 - sensitivity\n", + "specificity = create_metric(metric_name=\"binary_specificity\", experimental=True)\n", + "sensitivity = create_metric(metric_name=\"binary_sensitivity\", experimental=True)\n", + "fpr = -specificity + 1 # __rsub__ is not implemented for metrics\n", + "fnr = -sensitivity + 1\n", "ber = (fpr + fnr) / 2\n", - "fairness_metric_collection = MetricCollection(\n", + "fairness_metric_collection = MetricDict(\n", " {\n", " \"Sensitivity\": sensitivity,\n", " \"Specificity\": specificity,\n", @@ -929,7 +926,7 @@ "model_name = f\"model_for_preds.{model_name}\"\n", "results_flat = flatten_results_dict(\n", " results=results,\n", - " remove_metrics=[\"BinaryROCCurve\", \"BinaryPrecisionRecallCurve\"],\n", + " remove_metrics=[\"BinaryROC\", \"BinaryPrecisionRecallCurve\"],\n", " model_name=model_name,\n", ")" ] @@ -954,7 +951,7 @@ " report.log_quantitative_analysis(\n", " \"performance\",\n", " name=name,\n", - " value=metric,\n", + " value=metric.tolist(),\n", " description=descriptions[name],\n", " metric_slice=split,\n", " pass_fail_thresholds=0.7,\n", @@ -987,7 +984,7 @@ "source": [ "# extracting the ROC curves and AUROC results for all the slices\n", "roc_curves = {\n", - " slice_name: slice_results[\"BinaryROCCurve\"]\n", + " slice_name: slice_results[\"BinaryROC\"]\n", " for slice_name, slice_results in results[model_name].items()\n", "}\n", "aurocs = {\n", @@ -1060,7 +1057,7 @@ "overall_performance = {\n", " metric_name: metric_value\n", " for metric_name, metric_value in results[model_name][\"overall\"].items()\n", - " if metric_name not in [\"BinaryROCCurve\", \"BinaryPrecisionRecallCurve\"]\n", + " if metric_name not in [\"BinaryROC\", \"BinaryPrecisionRecallCurve\"]\n", "}" ] }, @@ -1094,7 +1091,7 @@ " slice_name: {\n", " metric_name: metric_value\n", " for metric_name, metric_value in slice_results.items()\n", - " if metric_name not in [\"BinaryROCCurve\", \"BinaryPrecisionRecallCurve\"]\n", + " if metric_name not in [\"BinaryROC\", \"BinaryPrecisionRecallCurve\"]\n", " }\n", " for slice_name, slice_results in results[model_name].items()\n", "}" diff --git a/docs/source/tutorials/nihcxr/cxr_classification.ipynb b/docs/source/tutorials/nihcxr/cxr_classification.ipynb index 32974c7e9..c4e34479f 100644 --- a/docs/source/tutorials/nihcxr/cxr_classification.ipynb +++ b/docs/source/tutorials/nihcxr/cxr_classification.ipynb @@ -29,7 +29,6 @@ "\n", "import shutil\n", "from functools import partial\n", - "from typing import Optional\n", "\n", "import numpy as np\n", "import plotly.express as px\n", @@ -45,7 +44,6 @@ "from cyclops.data.utils import apply_transforms\n", "from cyclops.evaluate import evaluator\n", "from cyclops.evaluate.metrics.factory import create_metric\n", - "from cyclops.evaluate.metrics.stat_scores import MultilabelStatScores\n", "from cyclops.models.wrappers import PTModel\n", "from cyclops.report import ModelCardReport" ] @@ -217,77 +215,35 @@ "]\n", "\n", "\n", - "class MultilabelPositivePredictiveValue(\n", - " MultilabelStatScores,\n", - " registry_key=\"positive_predictive_value\",\n", - "):\n", - " \"\"\"Compute the recall score for multilabel classification tasks.\"\"\"\n", - "\n", - " def __init__(\n", - " self,\n", - " num_labels: int,\n", - " threshold: float = 0.5,\n", - " top_k: Optional[int] = None,\n", - " ) -> None:\n", - " \"\"\"Initialize the metric.\"\"\"\n", - " super().__init__(\n", - " num_labels=num_labels,\n", - " threshold=threshold,\n", - " top_k=top_k,\n", - " labelwise=True,\n", - " )\n", - "\n", - " def compute(self): # type: ignore[override]\n", - " \"\"\"Compute the recall score from the state.\"\"\"\n", - " tp, fp, tn, fn = self._final_state()\n", - " return tp / (tp + fp)\n", - "\n", - "\n", - "class MultilabelNegativePredictiveValue(\n", - " MultilabelStatScores,\n", - " registry_key=\"negative_predictive_value\",\n", - "):\n", - " \"\"\"Compute the recall score for multilabel classification tasks.\"\"\"\n", - "\n", - " def __init__(\n", - " self,\n", - " num_labels: int,\n", - " threshold: float = 0.5,\n", - " top_k: Optional[int] = None,\n", - " ) -> None:\n", - " \"\"\"Initialize the metric.\"\"\"\n", - " super().__init__(\n", - " num_labels=num_labels,\n", - " threshold=threshold,\n", - " top_k=top_k,\n", - " labelwise=True,\n", - " )\n", - "\n", - " def compute(self): # type: ignore[override]\n", - " \"\"\"Compute the recall score from the state.\"\"\"\n", - " tp, fp, tn, fn = self._final_state()\n", - " return tn / (tn + fn)\n", - "\n", - "\n", - "ppv = MultilabelPositivePredictiveValue(\n", - " num_labels=len(pathologies),\n", + "num_labels = len(pathologies)\n", + "ppv = create_metric(\n", + " metric_name=\"multilabel_ppv\",\n", + " experimental=True,\n", + " num_labels=num_labels,\n", + " average=None,\n", ")\n", "\n", - "npv = MultilabelNegativePredictiveValue(\n", - " num_labels=len(pathologies),\n", + "npv = create_metric(\n", + " metric_name=\"multilabel_npv\",\n", + " experimental=True,\n", + " num_labels=num_labels,\n", + " average=None,\n", ")\n", "\n", "specificity = create_metric(\n", - " metric_name=\"specificity\",\n", - " task=\"multilabel\",\n", - " num_labels=len(pathologies),\n", + " metric_name=\"multilabel_specificity\",\n", + " experimental=True,\n", + " num_labels=num_labels,\n", + " average=None,\n", ")\n", "\n", "sensitivity = create_metric(\n", - " metric_name=\"sensitivity\",\n", - " task=\"multilabel\",\n", - " num_labels=len(pathologies),\n", + " metric_name=\"multilabel_sensitivity\",\n", + " experimental=True,\n", + " num_labels=num_labels,\n", + " average=None,\n", ")\n", + "\n", "# create the slice functions\n", "slice_spec = SliceSpec(spec_list=slices)\n", "\n", @@ -479,15 +435,15 @@ "for name, metric in results_flat.items():\n", " split, name = name.split(\"/\") # noqa: PLW2901\n", " descriptions = {\n", - " \"MultilabelPositivePredictiveValue\": \"The proportion of correctly predicted positive instances among all instances predicted as positive. Also known as precision.\",\n", - " \"MultilabelNegativePredictiveValue\": \"The proportion of correctly predicted negative instances among all instances predicted as negative.\",\n", + " \"MultilabelPPV\": \"The proportion of correctly predicted positive instances among all instances predicted as positive. Also known as precision.\",\n", + " \"MultilabelNPV\": \"The proportion of correctly predicted negative instances among all instances predicted as negative.\",\n", " \"MultilabelSensitivity\": \"The proportion of actual positive instances that are correctly predicted. Also known as recall or true positive rate.\",\n", " \"MultilabelSpecificity\": \"The proportion of actual negative instances that are correctly predicted.\",\n", " }\n", " report.log_quantitative_analysis(\n", " \"performance\",\n", " name=name,\n", - " value=metric,\n", + " value=metric.tolist() if isinstance(metric, np.generic) else metric,\n", " description=descriptions[name],\n", " metric_slice=split,\n", " pass_fail_thresholds=0.7,\n", diff --git a/docs/source/tutorials/nihcxr/generate_nihcxr_report.py b/docs/source/tutorials/nihcxr/generate_nihcxr_report.py index 516b276c3..584bf8be8 100644 --- a/docs/source/tutorials/nihcxr/generate_nihcxr_report.py +++ b/docs/source/tutorials/nihcxr/generate_nihcxr_report.py @@ -3,10 +3,9 @@ # get args from command line import argparse from functools import partial -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List import numpy as np -import numpy.typing as npt import plotly.express as px from torchvision.transforms import Compose from torchxrayvision.models import DenseNet @@ -20,7 +19,6 @@ from cyclops.data.utils import apply_transforms from cyclops.evaluate import evaluator from cyclops.evaluate.metrics.factory import create_metric -from cyclops.evaluate.metrics.stat_scores import MultilabelStatScores from cyclops.models.wrappers import PTModel # type: ignore[attr-defined] from cyclops.report import ModelCardReport # type: ignore[attr-defined] @@ -92,80 +90,40 @@ {"Patient Gender": {"value": "F"}}, ] +num_labels = len(pathologies) +ppv = create_metric( + metric_name="multilabel_ppv", + experimental=True, + num_labels=num_labels, + average=None, +) -class MultilabelPositivePredictiveValue( - MultilabelStatScores, - registry_key="positive_predictive_value", -): - """Compute the recall score for multilabel classification tasks.""" - - def __init__( - self, - num_labels: int, - threshold: float = 0.5, - top_k: Optional[int] = None, - ) -> None: - """Initialize the metric.""" - super().__init__( - num_labels=num_labels, - threshold=threshold, - top_k=top_k, - labelwise=True, - ) - - def compute(self) -> npt.NDArray[np.int_]: - """Compute the recall score from the state.""" - tp, fp, tn, fn = self._final_state() - return tp / (tp + fp) # type: ignore[return-value] - - -class MultilabelNegativePredictiveValue( - MultilabelStatScores, - registry_key="negative_predictive_value", -): - """Compute the recall score for multilabel classification tasks.""" - - def __init__( - self, - num_labels: int, - threshold: float = 0.5, - top_k: Optional[int] = None, - ) -> None: - """Initialize the metric.""" - super().__init__( - num_labels=num_labels, - threshold=threshold, - top_k=top_k, - labelwise=True, - ) - - def compute(self) -> npt.NDArray[np.int_]: - """Compute the recall score from the state.""" - tp, fp, tn, fn = self._final_state() - return tn / (tn + fn) # type: ignore[return-value] - - -ppv = MultilabelPositivePredictiveValue(num_labels=len(pathologies)) - -npv = MultilabelNegativePredictiveValue(num_labels=len(pathologies)) +npv = create_metric( + metric_name="multilabel_npv", + experimental=True, + num_labels=num_labels, + average=None, +) specificity = create_metric( - metric_name="specificity", - task="multilabel", - num_labels=len(pathologies), + metric_name="multilabel_specificity", + experimental=True, + num_labels=num_labels, + average=None, ) sensitivity = create_metric( - metric_name="sensitivity", - task="multilabel", - num_labels=len(pathologies), + metric_name="multilabel_sensitivity", + experimental=True, + num_labels=num_labels, + average=None, ) # create the slice functions slice_spec = SliceSpec(spec_list=slices_sex) nih_eval_results_gender = evaluator.evaluate( dataset=nih_ds, - metrics=[ppv, npv, sensitivity, specificity], + metrics=[ppv, npv, sensitivity, specificity], # type: ignore[list-item] target_columns=pathologies, prediction_columns="predictions.densenet", ignore_columns="image", @@ -208,7 +166,7 @@ def compute(self) -> npt.NDArray[np.int_]: nih_eval_results_age = evaluator.evaluate( dataset=nih_ds, - metrics=[ppv, npv, sensitivity, specificity], + metrics=[ppv, npv, sensitivity, specificity], # type: ignore[list-item] target_columns=pathologies, prediction_columns="predictions.densenet", ignore_columns="image", @@ -286,15 +244,15 @@ def compute(self) -> npt.NDArray[np.int_]: for name, metric in results_flat.items(): split, name = name.split("/") # noqa: PLW2901 descriptions = { - "MultilabelPositivePredictiveValue": "The proportion of correctly predicted positive instances among all instances predicted as positive. Also known as precision.", - "MultilabelNegativePredictiveValue": "The proportion of correctly predicted negative instances among all instances predicted as negative.", + "MultilabelPPV": "The proportion of correctly predicted positive instances among all instances predicted as positive. Also known as precision.", + "MultilabelNPV": "The proportion of correctly predicted negative instances among all instances predicted as negative.", "MultilabelSensitivity": "The proportion of actual positive instances that are correctly predicted. Also known as recall or true positive rate.", "MultilabelSpecificity": "The proportion of actual negative instances that are correctly predicted.", } report.log_quantitative_analysis( "performance", name=name, - value=metric, + value=metric.tolist() if isinstance(metric, np.generic) else metric, description=descriptions[name], metric_slice=split, pass_fail_thresholds=0.7, diff --git a/docs/source/tutorials/synthea/los_prediction.ipynb b/docs/source/tutorials/synthea/los_prediction.ipynb index 1c0bf4166..9f32ddaf7 100644 --- a/docs/source/tutorials/synthea/los_prediction.ipynb +++ b/docs/source/tutorials/synthea/los_prediction.ipynb @@ -55,7 +55,8 @@ "from cyclops.data.df.feature import TabularFeatures\n", "from cyclops.data.slicer import SliceSpec\n", "from cyclops.evaluate.fairness import FairnessConfig # noqa: E402\n", - "from cyclops.evaluate.metrics import MetricCollection, create_metric\n", + "from cyclops.evaluate.metrics import create_metric\n", + "from cyclops.evaluate.metrics.experimental.metric_dict import MetricDict\n", "from cyclops.models.catalog import create_model\n", "from cyclops.report import ModelCardReport\n", "from cyclops.report.plot.classification import ClassificationPlotter\n", @@ -908,7 +909,7 @@ "\n", "Evaluation is done using various evaluation metrics that provide different perspectives on the model's predictive abilities i.e. standard performance metrics and fairness metrics.\n", "\n", - "The standard performance metrics can be created using the `MetricCollection` object." + "The standard performance metrics can be created using the `MetricDict` object." ] }, { @@ -921,17 +922,19 @@ "outputs": [], "source": [ "metric_names = [\n", - " \"accuracy\",\n", - " \"precision\",\n", - " \"recall\",\n", - " \"f1_score\",\n", - " \"auroc\",\n", - " \"roc_curve\",\n", - " \"precision_recall_curve\",\n", - " \"stat_scores\",\n", + " \"binary_accuracy\",\n", + " \"binary_precision\",\n", + " \"binary_recall\",\n", + " \"binary_f1_score\",\n", + " \"binary_auroc\",\n", + " \"binary_roc_curve\",\n", + " \"binary_precision_recall_curve\",\n", + " \"binary_confusion_matrix\",\n", "]\n", - "metrics = [create_metric(metric_name, task=\"binary\") for metric_name in metric_names]\n", - "metric_collection = MetricCollection(metrics)" + "metrics = [\n", + " create_metric(metric_name, experimental=True) for metric_name in metric_names\n", + "]\n", + "metric_collection = MetricDict(metrics)" ] }, { @@ -979,7 +982,7 @@ "id": "67bd7806-c480-4c47-8e33-6612c2ede93e", "metadata": {}, "source": [ - "A `MetricCollection` can also be defined for the fairness metrics." + "A `MetricDict` can also be defined for the fairness metrics." ] }, { @@ -991,18 +994,14 @@ }, "outputs": [], "source": [ - "specificity = create_metric(\n", - " metric_name=\"specificity\",\n", - " task=\"binary\",\n", - ")\n", - "sensitivity = create_metric(\n", - " metric_name=\"sensitivity\",\n", - " task=\"binary\",\n", - ")\n", - "fpr = 1 - specificity\n", - "fnr = 1 - sensitivity\n", + "specificity = create_metric(metric_name=\"binary_specificity\", experimental=True)\n", + "sensitivity = create_metric(metric_name=\"binary_sensitivity\", experimental=True)\n", + "fpr = (\n", + " -specificity + 1\n", + ") # rsub is not supported due to limitations in the array API standard\n", + "fnr = -sensitivity + 1\n", "ber = (fpr + fnr) / 2\n", - "fairness_metric_collection = MetricCollection(\n", + "fairness_metric_collection = MetricDict(\n", " {\n", " \"Sensitivity\": sensitivity,\n", " \"Specificity\": specificity,\n", @@ -1095,7 +1094,7 @@ "model_name = f\"model_for_preds.{model_name}\"\n", "results_flat = flatten_results_dict(\n", " results=results,\n", - " remove_metrics=[\"BinaryROCCurve\", \"BinaryPrecisionRecallCurve\"],\n", + " remove_metrics=[\"BinaryROC\", \"BinaryPrecisionRecallCurve\"],\n", " model_name=model_name,\n", ")" ] @@ -1111,7 +1110,7 @@ "source": [ "for name, metric in results_flat.items():\n", " split, name = name.split(\"/\") # noqa: PLW2901\n", - " if name == \"BinaryStatScores\":\n", + " if name == \"BinaryConfusionMatrix\":\n", " continue\n", " descriptions = {\n", " \"BinaryPrecision\": \"The proportion of predicted positive instances that are correctly predicted.\",\n", @@ -1123,7 +1122,7 @@ " report.log_quantitative_analysis(\n", " \"performance\",\n", " name=name,\n", - " value=metric,\n", + " value=metric.tolist(),\n", " description=descriptions[name],\n", " metric_slice=split,\n", " pass_fail_thresholds=0.7,\n", @@ -1163,7 +1162,7 @@ "source": [ "# extracting the ROC curves and AUROC results for all the slices\n", "roc_curves = {\n", - " slice_name: slice_results[\"BinaryROCCurve\"]\n", + " slice_name: slice_results[\"BinaryROC\"]\n", " for slice_name, slice_results in results[model_name].items()\n", "}\n", "aurocs = {\n", @@ -1181,8 +1180,7 @@ "outputs": [], "source": [ "# Plot confusion matrix\n", - "tp, fp, tn, fn, _ = results[model_name][\"overall\"][\"BinaryStatScores\"]\n", - "confusion_matrix = np.array([[tn, fp], [fn, tp]])\n", + "confusion_matrix = results[model_name][\"overall\"][\"BinaryConfusionMatrix\"]\n", "conf_plot = plotter.plot_confusion_matrix(\n", " confusion_matrix,\n", ")\n", @@ -1225,7 +1223,7 @@ " metric_name: metric_value\n", " for metric_name, metric_value in results[model_name][\"overall\"].items()\n", " if metric_name\n", - " not in [\"BinaryROCCurve\", \"BinaryPrecisionRecallCurve\", \"BinaryStatScores\"]\n", + " not in [\"BinaryROC\", \"BinaryPrecisionRecallCurve\", \"BinaryConfusionMatrix\"]\n", "}" ] }, @@ -1262,7 +1260,7 @@ " metric_name: metric_value\n", " for metric_name, metric_value in slice_results.items()\n", " if metric_name\n", - " not in [\"BinaryROCCurve\", \"BinaryPrecisionRecallCurve\", \"BinaryStatScores\"]\n", + " not in [\"BinaryROCCurve\", \"BinaryPrecisionRecallCurve\", \"BinaryConfusionMatrix\"]\n", " }\n", " for slice_name, slice_results in results[model_name].items()\n", "}" diff --git a/poetry.lock b/poetry.lock index a6e32588f..102922b3c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "aiofiles" @@ -2408,7 +2408,6 @@ optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*" files = [ {file = "jsonpointer-2.4-py2.py3-none-any.whl", hash = "sha256:15d51bba20eea3165644553647711d150376234112651b4f1811022aecad7d7a"}, - {file = "jsonpointer-2.4.tar.gz", hash = "sha256:585cee82b70211fa9e6043b7bb89db6e1aa49524340dde8ad6b63206ea689d88"}, ] [[package]] @@ -3156,16 +3155,6 @@ files = [ {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-win32.whl", hash = "sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"}, @@ -3678,13 +3667,13 @@ test = ["flaky", "ipykernel (>=6.19.3)", "ipython", "ipywidgets", "nbconvert (>= [[package]] name = "nbconvert" -version = "7.13.0" +version = "7.14.2" description = "Converting Jupyter Notebooks" optional = false python-versions = ">=3.8" files = [ - {file = "nbconvert-7.13.0-py3-none-any.whl", hash = "sha256:22521cfcc10ba5755e44acb6a70d2bd8a891ce7aed6746481e10cd548b169e19"}, - {file = "nbconvert-7.13.0.tar.gz", hash = "sha256:c6f61c86fca5b28bd17f4f9a308248e59fa2b54919e1589f6cc3575c5dfec2bd"}, + {file = "nbconvert-7.14.2-py3-none-any.whl", hash = "sha256:db28590cef90f7faf2ebbc71acd402cbecf13d29176df728c0a9025a49345ea1"}, + {file = "nbconvert-7.14.2.tar.gz", hash = "sha256:a7f8808fd4e082431673ac538400218dd45efd076fbeb07cc6e5aa5a3a4e949e"}, ] [package.dependencies] @@ -3766,13 +3755,13 @@ toolchain = ["black", "blacken-docs", "flake8", "isort", "jupytext", "mypy", "py [[package]] name = "nbsphinx" -version = "0.8.12" +version = "0.9.3" description = "Jupyter Notebook Tools for Sphinx" optional = false python-versions = ">=3.6" files = [ - {file = "nbsphinx-0.8.12-py3-none-any.whl", hash = "sha256:c15b681c7fce287000856f91fe1edac50d29f7b0c15bbc746fbe55c8eb84750b"}, - {file = "nbsphinx-0.8.12.tar.gz", hash = "sha256:76570416cdecbeb21dbf5c3d6aa204ced6c1dd7ebef4077b5c21b8c6ece9533f"}, + {file = "nbsphinx-0.9.3-py3-none-any.whl", hash = "sha256:6e805e9627f4a358bd5720d5cbf8bf48853989c79af557afd91a5f22e163029f"}, + {file = "nbsphinx-0.9.3.tar.gz", hash = "sha256:ec339c8691b688f8676104a367a4b8cf3ea01fd089dc28d24dec22d563b11562"}, ] [package.dependencies] @@ -4081,12 +4070,10 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.21.0", markers = "python_version <= \"3.9\" and platform_system == \"Darwin\" and platform_machine == \"arm64\""}, - {version = ">=1.21.2", markers = "python_version >= \"3.10\""}, + {version = ">=1.21.0", markers = "python_version == \"3.9\" and platform_system == \"Darwin\" and platform_machine == \"arm64\""}, {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\""}, - {version = ">=1.19.3", markers = "python_version >= \"3.6\" and platform_system == \"Linux\" and platform_machine == \"aarch64\" or python_version >= \"3.9\""}, - {version = ">=1.17.0", markers = "python_version >= \"3.7\""}, - {version = ">=1.17.3", markers = "python_version >= \"3.8\""}, + {version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\""}, + {version = ">=1.19.3", markers = "platform_system == \"Linux\" and platform_machine == \"aarch64\" and python_version >= \"3.8\" and python_version < \"3.10\" or python_version > \"3.9\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_system != \"Darwin\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_machine != \"arm64\" and python_version < \"3.10\""}, ] [[package]] @@ -4503,8 +4490,6 @@ files = [ {file = "psycopg2-2.9.9-cp310-cp310-win_amd64.whl", hash = "sha256:426f9f29bde126913a20a96ff8ce7d73fd8a216cfb323b1f04da402d452853c3"}, {file = "psycopg2-2.9.9-cp311-cp311-win32.whl", hash = "sha256:ade01303ccf7ae12c356a5e10911c9e1c51136003a9a1d92f7aa9d010fb98372"}, {file = "psycopg2-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:121081ea2e76729acfb0673ff33755e8703d45e926e416cb59bae3a86c6a4981"}, - {file = "psycopg2-2.9.9-cp312-cp312-win32.whl", hash = "sha256:d735786acc7dd25815e89cc4ad529a43af779db2e25aa7c626de864127e5a024"}, - {file = "psycopg2-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:a7653d00b732afb6fc597e29c50ad28087dcb4fbfb28e86092277a559ae4e693"}, {file = "psycopg2-2.9.9-cp37-cp37m-win32.whl", hash = "sha256:5e0d98cade4f0e0304d7d6f25bbfbc5bd186e07b38eac65379309c4ca3193efa"}, {file = "psycopg2-2.9.9-cp37-cp37m-win_amd64.whl", hash = "sha256:7e2dacf8b009a1c1e843b5213a87f7c544b2b042476ed7755be813eaf4e8347a"}, {file = "psycopg2-2.9.9-cp38-cp38-win32.whl", hash = "sha256:ff432630e510709564c01dafdbe996cb552e0b9f3f065eb89bdce5bd31fabf4c"}, @@ -4982,7 +4967,6 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, - {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -4990,15 +4974,8 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, - {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, - {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, - {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, - {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -5015,7 +4992,6 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, - {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -5023,7 +4999,6 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, - {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -6387,7 +6362,7 @@ files = [ ] [package.dependencies] -greenlet = {version = "!=0.4.17", markers = "python_version >= \"3\" and (platform_machine == \"win32\" or platform_machine == \"WIN32\" or platform_machine == \"AMD64\" or platform_machine == \"amd64\" or platform_machine == \"x86_64\" or platform_machine == \"ppc64le\" or platform_machine == \"aarch64\")"} +greenlet = {version = "!=0.4.17", markers = "python_version >= \"3\" and (platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\")"} [package.extras] aiomysql = ["aiomysql (>=0.2.0)", "greenlet (!=0.4.17)"] @@ -7691,4 +7666,4 @@ xgboost = ["xgboost"] [metadata] lock-version = "2.0" python-versions = ">=3.9, <3.11" -content-hash = "61c6f31520b5669b67e41f7842662f90df98333a7b3dcd3117e87e2d988c7d48" +content-hash = "0c2ee6790e7c45da8e7b9f7d8dfb30f3579081be4da9df5600be577be1eeac07" diff --git a/pyproject.toml b/pyproject.toml index ef989e98c..589f7be4b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ xgboost = { version = "^1.5.2", optional = true } alibi = { version = "^0.9.4", optional = true, extras = ["shap"] } alibi-detect = { version = "^0.11.0", optional = true, extras = ["torch"] } llvmlite = { version = "^0.40.0", optional = true } +nbsphinx = "^0.9.3" [tool.poetry.group.xgboost] optional = true @@ -116,7 +117,7 @@ sphinx-autodoc-typehints = "^1.24.0" myst-parser = "^2.0.0" sphinx-copybutton = "^0.5.0" sphinx-autoapi = "^2.0.0" -nbsphinx = "^0.8.11" +nbsphinx = "^0.9.3" ipython = "^8.8.0" ipykernel = "^6.23.0" kaggle = "^1.5.13" diff --git a/tests/cyclops/evaluate/metrics/experimental/test_average_precision.py b/tests/cyclops/evaluate/metrics/experimental/test_average_precision.py new file mode 100644 index 000000000..5d3d7704e --- /dev/null +++ b/tests/cyclops/evaluate/metrics/experimental/test_average_precision.py @@ -0,0 +1,503 @@ +"""Test average precision metric.""" +from functools import partial + +import array_api_compat as apc +import array_api_compat.torch +import numpy.array_api as anp +import pytest +import torch.utils.dlpack +from torchmetrics.functional.classification import ( + binary_average_precision as tm_binary_average_precision, +) +from torchmetrics.functional.classification import ( + multiclass_average_precision as tm_multiclass_average_precision, +) +from torchmetrics.functional.classification import ( + multilabel_average_precision as tm_multilabel_average_precision, +) + +from cyclops.evaluate.metrics.experimental.average_precision import ( + BinaryAveragePrecision, + MulticlassAveragePrecision, + MultilabelAveragePrecision, +) +from cyclops.evaluate.metrics.experimental.functional.average_precision import ( + binary_average_precision, + multiclass_average_precision, + multilabel_average_precision, +) +from cyclops.evaluate.metrics.experimental.utils.ops import to_int +from cyclops.evaluate.metrics.experimental.utils.validation import is_floating_point + +from ..conftest import NUM_CLASSES, NUM_LABELS +from .inputs import _binary_cases, _multiclass_cases, _multilabel_cases, _thresholds +from .testers import MetricTester, _inject_ignore_index + + +def _binary_average_precision_reference( + target, + preds, + thresholds, + ignore_index, +) -> torch.Tensor: + """Return the reference binary average precision.""" + return tm_binary_average_precision( + torch.utils.dlpack.from_dlpack(preds), + torch.utils.dlpack.from_dlpack(target), + thresholds=torch.utils.dlpack.from_dlpack(thresholds) + if apc.is_array_api_obj(thresholds) + else thresholds, + ignore_index=ignore_index, + ) + + +class TestBinaryAveragePrecision(MetricTester): + """Test binary average precision function and class.""" + + atol = 2e-7 + + @pytest.mark.parametrize("inputs", _binary_cases(xp=anp)[3:]) + @pytest.mark.parametrize("thresholds", _thresholds(xp=anp)) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_binary_average_precision_function_with_numpy_array_api_arrays( + self, + inputs, + thresholds, + ignore_index, + ) -> None: + """Test function for binary average precision using array_api arrays.""" + target, preds = inputs + + if ignore_index is not None: + if target.shape[1] == 1 and anp.any(target == ignore_index): + pytest.skip( + "When targets are single elements and 'ignore_index' in target " + "the function will raise an error because it will receive an " + "empty array after filtering out the 'ignore_index' values.", + ) + target = _inject_ignore_index(target, ignore_index) + + self.run_metric_function_implementation_test( + target, + preds, + metric_function=binary_average_precision, + metric_args={ + "thresholds": thresholds, + "ignore_index": ignore_index, + }, + reference_metric=partial( + _binary_average_precision_reference, + thresholds=thresholds, + ignore_index=ignore_index, + ), + ) + + @pytest.mark.parametrize("inputs", _binary_cases(xp=anp)[3:]) + @pytest.mark.parametrize("thresholds", _thresholds(xp=anp)) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_binary_average_precision_class_with_numpy_array_api_arrays( + self, + inputs, + thresholds, + ignore_index, + ) -> None: + """Test class for binary average precision using array_api arrays.""" + target, preds = inputs + + if ( + preds.shape[1] == 1 + and is_floating_point(preds) + and not anp.all(to_int((preds >= 0)) * to_int((preds <= 1))) + ): + pytest.skip( + "When using 0-D logits, batch result will be different from local " + "result because the `sigmoid` operation may not be applied to each " + "batch (some values may be in [0, 1] and some may not).", + ) + + if ignore_index is not None: + if target.shape[1] == 1 and anp.any(target == ignore_index): + pytest.skip( + "When targets are single elements and 'ignore_index' in target " + "the function will raise an error because it will receive an " + "empty array after filtering out the 'ignore_index' values.", + ) + target = _inject_ignore_index(target, ignore_index) + + self.run_metric_class_implementation_test( + target, + preds, + metric_class=BinaryAveragePrecision, + metric_args={ + "thresholds": thresholds, + "ignore_index": ignore_index, + }, + reference_metric=partial( + _binary_average_precision_reference, + thresholds=thresholds, + ignore_index=ignore_index, + ), + ) + + @pytest.mark.integration_test() # machine for integration tests has GPU + @pytest.mark.parametrize("inputs", _binary_cases(xp=array_api_compat.torch)[3:]) + @pytest.mark.parametrize( + "thresholds", + _thresholds(xp=array_api_compat.torch), + ) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_binary_average_precision_with_torch_tensors( + self, + inputs, + thresholds, + ignore_index, + ) -> None: + """Test binary average precision class with torch tensors.""" + target, preds = inputs + + if ( + preds.shape[1] == 1 + and is_floating_point(preds) + and not torch.all(to_int((preds >= 0)) * to_int((preds <= 1))) + ): + pytest.skip( + "When using 0-D logits, batch result will be different from local " + "result because the `sigmoid` operation may not be applied to each " + "batch (some values may be in [0, 1] and some may not).", + ) + + if ignore_index is not None: + if target.shape[1] == 1 and torch.any(target == ignore_index): + pytest.skip( + "When targets are single elements and 'ignore_index' in target " + "the function will raise an error because it will receive an " + "empty array after filtering out the 'ignore_index' values.", + ) + target = _inject_ignore_index(target, ignore_index) + + device = "cuda" if torch.cuda.is_available() else "cpu" + if isinstance(thresholds, torch.Tensor): + thresholds = thresholds.to(device) + + self.run_metric_class_implementation_test( + target, + preds, + metric_class=BinaryAveragePrecision, + metric_args={ + "thresholds": thresholds, + "ignore_index": ignore_index, + }, + reference_metric=partial( + _binary_average_precision_reference, + thresholds=thresholds, + ignore_index=ignore_index, + ), + device=device, + use_device_for_ref=True, + ) + + +def _multiclass_average_precision_reference( + target, + preds, + num_classes=NUM_CLASSES, + thresholds=None, + average="macro", + ignore_index=None, +) -> torch.Tensor: + """Return the reference multiclass average precision.""" + if preds.ndim == 1 and is_floating_point(preds): + xp = apc.array_namespace(preds) + preds = xp.argmax(preds, axis=0) + + return tm_multiclass_average_precision( + torch.utils.dlpack.from_dlpack(preds), + torch.utils.dlpack.from_dlpack(target), + num_classes, + average=average, + thresholds=torch.utils.dlpack.from_dlpack(thresholds) + if apc.is_array_api_obj(thresholds) + else thresholds, + ignore_index=ignore_index, + ) + + +class TestMulticlassAveragePrecision(MetricTester): + """Test multiclass average precision function and class.""" + + atol = 3e-8 + + @pytest.mark.parametrize("inputs", _multiclass_cases(xp=anp)[4:]) + @pytest.mark.parametrize("thresholds", _thresholds(xp=anp)) + @pytest.mark.parametrize("average", [None, "none", "macro", "weighted"]) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_multiclass_average_precision_with_numpy_array_api_arrays( + self, + inputs, + thresholds, + average, + ignore_index, + ) -> None: + """Test function for multiclass average precision.""" + target, preds = inputs + + if ignore_index is not None: + if target.shape[1] == 1 and anp.any(target == ignore_index): + pytest.skip( + "When targets are single elements and 'ignore_index' in target " + "the function will raise an error because it will receive an " + "empty array after filtering out the 'ignore_index' values.", + ) + target = _inject_ignore_index(target, ignore_index) + + self.run_metric_function_implementation_test( + target, + preds, + metric_function=multiclass_average_precision, + metric_args={ + "num_classes": NUM_CLASSES, + "thresholds": thresholds, + "average": average, + "ignore_index": ignore_index, + }, + reference_metric=partial( + _multiclass_average_precision_reference, + average=average, + thresholds=thresholds, + ignore_index=ignore_index, + ), + ) + + @pytest.mark.parametrize("inputs", _multiclass_cases(xp=anp)[4:]) + @pytest.mark.parametrize("thresholds", _thresholds(xp=anp)) + @pytest.mark.parametrize("average", [None, "none", "macro", "weighted"]) + @pytest.mark.parametrize("ignore_index", [None, 1, -1]) + def test_multiclass_average_precision_class_with_numpy_array_api_arrays( + self, + inputs, + thresholds, + average, + ignore_index, + ) -> None: + """Test class for multiclass average precision.""" + target, preds = inputs + + if ignore_index is not None: + if target.shape[1] == 1 and anp.any(target == ignore_index): + pytest.skip( + "When targets are single elements and 'ignore_index' in target " + "the function will raise an error because it will receive an " + "empty array after filtering out the 'ignore_index' values.", + ) + target = _inject_ignore_index(target, ignore_index) + + self.run_metric_class_implementation_test( + target, + preds, + metric_class=MulticlassAveragePrecision, + reference_metric=partial( + _multiclass_average_precision_reference, + average=average, + thresholds=thresholds, + ignore_index=ignore_index, + ), + metric_args={ + "num_classes": NUM_CLASSES, + "average": average, + "thresholds": thresholds, + "ignore_index": ignore_index, + }, + ) + + @pytest.mark.integration_test() # machine for integration tests has GPU + @pytest.mark.parametrize("inputs", _multiclass_cases(xp=array_api_compat.torch)[4:]) + @pytest.mark.parametrize( + "thresholds", + _thresholds(xp=array_api_compat.torch), + ) + @pytest.mark.parametrize("average", [None, "none", "macro", "weighted"]) + @pytest.mark.parametrize("ignore_index", [None, 1, -1]) + def test_multiclass_average_precision_class_with_torch_tensors( + self, + inputs, + thresholds, + average, + ignore_index, + ) -> None: + """Test class for multiclass average precision.""" + target, preds = inputs + + if ignore_index is not None: + if target.shape[1] == 1 and torch.any(target == ignore_index): + pytest.skip( + "When targets are single elements and 'ignore_index' in target " + "the function will raise an error because it will receive an " + "empty array after filtering out the 'ignore_index' values.", + ) + target = _inject_ignore_index(target, ignore_index) + + device = "cuda" if torch.cuda.is_available() else "cpu" + if isinstance(thresholds, torch.Tensor): + thresholds = thresholds.to(device) + + self.run_metric_class_implementation_test( + target, + preds, + metric_class=MulticlassAveragePrecision, + reference_metric=partial( + _multiclass_average_precision_reference, + average=average, + thresholds=thresholds, + ignore_index=ignore_index, + ), + metric_args={ + "num_classes": NUM_CLASSES, + "thresholds": thresholds, + "average": average, + "ignore_index": ignore_index, + }, + device=device, + use_device_for_ref=True, + ) + + +def _multilabel_average_precision_reference( + preds, + target, + num_labels=NUM_LABELS, + thresholds=None, + average="macro", + ignore_index=None, +) -> torch.Tensor: + """Return the reference multilabel average precision.""" + return tm_multilabel_average_precision( + torch.utils.dlpack.from_dlpack(preds), + torch.utils.dlpack.from_dlpack(target), + num_labels, + average=average, + thresholds=torch.utils.dlpack.from_dlpack(thresholds) + if apc.is_array_api_obj(thresholds) + else thresholds, + ignore_index=ignore_index, + ) + + +class TestMultilabelAveragePrecision(MetricTester): + """Test multilabel average precision function and class.""" + + atol = 2e-7 + + @pytest.mark.parametrize("inputs", _multilabel_cases(xp=anp)[2:]) + @pytest.mark.parametrize("thresholds", _thresholds(xp=anp)) + @pytest.mark.parametrize("average", [None, "none", "micro", "macro", "weighted"]) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_multilabel_average_precision_with_numpy_array_api_arrays( + self, + inputs, + thresholds, + average, + ignore_index, + ) -> None: + """Test function for multilabel average precision.""" + target, preds = inputs + + if ignore_index is not None: + target = _inject_ignore_index(target, ignore_index) + + self.run_metric_function_implementation_test( + target, + preds, + metric_function=multilabel_average_precision, + reference_metric=partial( + _multilabel_average_precision_reference, + average=average, + thresholds=thresholds, + ignore_index=ignore_index, + ), + metric_args={ + "thresholds": thresholds, + "average": average, + "num_labels": NUM_LABELS, + "ignore_index": ignore_index, + }, + ) + + @pytest.mark.parametrize("inputs", _multilabel_cases(xp=anp)[2:]) + @pytest.mark.parametrize("thresholds", _thresholds(xp=anp)) + @pytest.mark.parametrize("average", [None, "none", "micro", "macro", "weighted"]) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_multilabel_average_precision_class_with_numpy_array_api_arrays( + self, + inputs, + thresholds, + average, + ignore_index, + ) -> None: + """Test class for multilabel average precision.""" + target, preds = inputs + + if ignore_index is not None: + target = _inject_ignore_index(target, ignore_index) + + self.run_metric_class_implementation_test( + target, + preds, + metric_class=MultilabelAveragePrecision, + reference_metric=partial( + _multilabel_average_precision_reference, + thresholds=thresholds, + average=average, + ignore_index=ignore_index, + ), + metric_args={ + "thresholds": thresholds, + "average": average, + "num_labels": NUM_LABELS, + "ignore_index": ignore_index, + }, + ) + + @pytest.mark.integration_test() # machine for integration tests has GPU + @pytest.mark.parametrize("inputs", _multilabel_cases(xp=array_api_compat.torch)[2:]) + @pytest.mark.parametrize( + "thresholds", + _thresholds(xp=array_api_compat.torch), + ) + @pytest.mark.parametrize("average", [None, "none", "micro", "macro", "weighted"]) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + def test_multilabel_average_precision_class_with_torch_tensors( + self, + inputs, + thresholds, + average, + ignore_index, + ) -> None: + """Test class for multilabel average precision.""" + target, preds = inputs + + if ignore_index is not None: + target = _inject_ignore_index(target, ignore_index) + + device = "cuda" if torch.cuda.is_available() else "cpu" + if isinstance(thresholds, torch.Tensor): + thresholds = thresholds.to(device) + + self.run_metric_class_implementation_test( + target, + preds, + metric_class=MultilabelAveragePrecision, + reference_metric=partial( + _multilabel_average_precision_reference, + thresholds=thresholds, + average=average, + ignore_index=ignore_index, + ), + metric_args={ + "thresholds": thresholds, + "average": average, + "num_labels": NUM_LABELS, + "ignore_index": ignore_index, + }, + device=device, + use_device_for_ref=True, + ) diff --git a/tests/cyclops/evaluate/metrics/experimental/test_precision_recall.py b/tests/cyclops/evaluate/metrics/experimental/test_precision_recall.py index 8eb2a8e84..14c3c3a96 100644 --- a/tests/cyclops/evaluate/metrics/experimental/test_precision_recall.py +++ b/tests/cyclops/evaluate/metrics/experimental/test_precision_recall.py @@ -328,6 +328,8 @@ def _multiclass_precision_recall_reference( class TestMulticlassPrecision(MetricTester): """Test multiclass precision metric class and function.""" + atol = 6e-8 + @pytest.mark.parametrize("inputs", _multiclass_cases(xp=anp)) @pytest.mark.parametrize("top_k", [1, 2]) @pytest.mark.parametrize("average", [None, "micro", "macro", "weighted"])