Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 47 additions & 46 deletions src/multicalibration/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import sys

from collections.abc import Callable
from typing import Any, Dict, Protocol, Tuple
from typing import Any, Protocol

import numpy as np
import pandas as pd
Expand All @@ -27,6 +27,12 @@
CALIBRATION_ERROR_EPSILON = 0.0000001
DEFAULT_PRECISION_DTYPE = np.float64

# Kuiper distribution constants
# KUIPER_STATISTIC_MAX: Maximum statistic value before CDF is effectively 1.0
# KUIPER_STATISTIC_MIN: Minimum statistic value below which p-value is 1.0
KUIPER_STATISTIC_MAX: float = 8.26732673
KUIPER_STATISTIC_MIN: float = 1e-20


def _calibration_error(
labels: npt.NDArray,
Expand Down Expand Up @@ -195,17 +201,17 @@ def calibration_ratio(
1 - labels if adjust_unjoined else np.ones_like(predicted_scores)
)

calibration_ratio = np.sum(
ratio = np.sum(
predicted_scores * sample_weight * unjoined_adjustment_weights
) / np.sum(labels * sample_weight)
return calibration_ratio
return ratio


def recall(
labels: npt.NDArray,
predicted_labels: npt.NDArray,
sample_weight: npt.NDArray | None = None,
**kwargs: Dict[str, Any],
**kwargs: Any,
) -> float:
return skmetrics.recall_score(
y_true=labels.astype(int),
Expand All @@ -218,7 +224,7 @@ def precision(
labels: npt.NDArray,
predicted_labels: npt.NDArray,
precision_weight: npt.NDArray | None = None,
**kwargs: Dict[str, Any],
**kwargs: Any,
) -> float:
return skmetrics.precision_score(
y_true=labels.astype(int),
Expand Down Expand Up @@ -257,8 +263,8 @@ def fpr_with_mask(
if denominator == 0:
return None
fp_sr_idx = (y_pred & ~y_true & y_mask).astype(int) * sample_weight
fpr = 1.0 * fp_sr_idx.sum() / denominator
return fpr
false_positive_rate = 1.0 * fp_sr_idx.sum() / denominator
return false_positive_rate


def _dcg_sample_scores(
Expand Down Expand Up @@ -468,9 +474,9 @@ def fpr_at_precision(
threshold_at_precision_target = np.min(thresholds_at_target_precision)

false_positives = np.sum(y_scores[y_true == 0] >= threshold_at_precision_target)
fpr = false_positives / negatives
false_positive_rate = false_positives / negatives

return fpr
return false_positive_rate


def predictions_to_labels(
Expand Down Expand Up @@ -521,7 +527,7 @@ def multicalibration_error(

# Handle the case when there are no segmentation columns, in which case
# we compute the error for the entire dataset as a single segment
if len(segmentation_cols) == 0:
if not segmentation_cols:
return metric(
labels=labels,
predicted_scores=predictions,
Expand Down Expand Up @@ -594,10 +600,8 @@ def multi_cg_score(
k cannot be smaller than 1 and cannot be larger than the number of samples.
:return: an array of size n_segments with the metric score for each segment.
"""
assert metric in [
ndcg_score,
dcg_score,
], "Only ndcg_score and dcg_score are supported"
if metric not in (ndcg_score, dcg_score):
raise ValueError("Only ndcg_score and dcg_score are supported")
segments_df = segments_df.copy()
segmentation_cols = list(segments_df.columns)
segments_df["label"] = labels
Expand Down Expand Up @@ -931,11 +935,12 @@ def kuiper_distribution(x: float) -> float:
:return: cumulative distribution function evaluated at x
:rtype: float
"""
assert (
x > 0
), f"Can only evaluate cumulative Kuiper distribution at positive x, not at {x}"
if x <= 0:
raise ValueError(
f"Can only evaluate cumulative Kuiper distribution at positive x, not at {x}"
)
# If x goes to infinity, c tends to 1.0
if x >= 8.26732673:
if x >= KUIPER_STATISTIC_MAX:
return 1.0 - sys.float_info.epsilon

# Compute the machine precision assuming binary numerical representations.
Expand Down Expand Up @@ -967,7 +972,7 @@ def _normalization_method_assignment(
}
if method not in methods:
raise ValueError(
f"Unknown normalization method {method}. Available methods are {methods.keys()}"
f"Unknown normalization method {method}. Available methods are {list(methods)}"
)
return methods[method]

Expand All @@ -976,7 +981,7 @@ def kuiper_test(
labels: npt.NDArray,
predicted_scores: npt.NDArray,
sample_weight: npt.NDArray | None = None,
) -> Tuple[float, float]:
) -> tuple[float, float]:
"""
Calculates the Kuiper test statistic and p-value for the Kuiper calibration
distance. This test is used to assess how well the predicted probabilities
Expand All @@ -988,23 +993,20 @@ def kuiper_test(
:return: A tuple containing the Kuiper statistic and the corresponding p-value.
"""

KUIPER_MAX: float = 8.26732673
KUIPER_MIN: float = 1e-20

kuiper_statistic = kuiper_calibration(
kuiper_stat = kuiper_calibration(
labels,
predicted_scores,
sample_weight,
normalization_method="kuiper_standard_deviation",
)
if kuiper_statistic < KUIPER_MIN:
if kuiper_stat < KUIPER_STATISTIC_MIN:
pval = 1.0
elif kuiper_statistic > KUIPER_MAX:
elif kuiper_stat > KUIPER_STATISTIC_MAX:
pval = sys.float_info.epsilon
else:
pval = 1 - kuiper_distribution(kuiper_statistic)
pval = 1 - kuiper_distribution(kuiper_stat)

return kuiper_statistic, pval
return kuiper_stat, pval


def kuiper_statistic(
Expand Down Expand Up @@ -1316,7 +1318,8 @@ def _group_rank_calibration_error(
)
segment_RCE["weight"] = segment_RCE["segment_total_weight"] / len(labels)
segment_RCE["weighted_error"] = segment_RCE["error"] * segment_RCE["weight"]
assert np.allclose(segment_RCE["weight"].sum(), 1.0)
if not np.allclose(segment_RCE["weight"].sum(), 1.0):
raise AssertionError("Segment weights do not sum to 1.0")

return segment_RCE["weighted_error"].sum()

Expand Down Expand Up @@ -1397,8 +1400,8 @@ def normalized_entropy(
labels, baseline_predictions, sample_weight=sample_weight
)

normalized_entropy = prediction_log_loss / baseline_logloss
return normalized_entropy
ne = prediction_log_loss / baseline_logloss
return ne


def calibration_free_normalized_entropy(
Expand All @@ -1415,9 +1418,8 @@ def calibration_free_normalized_entropy(
:param predicted_scores: Predicted probabilities, as returned by a classifier's predict_proba method.
:returns: the calibration-free NE.
"""
assert (
len(labels.shape) == 1
), "y_pred must be the predicted probability for class 1 only."
if len(labels.shape) != 1:
raise ValueError("y_pred must be the predicted probability for class 1 only.")

current_calibration = calibration_ratio(labels, predicted_scores, sample_weight)

Expand Down Expand Up @@ -1770,18 +1772,17 @@ def wrap_multicalibration_error_metric(
max_n_segments: int | None = DEFAULT_MULTI_KUIPER_N_SEGMENTS,
metric_version: str = "mce",
) -> ScoreFunctionInterface:
assert (
categorical_segment_columns is not None or numerical_segment_columns is not None
), "No segment columns provided. Please provide either categorical_segment_columns or numerical_segment_columns."
assert (
metric_version
in [
"mce",
"mce_sigma_scale",
"mce_absolute",
"p_value",
]
), f"`metric_version` has to be one of ['mce', 'mce_sigma_scale', 'mce_absolute', 'p_value']. Got `{metric_version}`."
if categorical_segment_columns is None and numerical_segment_columns is None:
raise ValueError(
"No segment columns provided. Please provide either "
"categorical_segment_columns or numerical_segment_columns."
)
valid_versions = ("mce", "mce_sigma_scale", "mce_absolute", "p_value")
if metric_version not in valid_versions:
raise ValueError(
f"`metric_version` has to be one of {list(valid_versions)}. "
f"Got `{metric_version}`."
)

class WrappedFuncMCE(ScoreFunctionInterface):
name = f"Multicalibration Error<br>({metric_version})"
Expand Down
1 change: 0 additions & 1 deletion tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,6 @@ def test_multi_cg_gives_same_result_as_cg_per_segment(rank_discount, rng):
df["segment_2"] = rng.choice(["C", "D"], size=len(df))

min_segments_size = df.groupby(by=["segment_1", "segment_2"]).count().values.min()
print(f"{min_segments_size=}")
k = min(k, min_segments_size)

multi_cg_scores = metrics.multi_cg_score(
Expand Down