@@ -21,20 +21,28 @@ import scala.util.Random
21
21
22
22
import org .scalatest .FunSuite
23
23
24
+ import org .apache .spark .SparkException
24
25
import org .apache .spark .mllib .linalg .Vectors
25
26
import org .apache .spark .mllib .util .MLlibTestSparkContext
26
27
import org .apache .spark .mllib .util .TestingUtils ._
27
28
import org .apache .spark .sql .{DataFrame , Row , SQLContext }
28
29
29
30
class BucketizerSuite extends FunSuite with MLlibTestSparkContext {
30
31
31
- test(" Bucket continuous features with setter" ) {
32
- val sqlContext = new SQLContext (sc)
33
- val data = Array (0.1 , - 0.5 , 0.2 , - 0.3 , 0.8 , 0.7 , - 0.1 , - 0.4 , - 0.9 )
32
+ @ transient private var sqlContext : SQLContext = _
33
+
34
+ override def beforeAll (): Unit = {
35
+ super .beforeAll()
36
+ sqlContext = new SQLContext (sc)
37
+ }
38
+
39
+ test(" Bucket continuous features, without -inf,inf" ) {
40
+ // Check a set of valid feature values.
34
41
val splits = Array (- 0.5 , 0.0 , 0.5 )
35
- val bucketizedData = Array (2.0 , 1.0 , 2.0 , 1.0 , 3.0 , 3.0 , 1.0 , 1.0 , 0.0 )
36
- val dataFrame : DataFrame = sqlContext.createDataFrame(
37
- data.zip(bucketizedData)).toDF(" feature" , " expected" )
42
+ val validData = Array (- 0.5 , - 0.3 , 0.0 , 0.2 )
43
+ val expectedBuckets = Array (0.0 , 0.0 , 1.0 , 1.0 )
44
+ val dataFrame : DataFrame =
45
+ sqlContext.createDataFrame(validData.zip(expectedBuckets)).toDF(" feature" , " expected" )
38
46
39
47
val bucketizer : Bucketizer = new Bucketizer ()
40
48
.setInputCol(" feature" )
@@ -43,58 +51,96 @@ class BucketizerSuite extends FunSuite with MLlibTestSparkContext {
43
51
44
52
bucketizer.transform(dataFrame).select(" result" , " expected" ).collect().foreach {
45
53
case Row (x : Double , y : Double ) =>
46
- assert(x === y, " The feature value is not correct after bucketing." )
54
+ assert(x === y,
55
+ s " The feature value is not correct after bucketing. Expected $y but found $x" )
47
56
}
48
- }
49
57
50
- test(" Binary search correctness in contrast with linear search" ) {
51
- val data = Array .fill(100 )(Random .nextDouble())
52
- val splits = Array .fill(10 )(Random .nextDouble()).sorted
53
- val wrappedSplits = Array (Double .MinValue ) ++ splits ++ Array (Double .MaxValue )
54
- val bsResult = Vectors .dense(
55
- data.map(x => Bucketizer .binarySearchForBuckets(wrappedSplits, x, true , true )))
56
- val lsResult = Vectors .dense(data.map(x => BucketizerSuite .linearSearchForBuckets(splits, x)))
57
- assert(bsResult ~== lsResult absTol 1e-5 )
58
+ // Check for exceptions when using a set of invalid feature values.
59
+ val invalidData1 : Array [Double ] = Array (- 0.9 ) ++ validData
60
+ val invalidData2 = Array (0.5 ) ++ validData
61
+ val badDF1 = sqlContext.createDataFrame(invalidData1.zipWithIndex).toDF(" feature" , " idx" )
62
+ intercept[RuntimeException ]{
63
+ bucketizer.transform(badDF1).collect()
64
+ println(" Invalid feature value -0.9 was not caught as an invalid feature!" )
65
+ }
66
+ val badDF2 = sqlContext.createDataFrame(invalidData2.zipWithIndex).toDF(" feature" , " idx" )
67
+ intercept[RuntimeException ]{
68
+ bucketizer.transform(badDF2).collect()
69
+ println(" Invalid feature value 0.5 was not caught as an invalid feature!" )
70
+ }
58
71
}
59
72
60
- test(" Binary search of features at splits" ) {
61
- val splits = Array .fill(10 )(Random .nextDouble()).sorted
62
- val data = splits
63
- val expected = Vectors .dense(1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 , 7.0 , 8.0 , 9.0 , 10.0 )
64
- val wrappedSplits = Array (Double .MinValue ) ++ splits ++ Array (Double .MaxValue )
65
- val result = Vectors .dense(
66
- data.map(x => Bucketizer .binarySearchForBuckets(wrappedSplits, x, true , true )))
67
- assert(result ~== expected absTol 1e-5 )
73
+ test(" Bucket continuous features, with -inf,inf" ) {
74
+ val splits = Array (Double .NegativeInfinity , - 0.5 , 0.0 , 0.5 , Double .PositiveInfinity )
75
+ val validData = Array (- 0.9 , - 0.5 , - 0.3 , 0.0 , 0.2 , 0.5 , 0.9 )
76
+ val expectedBuckets = Array (0.0 , 1.0 , 1.0 , 2.0 , 2.0 , 3.0 , 3.0 )
77
+ val dataFrame : DataFrame =
78
+ sqlContext.createDataFrame(validData.zip(expectedBuckets)).toDF(" feature" , " expected" )
79
+
80
+ val bucketizer : Bucketizer = new Bucketizer ()
81
+ .setInputCol(" feature" )
82
+ .setOutputCol(" result" )
83
+ .setSplits(splits)
84
+
85
+ bucketizer.transform(dataFrame).select(" result" , " expected" ).collect().foreach {
86
+ case Row (x : Double , y : Double ) =>
87
+ assert(x === y,
88
+ s " The feature value is not correct after bucketing. Expected $y but found $x" )
89
+ }
68
90
}
69
91
70
- test(" Binary search of features between splits" ) {
71
- val data = Array .fill(10 )(Random .nextDouble())
72
- val splits = Array (- 0.1 , 1.1 )
73
- val expected = Vectors .dense(Array .fill(10 )(1.0 ))
74
- val wrappedSplits = Array (Double .MinValue ) ++ splits ++ Array (Double .MaxValue )
75
- val result = Vectors .dense(
76
- data.map(x => Bucketizer .binarySearchForBuckets(wrappedSplits, x, true , true )))
77
- assert(result ~== expected absTol 1e-5 )
92
+ test(" Binary search correctness on hand-picked examples" ) {
93
+ import BucketizerSuite .checkBinarySearch
94
+ // length 3, with -inf
95
+ checkBinarySearch(Array (Double .NegativeInfinity , 0.0 , 1.0 ))
96
+ // length 4
97
+ checkBinarySearch(Array (- 1.0 , - 0.5 , 0.0 , 1.0 ))
98
+ // length 5
99
+ checkBinarySearch(Array (- 1.0 , - 0.5 , 0.0 , 1.0 , 1.5 ))
100
+ // length 3, with inf
101
+ checkBinarySearch(Array (0.0 , 1.0 , Double .PositiveInfinity ))
102
+ // length 3, with -inf and inf
103
+ checkBinarySearch(Array (Double .NegativeInfinity , 1.0 , Double .PositiveInfinity ))
78
104
}
79
105
80
- test(" Binary search of features outside splits" ) {
81
- val data = Array .fill(5 )(Random .nextDouble() + 1.1 ) ++ Array .fill(5 )(Random .nextDouble() - 1.1 )
82
- val splits = Array (0.0 , 1.1 )
83
- val expected = Vectors .dense(Array .fill(5 )(2.0 ) ++ Array .fill(5 )(0.0 ))
84
- val wrappedSplits = Array (Double .MinValue ) ++ splits ++ Array (Double .MaxValue )
85
- val result = Vectors .dense(
86
- data.map(x => Bucketizer .binarySearchForBuckets(wrappedSplits, x, true , true )))
87
- assert(result ~== expected absTol 1e-5 )
106
+ test(" Binary search correctness in contrast with linear search, on random data" ) {
107
+ val data = Array .fill(100 )(Random .nextDouble())
108
+ val splits : Array [Double ] = Double .NegativeInfinity +:
109
+ Array .fill(10 )(Random .nextDouble()).sorted :+ Double .PositiveInfinity
110
+ val bsResult = Vectors .dense(data.map(x => Bucketizer .binarySearchForBuckets(splits, x)))
111
+ val lsResult = Vectors .dense(data.map(x => BucketizerSuite .linearSearchForBuckets(splits, x)))
112
+ assert(bsResult ~== lsResult absTol 1e-5 )
88
113
}
89
114
}
90
115
91
- private object BucketizerSuite {
92
- private def linearSearchForBuckets (splits : Array [Double ], feature : Double ): Double = {
116
+ private object BucketizerSuite extends FunSuite {
117
+ /** Brute force search for buckets. Bucket i is defined by the range [split(i), split(i+1)). */
118
+ def linearSearchForBuckets (splits : Array [Double ], feature : Double ): Double = {
119
+ require(feature >= splits.head)
93
120
var i = 0
94
- while (i < splits.size ) {
95
- if (feature < splits(i)) return i
121
+ while (i < splits.length - 1 ) {
122
+ if (feature < splits(i + 1 )) return i
96
123
i += 1
97
124
}
98
- i
125
+ throw new RuntimeException (
126
+ s " linearSearchForBuckets failed to find bucket for feature value $feature" )
127
+ }
128
+
129
+ /** Check all values in splits, plus values between all splits. */
130
+ def checkBinarySearch (splits : Array [Double ]): Unit = {
131
+ def testFeature (feature : Double , expectedBucket : Double ): Unit = {
132
+ assert(Bucketizer .binarySearchForBuckets(splits, feature) === expectedBucket,
133
+ s " Expected feature value $feature to be in bucket $expectedBucket with splits: " +
134
+ s " ${splits.mkString(" , " )}" )
135
+ }
136
+ var i = 0
137
+ while (i < splits.length - 1 ) {
138
+ testFeature(splits(i), i) // Split i should fall in bucket i.
139
+ testFeature((splits(i) + splits(i + 1 )) / 2 , i) // Value between splits i,i+1 should be in i.
140
+ i += 1
141
+ }
142
+ if (splits.last === Double .PositiveInfinity ) {
143
+ testFeature(Double .PositiveInfinity , splits.length - 2 )
144
+ }
99
145
}
100
146
}
0 commit comments