@@ -36,15 +36,21 @@ class MulticlassMetricsSuite extends FunSuite with LocalSparkContext {
36
36
(1.0 , 1.0 ), (1.0 , 1.0 ), (2.0 , 2.0 ), (2.0 , 0.0 )), 2 )
37
37
val metrics = new MulticlassMetrics (scoreAndLabels)
38
38
val delta = 0.0000001
39
- val precision0 = 2.0 / (2.0 + 1.0 )
40
- val precision1 = 3.0 / (3.0 + 1.0 )
41
- val precision2 = 1.0 / (1.0 + 1.0 )
42
- val recall0 = 2.0 / (2.0 + 2.0 )
43
- val recall1 = 3.0 / (3.0 + 1.0 )
44
- val recall2 = 1.0 / (1.0 + 0.0 )
39
+ val fpRate0 = 1.0 / (9 - 4 )
40
+ val fpRate1 = 1.0 / (9 - 4 )
41
+ val fpRate2 = 1.0 / (9 - 1 )
42
+ val precision0 = 2.0 / (2 + 1 )
43
+ val precision1 = 3.0 / (3 + 1 )
44
+ val precision2 = 1.0 / (1 + 1 )
45
+ val recall0 = 2.0 / (2 + 2 )
46
+ val recall1 = 3.0 / (3 + 1 )
47
+ val recall2 = 1.0 / (1 + 0 )
45
48
val f1measure0 = 2 * precision0 * recall0 / (precision0 + recall0)
46
49
val f1measure1 = 2 * precision1 * recall1 / (precision1 + recall1)
47
50
val f1measure2 = 2 * precision2 * recall2 / (precision2 + recall2)
51
+ assert(math.abs(metrics.falsePositiveRate(0.0 ) - fpRate0) < delta)
52
+ assert(math.abs(metrics.falsePositiveRate(1.0 ) - fpRate1) < delta)
53
+ assert(math.abs(metrics.falsePositiveRate(2.0 ) - fpRate2) < delta)
48
54
assert(math.abs(metrics.precision(0.0 ) - precision0) < delta)
49
55
assert(math.abs(metrics.precision(1.0 ) - precision1) < delta)
50
56
assert(math.abs(metrics.precision(2.0 ) - precision2) < delta)
@@ -55,16 +61,16 @@ class MulticlassMetricsSuite extends FunSuite with LocalSparkContext {
55
61
assert(math.abs(metrics.fMeasure(1.0 ) - f1measure1) < delta)
56
62
assert(math.abs(metrics.fMeasure(2.0 ) - f1measure2) < delta)
57
63
assert(math.abs(metrics.recall -
58
- (2.0 + 3.0 + 1.0 ) / ((2.0 + 3.0 + 1.0 ) + (1.0 + 1.0 + 1.0 ))) < delta)
64
+ (2.0 + 3.0 + 1.0 ) / ((2 + 3 + 1 ) + (1 + 1 + 1 ))) < delta)
59
65
assert(math.abs(metrics.recall - metrics.precision) < delta)
60
66
assert(math.abs(metrics.recall - metrics.fMeasure) < delta)
61
67
assert(math.abs(metrics.recall - metrics.weightedRecall) < delta)
62
68
assert(math.abs(metrics.weightedPrecision -
63
- ((4.0 / 9.0 ) * precision0 + (4.0 / 9.0 ) * precision1 + (1.0 / 9.0 ) * precision2)) < delta)
69
+ ((4.0 / 9 ) * precision0 + (4.0 / 9 ) * precision1 + (1.0 / 9 ) * precision2)) < delta)
64
70
assert(math.abs(metrics.weightedRecall -
65
- ((4.0 / 9.0 ) * recall0 + (4.0 / 9.0 ) * recall1 + (1.0 / 9.0 ) * recall2)) < delta)
66
- assert(math.abs(metrics.weightedF1Measure -
67
- ((4.0 / 9.0 ) * f1measure0 + (4.0 / 9.0 ) * f1measure1 + (1.0 / 9.0 ) * f1measure2)) < delta)
71
+ ((4.0 / 9 ) * recall0 + (4.0 / 9 ) * recall1 + (1.0 / 9 ) * recall2)) < delta)
72
+ assert(math.abs(metrics.weightedFMeasure -
73
+ ((4.0 / 9 ) * f1measure0 + (4.0 / 9 ) * f1measure1 + (1.0 / 9 ) * f1measure2)) < delta)
68
74
assert(metrics.labels.sameElements(labels))
69
75
}
70
76
}
0 commit comments