Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion tf_keras/metrics/confusion_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1612,13 +1612,45 @@ def result(self):
)
x = fp_rate
y = recall
else: # curve == 'PR'.
elif self.curve == metrics_utils.AUCCurve.PR:
precision = tf.math.divide_no_nan(
self.true_positives,
tf.math.add(self.true_positives, self.false_positives),
)
x = recall
y = precision
else: # curve == 'PR_GAIN'.
# Due to the hyperbolic transform, this formula is less robust than
# ROC or PR values. In particular
# 1) Both measures diverge when there are no negative examples;
# 2) Both measures diverge when there are no true positives;
# 3) Recall gain becomes negative when the recall is lower than the label
# average (i.e. when more negative examples are classified positive
# than real positives).
#
# We ignore case 1 as it is easily communicated. For case 2 we set
# recall_gain to 0 and precision_gain to 1. For case 3 we set the
# recall_gain to 0. These fixes will result in an overastimation of
# the AUC for estimateors that are anti-correlated with the label (at
# some thresholds).
#
# The scaling factor $\frac{P}{N}$ that is used to form both
# gain values.
scaling_factor = tf.math.divide_no_nan(
tf.math.add(self.true_positives, self.false_negatives),
tf.math.add(self.false_positives, self.true_negatives),
)

recall_gain = 1.0 - scaling_factor * tf.math.divide_no_nan(self.false_negatives, self.true_positives)
precision_gain = 1.0 - scaling_factor * tf.math.divide_no_nan(self.false_positives, self.true_positives)
# Handle case 2.
recall_gain = tf.where(tf.equal(self.true_positives, 0.0), tf.zeros_like(recall_gain), recall_gain)
precision_gain = tf.where(tf.equal(self.true_positives, 0.0), tf.ones_like(precision_gain), precision_gain)
# Handle case 3.
recall_gain = tf.math.maximum(recall_gain, tf.zeros_like(recall_gain))

x = recall_gain
y = precision_gain

# Find the rectangle heights based on `summation_method`.
if (
Expand Down
76 changes: 76 additions & 0 deletions tf_keras/metrics/confusion_metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1605,6 +1605,82 @@ def test_weighted_pr_interpolation(self):
expected_result = 2.416 / 7 + 4 / 7
self.assertAllClose(self.evaluate(result), expected_result, 1e-3)

def test_weighted_pr_gain_majoring(self):
self.setup()
auc_obj = metrics.AUC(
num_thresholds=self.num_thresholds,
curve="PR_GAIN",
summation_method="majoring",
)
self.evaluate(tf.compat.v1.variables_initializer(auc_obj.variables))
result = auc_obj(
self.y_true, self.y_pred, sample_weight=self.sample_weight
)

# tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3]
# scaling_factor (P/N) = 7/3
# recall_gain = 1 - 7/3 [ 0/7, 3/4, 7/0 ] = [1, -3/4, -inf] -> [1, 0, 0]
# precision_gain = 1 - 7/3 [ 3/7, 0/4, 0/0 ] = [0, 1, 1]
# heights = [max(0, 1), max(1, 1)] = [1, 1]
# widths = [(1 - 0), (0 - 0)] = [1, 0]
expected_result = 1 * 1 + 0 * 1
self.assertAllClose(self.evaluate(result), expected_result, 1e-3)

def test_weighted_pr_gain_minoring(self):
self.setup()
auc_obj = metrics.AUC(
num_thresholds=self.num_thresholds,
curve="PR_GAIN",
summation_method="minoring",
)
self.evaluate(tf.compat.v1.variables_initializer(auc_obj.variables))
result = auc_obj(
self.y_true, self.y_pred, sample_weight=self.sample_weight
)

# tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3]
# scaling_factor (P/N) = 7/3
# recall_gain = 1 - 7/3 [ 0/7, 3/4, 7/0 ] = [1, -3/4, -inf] -> [1, 0, 0]
# precision_gain = 1 - 7/3 [ 3/7, 0/4, 0/0 ] = [0, 1, 1]
# heights = [min(0, 1), min(1, 1)] = [0, 1]
# widths = [(1 - 0), (0 - 0)] = [1, 0]
expected_result = 1 * 0 + 0 * 1
self.assertAllClose(self.evaluate(result), expected_result, 1e-3)

def test_weighted_pr_gain_interpolation(self):
self.setup()
auc_obj = metrics.AUC(num_thresholds=self.num_thresholds, curve="PR_GAIN")
self.evaluate(tf.compat.v1.variables_initializer(auc_obj.variables))
result = auc_obj(
self.y_true, self.y_pred, sample_weight=self.sample_weight
)

# tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3]
# scaling_factor (P/N) = 7/3
# recall_gain = 1 - 7/3 [ 0/7, 3/4, 7/0 ] = [1, -3/4, -inf] -> [1, 0, 0]
# precision_gain = 1 - 7/3 [ 3/7, 0/4, 0/0 ] = [0, 1, 1]
# heights = [(0+1)/2, (1+1)/2] = [0.5, 1]
# widths = [(1 - 0), (0 - 0)] = [1, 0]
expected_result = 1 * 0.5 + 0 * 1
self.assertAllClose(self.evaluate(result), expected_result, 1e-3)

def test_pr_gain_interpolation(self):
self.setup()
auc_obj = metrics.AUC(num_thresholds=self.num_thresholds, curve="PR_GAIN")
self.evaluate(tf.compat.v1.variables_initializer(auc_obj.variables))
y_true = tf.constant([0, 0, 0, 1, 0, 1, 0, 1, 1, 1])
y_pred = tf.constant([0.1, 0.2, 0.3, 0.3, 0.4, 0.4, 0.6, 0.6, 0.8, 0.9])
result = auc_obj( y_true, y_pred)

# tp = [5, 3, 0], fp = [5, 1, 0], fn = [0, 2, 5], tn = [0, 4, 4]
# scaling_factor (P/N) = 5/5 = 1
# recall_gain = 1 - [ 0/5, 2/3, 5/0 ] = [1, 1/3, 0]
# precision_gain = 1 - [ 5/5, 1/3, 0/0 ] = [0, 2/3, 1]
# heights = [(0+2/3)/2, (2/3+1)/2] = [0.333, 0.833]
# widths = [(1 - 1/3), (1/3 - 0)] = [0.666, 0.333]
expected_result = 0.666666 * 0.3333333 + 0.3333333 * 0.8333333
self.assertAllClose(self.evaluate(result), expected_result, 1e-3)

def test_invalid_num_thresholds(self):
with self.assertRaisesRegex(
ValueError, "Argument `num_thresholds` must be an integer > 1"
Expand Down
5 changes: 4 additions & 1 deletion tf_keras/utils/metrics_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,17 +237,20 @@ class AUCCurve(Enum):

ROC = "ROC"
PR = "PR"
PR_GAIN = "PR_GAIN"

@staticmethod
def from_str(key):
if key in ("pr", "PR"):
return AUCCurve.PR
elif key in ("roc", "ROC"):
return AUCCurve.ROC
elif key in ("pr_gain", "prgain", "PR_GAIN", "PRGAIN"):
return AUCCurve.PR_GAIN
else:
raise ValueError(
f'Invalid AUC curve value: "{key}". '
'Expected values are ["PR", "ROC"]'
'Expected values are ["PR", "ROC", "PR_GAIN"]'
)


Expand Down
Loading