Skip to content

Commit 49ea494

Browse files
Backport implementation of PR-gain AUC metrics based on http://people.cs.bris.ac.uk/~flach/PRGcurves/PRcurves.pdf
PiperOrigin-RevId: 798108522
1 parent c79cc0e commit 49ea494

File tree

3 files changed

+113
-2
lines changed

3 files changed

+113
-2
lines changed

tf_keras/metrics/confusion_metrics.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1612,13 +1612,45 @@ def result(self):
16121612
)
16131613
x = fp_rate
16141614
y = recall
1615-
else: # curve == 'PR'.
1615+
elif self.curve == metrics_utils.AUCCurve.PR:
16161616
precision = tf.math.divide_no_nan(
16171617
self.true_positives,
16181618
tf.math.add(self.true_positives, self.false_positives),
16191619
)
16201620
x = recall
16211621
y = precision
1622+
else: # curve == 'PR_GAIN'.
1623+
# Due to the hyperbolic transform, this formula is less robust than
1624+
# ROC or PR values. In particular
1625+
# 1) Both measures diverge when there are no negative examples;
1626+
# 2) Both measures diverge when there are no true positives;
1627+
# 3) Recall gain becomes negative when the recall is lower than the label
1628+
# average (i.e. when more negative examples are classified positive
1629+
# than real positives).
1630+
#
1631+
# We ignore case 1 as it is easily communicated. For case 2 we set
1632+
# recall_gain to 0 and precision_gain to 1. For case 3 we set the
1633+
# recall_gain to 0. These fixes will result in an overastimation of
1634+
# the AUC for estimateors that are anti-correlated with the label (at
1635+
# some thresholds).
1636+
#
1637+
# The scaling factor $\frac{P}{N}$ that is used to form both
1638+
# gain values.
1639+
scaling_factor = tf.math.divide_no_nan(
1640+
tf.math.add(self.true_positives, self.false_negatives),
1641+
tf.math.add(self.false_positives, self.true_negatives),
1642+
)
1643+
1644+
recall_gain = 1.0 - scaling_factor * tf.math.divide_no_nan(self.false_negatives, self.true_positives)
1645+
precision_gain = 1.0 - scaling_factor * tf.math.divide_no_nan(self.false_positives, self.true_positives)
1646+
# Handle case 2.
1647+
recall_gain = tf.where(tf.equal(self.true_positives, 0.0), tf.zeros_like(recall_gain), recall_gain)
1648+
precision_gain = tf.where(tf.equal(self.true_positives, 0.0), tf.ones_like(precision_gain), precision_gain)
1649+
# Handle case 3.
1650+
recall_gain = tf.math.maximum(recall_gain, tf.zeros_like(recall_gain))
1651+
1652+
x = recall_gain
1653+
y = precision_gain
16221654

16231655
# Find the rectangle heights based on `summation_method`.
16241656
if (

tf_keras/metrics/confusion_metrics_test.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1605,6 +1605,82 @@ def test_weighted_pr_interpolation(self):
16051605
expected_result = 2.416 / 7 + 4 / 7
16061606
self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
16071607

1608+
def test_weighted_pr_gain_majoring(self):
1609+
self.setup()
1610+
auc_obj = metrics.AUC(
1611+
num_thresholds=self.num_thresholds,
1612+
curve="PR_GAIN",
1613+
summation_method="majoring",
1614+
)
1615+
self.evaluate(tf.compat.v1.variables_initializer(auc_obj.variables))
1616+
result = auc_obj(
1617+
self.y_true, self.y_pred, sample_weight=self.sample_weight
1618+
)
1619+
1620+
# tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3]
1621+
# scaling_factor (P/N) = 7/3
1622+
# recall_gain = 1 - 7/3 [ 0/7, 3/4, 7/0 ] = [1, -3/4, -inf] -> [1, 0, 0]
1623+
# precision_gain = 1 - 7/3 [ 3/7, 0/4, 0/0 ] = [0, 1, 1]
1624+
# heights = [max(0, 1), max(1, 1)] = [1, 1]
1625+
# widths = [(1 - 0), (0 - 0)] = [1, 0]
1626+
expected_result = 1 * 1 + 0 * 1
1627+
self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
1628+
1629+
def test_weighted_pr_gain_minoring(self):
1630+
self.setup()
1631+
auc_obj = metrics.AUC(
1632+
num_thresholds=self.num_thresholds,
1633+
curve="PR_GAIN",
1634+
summation_method="minoring",
1635+
)
1636+
self.evaluate(tf.compat.v1.variables_initializer(auc_obj.variables))
1637+
result = auc_obj(
1638+
self.y_true, self.y_pred, sample_weight=self.sample_weight
1639+
)
1640+
1641+
# tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3]
1642+
# scaling_factor (P/N) = 7/3
1643+
# recall_gain = 1 - 7/3 [ 0/7, 3/4, 7/0 ] = [1, -3/4, -inf] -> [1, 0, 0]
1644+
# precision_gain = 1 - 7/3 [ 3/7, 0/4, 0/0 ] = [0, 1, 1]
1645+
# heights = [min(0, 1), min(1, 1)] = [0, 1]
1646+
# widths = [(1 - 0), (0 - 0)] = [1, 0]
1647+
expected_result = 1 * 0 + 0 * 1
1648+
self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
1649+
1650+
def test_weighted_pr_gain_interpolation(self):
1651+
self.setup()
1652+
auc_obj = metrics.AUC(num_thresholds=self.num_thresholds, curve="PR_GAIN")
1653+
self.evaluate(tf.compat.v1.variables_initializer(auc_obj.variables))
1654+
result = auc_obj(
1655+
self.y_true, self.y_pred, sample_weight=self.sample_weight
1656+
)
1657+
1658+
# tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3]
1659+
# scaling_factor (P/N) = 7/3
1660+
# recall_gain = 1 - 7/3 [ 0/7, 3/4, 7/0 ] = [1, -3/4, -inf] -> [1, 0, 0]
1661+
# precision_gain = 1 - 7/3 [ 3/7, 0/4, 0/0 ] = [0, 1, 1]
1662+
# heights = [(0+1)/2, (1+1)/2] = [0.5, 1]
1663+
# widths = [(1 - 0), (0 - 0)] = [1, 0]
1664+
expected_result = 1 * 0.5 + 0 * 1
1665+
self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
1666+
1667+
def test_pr_gain_interpolation(self):
1668+
self.setup()
1669+
auc_obj = metrics.AUC(num_thresholds=self.num_thresholds, curve="PR_GAIN")
1670+
self.evaluate(tf.compat.v1.variables_initializer(auc_obj.variables))
1671+
y_true = tf.constant([0, 0, 0, 1, 0, 1, 0, 1, 1, 1])
1672+
y_pred = tf.constant([0.1, 0.2, 0.3, 0.3, 0.4, 0.4, 0.6, 0.6, 0.8, 0.9])
1673+
result = auc_obj( y_true, y_pred)
1674+
1675+
# tp = [5, 3, 0], fp = [5, 1, 0], fn = [0, 2, 5], tn = [0, 4, 4]
1676+
# scaling_factor (P/N) = 5/5 = 1
1677+
# recall_gain = 1 - [ 0/5, 2/3, 5/0 ] = [1, 1/3, 0]
1678+
# precision_gain = 1 - [ 5/5, 1/3, 0/0 ] = [0, 2/3, 1]
1679+
# heights = [(0+2/3)/2, (2/3+1)/2] = [0.333, 0.833]
1680+
# widths = [(1 - 1/3), (1/3 - 0)] = [0.666, 0.333]
1681+
expected_result = 0.666666 * 0.3333333 + 0.3333333 * 0.8333333
1682+
self.assertAllClose(self.evaluate(result), expected_result, 1e-3)
1683+
16081684
def test_invalid_num_thresholds(self):
16091685
with self.assertRaisesRegex(
16101686
ValueError, "Argument `num_thresholds` must be an integer > 1"

tf_keras/utils/metrics_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,17 +237,20 @@ class AUCCurve(Enum):
237237

238238
ROC = "ROC"
239239
PR = "PR"
240+
PR_GAIN = "PR_GAIN"
240241

241242
@staticmethod
242243
def from_str(key):
243244
if key in ("pr", "PR"):
244245
return AUCCurve.PR
245246
elif key in ("roc", "ROC"):
246247
return AUCCurve.ROC
248+
elif key in ("pr_gain", "prgain", "PR_GAIN", "PRGAIN"):
249+
return AUCCurve.PR_GAIN
247250
else:
248251
raise ValueError(
249252
f'Invalid AUC curve value: "{key}". '
250-
'Expected values are ["PR", "ROC"]'
253+
'Expected values are ["PR", "ROC", "PR_GAIN"]'
251254
)
252255

253256

0 commit comments

Comments
 (0)