@@ -1605,6 +1605,82 @@ def test_weighted_pr_interpolation(self):
16051605 expected_result = 2.416 / 7 + 4 / 7
16061606 self .assertAllClose (self .evaluate (result ), expected_result , 1e-3 )
16071607
1608+ def test_weighted_pr_gain_majoring (self ):
1609+ self .setup ()
1610+ auc_obj = metrics .AUC (
1611+ num_thresholds = self .num_thresholds ,
1612+ curve = "PR_GAIN" ,
1613+ summation_method = "majoring" ,
1614+ )
1615+ self .evaluate (tf .compat .v1 .variables_initializer (auc_obj .variables ))
1616+ result = auc_obj (
1617+ self .y_true , self .y_pred , sample_weight = self .sample_weight
1618+ )
1619+
1620+ # tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3]
1621+ # scaling_factor (P/N) = 7/3
1622+ # recall_gain = 1 - 7/3 [ 0/7, 3/4, 7/0 ] = [1, -3/4, -inf] -> [1, 0, 0]
1623+ # precision_gain = 1 - 7/3 [ 3/7, 0/4, 0/0 ] = [0, 1, 1]
1624+ # heights = [max(0, 1), max(1, 1)] = [1, 1]
1625+ # widths = [(1 - 0), (0 - 0)] = [1, 0]
1626+ expected_result = 1 * 1 + 0 * 1
1627+ self .assertAllClose (self .evaluate (result ), expected_result , 1e-3 )
1628+
1629+ def test_weighted_pr_gain_minoring (self ):
1630+ self .setup ()
1631+ auc_obj = metrics .AUC (
1632+ num_thresholds = self .num_thresholds ,
1633+ curve = "PR_GAIN" ,
1634+ summation_method = "minoring" ,
1635+ )
1636+ self .evaluate (tf .compat .v1 .variables_initializer (auc_obj .variables ))
1637+ result = auc_obj (
1638+ self .y_true , self .y_pred , sample_weight = self .sample_weight
1639+ )
1640+
1641+ # tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3]
1642+ # scaling_factor (P/N) = 7/3
1643+ # recall_gain = 1 - 7/3 [ 0/7, 3/4, 7/0 ] = [1, -3/4, -inf] -> [1, 0, 0]
1644+ # precision_gain = 1 - 7/3 [ 3/7, 0/4, 0/0 ] = [0, 1, 1]
1645+ # heights = [min(0, 1), min(1, 1)] = [0, 1]
1646+ # widths = [(1 - 0), (0 - 0)] = [1, 0]
1647+ expected_result = 1 * 0 + 0 * 1
1648+ self .assertAllClose (self .evaluate (result ), expected_result , 1e-3 )
1649+
1650+ def test_weighted_pr_gain_interpolation (self ):
1651+ self .setup ()
1652+ auc_obj = metrics .AUC (num_thresholds = self .num_thresholds , curve = "PR_GAIN" )
1653+ self .evaluate (tf .compat .v1 .variables_initializer (auc_obj .variables ))
1654+ result = auc_obj (
1655+ self .y_true , self .y_pred , sample_weight = self .sample_weight
1656+ )
1657+
1658+ # tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3]
1659+ # scaling_factor (P/N) = 7/3
1660+ # recall_gain = 1 - 7/3 [ 0/7, 3/4, 7/0 ] = [1, -3/4, -inf] -> [1, 0, 0]
1661+ # precision_gain = 1 - 7/3 [ 3/7, 0/4, 0/0 ] = [0, 1, 1]
1662+ # heights = [(0+1)/2, (1+1)/2] = [0.5, 1]
1663+ # widths = [(1 - 0), (0 - 0)] = [1, 0]
1664+ expected_result = 1 * 0.5 + 0 * 1
1665+ self .assertAllClose (self .evaluate (result ), expected_result , 1e-3 )
1666+
1667+ def test_pr_gain_interpolation (self ):
1668+ self .setup ()
1669+ auc_obj = metrics .AUC (num_thresholds = self .num_thresholds , curve = "PR_GAIN" )
1670+ self .evaluate (tf .compat .v1 .variables_initializer (auc_obj .variables ))
1671+ y_true = tf .constant ([0 , 0 , 0 , 1 , 0 , 1 , 0 , 1 , 1 , 1 ])
1672+ y_pred = tf .constant ([0.1 , 0.2 , 0.3 , 0.3 , 0.4 , 0.4 , 0.6 , 0.6 , 0.8 , 0.9 ])
1673+ result = auc_obj ( y_true , y_pred )
1674+
1675+ # tp = [5, 3, 0], fp = [5, 1, 0], fn = [0, 2, 5], tn = [0, 4, 4]
1676+ # scaling_factor (P/N) = 5/5 = 1
1677+ # recall_gain = 1 - [ 0/5, 2/3, 5/0 ] = [1, 1/3, 0]
1678+ # precision_gain = 1 - [ 5/5, 1/3, 0/0 ] = [0, 2/3, 1]
1679+ # heights = [(0+2/3)/2, (2/3+1)/2] = [0.333, 0.833]
1680+ # widths = [(1 - 1/3), (1/3 - 0)] = [0.666, 0.333]
1681+ expected_result = 0.666666 * 0.3333333 + 0.3333333 * 0.8333333
1682+ self .assertAllClose (self .evaluate (result ), expected_result , 1e-3 )
1683+
16081684 def test_invalid_num_thresholds (self ):
16091685 with self .assertRaisesRegex (
16101686 ValueError , "Argument `num_thresholds` must be an integer > 1"
0 commit comments