Skip to content

Commit fb30d79

Browse files
committed
fix and test binary search
1 parent 2466322 commit fb30d79

File tree

2 files changed

+38
-10
lines changed

2 files changed

+38
-10
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,12 @@ private[ml] final class Bucketizer(override val parent: Estimator[Bucketizer])
3939

4040
/**
4141
* Parameter for mapping continuous features into buckets. With n splits, there are n+1 buckets.
42-
* A bucket defined by splits x,y holds values in the range (x,y].
42+
* A bucket defined by splits x,y holds values in the range [x,y).
4343
* @group param
4444
*/
4545
val splits: Param[Array[Double]] = new Param[Array[Double]](this, "splits",
4646
"Split points for mapping continuous features into buckets. With n splits, there are n+1" +
47-
" buckets. A bucket defined by splits x,y holds values in the range (x,y].",
47+
" buckets. A bucket defined by splits x,y holds values in the range [x,y).",
4848
Bucketizer.checkSplits)
4949

5050
/** @group getParam */
@@ -85,7 +85,8 @@ private[ml] final class Bucketizer(override val parent: Estimator[Bucketizer])
8585
transformSchema(dataset.schema)
8686
val wrappedSplits = Array(Double.MinValue) ++ $(splits) ++ Array(Double.MaxValue)
8787
val bucketizer = udf { feature: Double =>
88-
Bucketizer.binarySearchForBuckets(wrappedSplits, feature) }
88+
Bucketizer
89+
.binarySearchForBuckets(wrappedSplits, feature, $(lowerInclusive), $(upperInclusive)) }
8990
val newCol = bucketizer(dataset($(inputCol)))
9091
val newField = prepOutputField(dataset.schema)
9192
dataset.withColumn($(outputCol), newCol.as($(outputCol), newField.metadata))
@@ -95,7 +96,6 @@ private[ml] final class Bucketizer(override val parent: Estimator[Bucketizer])
9596
val attr = new NominalAttribute(
9697
name = Some($(outputCol)),
9798
isOrdinal = Some(true),
98-
numValues = Some($(splits).size),
9999
values = Some($(splits).map(_.toString)))
100100

101101
attr.toStructField()
@@ -131,20 +131,27 @@ object Bucketizer {
131131
/**
132132
* Binary searching in several buckets to place each data point.
133133
*/
134-
private[feature] def binarySearchForBuckets(splits: Array[Double], feature: Double): Double = {
134+
private[feature] def binarySearchForBuckets(
135+
splits: Array[Double],
136+
feature: Double,
137+
lowerInclusive: Boolean,
138+
upperInclusive: Boolean): Double = {
139+
if ((feature < splits.head && !lowerInclusive) || (feature > splits.last && !upperInclusive))
140+
throw new Exception(s"Feature $feature out of bound, check your features or loose the" +
141+
s" lower/upper bound constraint.")
135142
var left = 0
136143
var right = splits.length - 2
137144
while (left <= right) {
138145
val mid = left + (right - left) / 2
139146
val split = splits(mid)
140-
if ((feature > split) && (feature <= splits(mid + 1))) {
147+
if ((feature >= split) && (feature < splits(mid + 1))) {
141148
return mid
142-
} else if (feature <= split) {
149+
} else if (feature < split) {
143150
right = mid - 1
144151
} else {
145152
left = mid + 1
146153
}
147154
}
148-
throw new Exception("Failed to find a bucket.")
155+
throw new Exception(s"Failed to find a bucket for feature $feature.")
149156
}
150157
}

mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,22 @@
1717

1818
package org.apache.spark.ml.feature
1919

20+
import scala.util.Random
21+
2022
import org.scalatest.FunSuite
2123

24+
import org.apache.spark.mllib.linalg.Vectors
2225
import org.apache.spark.mllib.util.MLlibTestSparkContext
26+
import org.apache.spark.mllib.util.TestingUtils._
2327
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
2428

2529
class BucketizerSuite extends FunSuite with MLlibTestSparkContext {
2630

2731
test("Bucket continuous features with setter") {
2832
val sqlContext = new SQLContext(sc)
29-
val data = Array(0.1, -0.5, 0.2, -0.3, 0.8, 0.7, -0.1, -0.4)
33+
val data = Array(0.1, -0.5, 0.2, -0.3, 0.8, 0.7, -0.1, -0.4, -0.9)
3034
val buckets = Array(-0.5, 0.0, 0.5)
31-
val bucketizedData = Array(2.0, 0.0, 2.0, 1.0, 3.0, 3.0, 1.0, 1.0)
35+
val bucketizedData = Array(2.0, 1.0, 2.0, 1.0, 3.0, 3.0, 1.0, 1.0, 0.0)
3236
val dataFrame: DataFrame = sqlContext.createDataFrame(
3337
data.zip(bucketizedData)).toDF("feature", "expected")
3438

@@ -44,6 +48,23 @@ class BucketizerSuite extends FunSuite with MLlibTestSparkContext {
4448
}
4549

4650
test("Binary search for finding buckets") {
51+
val data = Array.fill[Double](100)(Random.nextDouble())
52+
val splits = Array.fill[Double](10)(Random.nextDouble()).sorted
53+
val wrappedSplits = Array(Double.MinValue) ++ splits ++ Array(Double.MaxValue)
54+
val bsResult = Vectors.dense(
55+
data.map(x => Bucketizer.binarySearchForBuckets(wrappedSplits, x, true, true)))
56+
val lsResult = Vectors.dense(data.map(x => BucketizerSuite.linearSearchForBuckets(splits, x)))
57+
assert(bsResult ~== lsResult absTol 1e-5)
58+
}
59+
}
4760

61+
object BucketizerSuite {
62+
private def linearSearchForBuckets(splits: Array[Double], feature: Double): Double = {
63+
var i = 0
64+
while (i < splits.size) {
65+
if (feature < splits(i)) return i
66+
i += 1
67+
}
68+
i
4869
}
4970
}

0 commit comments

Comments
 (0)