Skip to content

Commit 34f124a

Browse files
committed
Removed lowerInclusive, upperInclusive params from Bucketizer, and used splits instead.
1 parent eacfcfa commit 34f124a

File tree

3 files changed

+139
-128
lines changed

3 files changed

+139
-128
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala

Lines changed: 37 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,17 @@ final class Bucketizer private[ml] (override val parent: Estimator[Bucketizer])
3939

4040
/**
4141
* Parameter for mapping continuous features into buckets. With n splits, there are n+1 buckets.
42-
* A bucket defined by splits x,y holds values in the range [x,y). Note that the splits should be
43-
* strictly increasing.
42+
* A bucket defined by splits x,y holds values in the range [x,y). Splits should be strictly
43+
* increasing. Values at -inf, inf must be explicitly provided to cover all Double values;
44+
* otherwise, values outside the splits specified will be treated as errors.
4445
* @group param
4546
*/
4647
val splits: Param[Array[Double]] = new Param[Array[Double]](this, "splits",
4748
"Split points for mapping continuous features into buckets. With n splits, there are n+1 " +
4849
"buckets. A bucket defined by splits x,y holds values in the range [x,y). The splits " +
49-
"should be strictly increasing.",
50+
"should be strictly increasing. Values at -inf, inf must be explicitly provided to cover" +
51+
" all Double values; otherwise, values outside the splits specified will be treated as" +
52+
" errors.",
5053
Bucketizer.checkSplits)
5154

5255
/** @group getParam */
@@ -55,40 +58,6 @@ final class Bucketizer private[ml] (override val parent: Estimator[Bucketizer])
5558
/** @group setParam */
5659
def setSplits(value: Array[Double]): this.type = set(splits, value)
5760

58-
/**
59-
* An indicator of the inclusiveness of negative infinite. If true, then use implicit bin
60-
* (-inf, getSplits.head). If false, then throw exception if values < getSplits.head are
61-
* encountered.
62-
* @group Param */
63-
val lowerInclusive: BooleanParam = new BooleanParam(this, "lowerInclusive",
64-
"An indicator of the inclusiveness of negative infinite. If true, then use implicit bin " +
65-
"(-inf, getSplits.head). If false, then throw exception if values < getSplits.head are " +
66-
"encountered.")
67-
setDefault(lowerInclusive -> true)
68-
69-
/** @group getParam */
70-
def getLowerInclusive: Boolean = $(lowerInclusive)
71-
72-
/** @group setParam */
73-
def setLowerInclusive(value: Boolean): this.type = set(lowerInclusive, value)
74-
75-
/**
76-
* An indicator of the inclusiveness of positive infinite. If true, then use implicit bin
77-
* [getSplits.last, inf). If false, then throw exception if values > getSplits.last are
78-
* encountered.
79-
* @group Param */
80-
val upperInclusive: BooleanParam = new BooleanParam(this, "upperInclusive",
81-
"An indicator of the inclusiveness of positive infinite. If true, then use implicit bin " +
82-
"[getSplits.last, inf). If false, then throw exception if values > getSplits.last are " +
83-
"encountered.")
84-
setDefault(upperInclusive -> true)
85-
86-
/** @group getParam */
87-
def getUpperInclusive: Boolean = $(upperInclusive)
88-
89-
/** @group setParam */
90-
def setUpperInclusive(value: Boolean): this.type = set(upperInclusive, value)
91-
9261
/** @group setParam */
9362
def setInputCol(value: String): this.type = set(inputCol, value)
9463

@@ -97,81 +66,66 @@ final class Bucketizer private[ml] (override val parent: Estimator[Bucketizer])
9766

9867
override def transform(dataset: DataFrame): DataFrame = {
9968
transformSchema(dataset.schema)
100-
val wrappedSplits = Array(Double.MinValue) ++ $(splits) ++ Array(Double.MaxValue)
10169
val bucketizer = udf { feature: Double =>
102-
Bucketizer
103-
.binarySearchForBuckets(wrappedSplits, feature, $(lowerInclusive), $(upperInclusive)) }
70+
Bucketizer.binarySearchForBuckets($(splits), feature)
71+
}
10472
val newCol = bucketizer(dataset($(inputCol)))
10573
val newField = prepOutputField(dataset.schema)
10674
dataset.withColumn($(outputCol), newCol.as($(outputCol), newField.metadata))
10775
}
10876

10977
private def prepOutputField(schema: StructType): StructField = {
110-
val innerRanges = $(splits).sliding(2).map(bucket => bucket.mkString(", ")).toArray
111-
val values = ($(lowerInclusive), $(upperInclusive)) match {
112-
case (true, true) =>
113-
Array(s"-inf, ${$(splits).head}") ++ innerRanges ++ Array(s"${$(splits).last}, inf")
114-
case (true, false) => Array(s"-inf, ${$(splits).head}") ++ innerRanges
115-
case (false, true) => innerRanges ++ Array(s"${$(splits).last}, inf")
116-
case _ => innerRanges
117-
}
118-
val attr =
119-
new NominalAttribute(name = Some($(outputCol)), isOrdinal = Some(true), values = Some(values))
78+
val buckets = $(splits).sliding(2).map(bucket => bucket.mkString(", ")).toArray
79+
val attr = new NominalAttribute(name = Some($(outputCol)), isOrdinal = Some(true),
80+
values = Some(buckets))
12081
attr.toStructField()
12182
}
12283

12384
override def transformSchema(schema: StructType): StructType = {
12485
SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
125-
require(schema.fields.forall(_.name != $(outputCol)),
126-
s"Output column ${$(outputCol)} already exists.")
127-
StructType(schema.fields :+ prepOutputField(schema))
86+
SchemaUtils.appendColumn(schema, prepOutputField(schema))
12887
}
12988
}
13089

13190
private[feature] object Bucketizer {
132-
/**
133-
* The given splits should match 1) its size is larger than zero; 2) it is ordered in a strictly
134-
* increasing way.
135-
*/
136-
private def checkSplits(splits: Array[Double]): Boolean = {
137-
if (splits.size == 0) false
138-
else if (splits.size == 1) true
139-
else {
140-
splits.foldLeft((true, Double.MinValue)) { case ((validator, prevValue), currValue) =>
141-
if (validator && prevValue < currValue) {
142-
(true, currValue)
143-
} else {
144-
(false, currValue)
145-
}
146-
}._1
91+
/** We require splits to be of length >= 3 and to be in strictly increasing order. */
92+
def checkSplits(splits: Array[Double]): Boolean = {
93+
if (splits.length < 3) {
94+
false
95+
} else {
96+
var i = 0
97+
while (i < splits.length - 1) {
98+
if (splits(i) >= splits(i + 1)) return false
99+
i += 1
100+
}
101+
true
147102
}
148103
}
149104

150105
/**
151106
* Binary searching in several buckets to place each data point.
107+
* @throws RuntimeException if a feature is < splits.head or >= splits.last
152108
*/
153-
private[feature] def binarySearchForBuckets(
109+
def binarySearchForBuckets(
154110
splits: Array[Double],
155-
feature: Double,
156-
lowerInclusive: Boolean,
157-
upperInclusive: Boolean): Double = {
158-
if ((feature < splits.head && !lowerInclusive) || (feature > splits.last && !upperInclusive)) {
159-
throw new RuntimeException(s"Feature $feature out of bound, check your features or loosen " +
160-
s"the lower/upper bound constraint.")
111+
feature: Double): Double = {
112+
// Check bounds. We make an exception for +inf so that it can exist in some bin.
113+
if ((feature < splits.head) || (feature >= splits.last && feature != Double.PositiveInfinity)) {
114+
throw new RuntimeException(s"Feature value $feature out of Bucketizer bounds" +
115+
s" [${splits.head}, ${splits.last}). Check your features, or loosen " +
116+
s"the lower/upper bound constraints.")
161117
}
162118
var left = 0
163119
var right = splits.length - 2
164-
while (left <= right) {
165-
val mid = left + (right - left) / 2
166-
val split = splits(mid)
167-
if ((feature >= split) && (feature < splits(mid + 1))) {
168-
return mid
169-
} else if (feature < split) {
170-
right = mid - 1
120+
while (left < right) {
121+
val mid = (left + right) / 2
122+
val split = splits(mid + 1)
123+
if (feature < split) {
124+
right = mid
171125
} else {
172126
left = mid + 1
173127
}
174128
}
175-
throw new RuntimeException(s"Unexpected error: failed to find a bucket for feature $feature.")
129+
left
176130
}
177131
}

mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,4 +58,15 @@ object SchemaUtils {
5858
val outputFields = schema.fields :+ StructField(colName, dataType, nullable = false)
5959
StructType(outputFields)
6060
}
61+
62+
/**
63+
* Appends a new column to the input schema. This fails if the given output column already exists.
64+
* @param schema input schema
65+
* @param col New column schema
66+
* @return new schema with the input column appended
67+
*/
68+
def appendColumn(schema: StructType, col: StructField): StructType = {
69+
require(!schema.fieldNames.contains(col.name), s"Column ${col.name} already exists.")
70+
StructType(schema.fields :+ col)
71+
}
6172
}

mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala

Lines changed: 91 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,28 @@ import scala.util.Random
2121

2222
import org.scalatest.FunSuite
2323

24+
import org.apache.spark.SparkException
2425
import org.apache.spark.mllib.linalg.Vectors
2526
import org.apache.spark.mllib.util.MLlibTestSparkContext
2627
import org.apache.spark.mllib.util.TestingUtils._
2728
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
2829

2930
class BucketizerSuite extends FunSuite with MLlibTestSparkContext {
3031

31-
test("Bucket continuous features with setter") {
32-
val sqlContext = new SQLContext(sc)
33-
val data = Array(0.1, -0.5, 0.2, -0.3, 0.8, 0.7, -0.1, -0.4, -0.9)
32+
@transient private var sqlContext: SQLContext = _
33+
34+
override def beforeAll(): Unit = {
35+
super.beforeAll()
36+
sqlContext = new SQLContext(sc)
37+
}
38+
39+
test("Bucket continuous features, without -inf,inf") {
40+
// Check a set of valid feature values.
3441
val splits = Array(-0.5, 0.0, 0.5)
35-
val bucketizedData = Array(2.0, 1.0, 2.0, 1.0, 3.0, 3.0, 1.0, 1.0, 0.0)
36-
val dataFrame: DataFrame = sqlContext.createDataFrame(
37-
data.zip(bucketizedData)).toDF("feature", "expected")
42+
val validData = Array(-0.5, -0.3, 0.0, 0.2)
43+
val expectedBuckets = Array(0.0, 0.0, 1.0, 1.0)
44+
val dataFrame: DataFrame =
45+
sqlContext.createDataFrame(validData.zip(expectedBuckets)).toDF("feature", "expected")
3846

3947
val bucketizer: Bucketizer = new Bucketizer()
4048
.setInputCol("feature")
@@ -43,58 +51,96 @@ class BucketizerSuite extends FunSuite with MLlibTestSparkContext {
4351

4452
bucketizer.transform(dataFrame).select("result", "expected").collect().foreach {
4553
case Row(x: Double, y: Double) =>
46-
assert(x === y, "The feature value is not correct after bucketing.")
54+
assert(x === y,
55+
s"The feature value is not correct after bucketing. Expected $y but found $x")
4756
}
48-
}
4957

50-
test("Binary search correctness in contrast with linear search") {
51-
val data = Array.fill(100)(Random.nextDouble())
52-
val splits = Array.fill(10)(Random.nextDouble()).sorted
53-
val wrappedSplits = Array(Double.MinValue) ++ splits ++ Array(Double.MaxValue)
54-
val bsResult = Vectors.dense(
55-
data.map(x => Bucketizer.binarySearchForBuckets(wrappedSplits, x, true, true)))
56-
val lsResult = Vectors.dense(data.map(x => BucketizerSuite.linearSearchForBuckets(splits, x)))
57-
assert(bsResult ~== lsResult absTol 1e-5)
58+
// Check for exceptions when using a set of invalid feature values.
59+
val invalidData1: Array[Double] = Array(-0.9) ++ validData
60+
val invalidData2 = Array(0.5) ++ validData
61+
val badDF1 = sqlContext.createDataFrame(invalidData1.zipWithIndex).toDF("feature", "idx")
62+
intercept[RuntimeException]{
63+
bucketizer.transform(badDF1).collect()
64+
println("Invalid feature value -0.9 was not caught as an invalid feature!")
65+
}
66+
val badDF2 = sqlContext.createDataFrame(invalidData2.zipWithIndex).toDF("feature", "idx")
67+
intercept[RuntimeException]{
68+
bucketizer.transform(badDF2).collect()
69+
println("Invalid feature value 0.5 was not caught as an invalid feature!")
70+
}
5871
}
5972

60-
test("Binary search of features at splits") {
61-
val splits = Array.fill(10)(Random.nextDouble()).sorted
62-
val data = splits
63-
val expected = Vectors.dense(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0)
64-
val wrappedSplits = Array(Double.MinValue) ++ splits ++ Array(Double.MaxValue)
65-
val result = Vectors.dense(
66-
data.map(x => Bucketizer.binarySearchForBuckets(wrappedSplits, x, true, true)))
67-
assert(result ~== expected absTol 1e-5)
73+
test("Bucket continuous features, with -inf,inf") {
74+
val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity)
75+
val validData = Array(-0.9, -0.5, -0.3, 0.0, 0.2, 0.5, 0.9)
76+
val expectedBuckets = Array(0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0)
77+
val dataFrame: DataFrame =
78+
sqlContext.createDataFrame(validData.zip(expectedBuckets)).toDF("feature", "expected")
79+
80+
val bucketizer: Bucketizer = new Bucketizer()
81+
.setInputCol("feature")
82+
.setOutputCol("result")
83+
.setSplits(splits)
84+
85+
bucketizer.transform(dataFrame).select("result", "expected").collect().foreach {
86+
case Row(x: Double, y: Double) =>
87+
assert(x === y,
88+
s"The feature value is not correct after bucketing. Expected $y but found $x")
89+
}
6890
}
6991

70-
test("Binary search of features between splits") {
71-
val data = Array.fill(10)(Random.nextDouble())
72-
val splits = Array(-0.1, 1.1)
73-
val expected = Vectors.dense(Array.fill(10)(1.0))
74-
val wrappedSplits = Array(Double.MinValue) ++ splits ++ Array(Double.MaxValue)
75-
val result = Vectors.dense(
76-
data.map(x => Bucketizer.binarySearchForBuckets(wrappedSplits, x, true, true)))
77-
assert(result ~== expected absTol 1e-5)
92+
test("Binary search correctness on hand-picked examples") {
93+
import BucketizerSuite.checkBinarySearch
94+
// length 3, with -inf
95+
checkBinarySearch(Array(Double.NegativeInfinity, 0.0, 1.0))
96+
// length 4
97+
checkBinarySearch(Array(-1.0, -0.5, 0.0, 1.0))
98+
// length 5
99+
checkBinarySearch(Array(-1.0, -0.5, 0.0, 1.0, 1.5))
100+
// length 3, with inf
101+
checkBinarySearch(Array(0.0, 1.0, Double.PositiveInfinity))
102+
// length 3, with -inf and inf
103+
checkBinarySearch(Array(Double.NegativeInfinity, 1.0, Double.PositiveInfinity))
78104
}
79105

80-
test("Binary search of features outside splits") {
81-
val data = Array.fill(5)(Random.nextDouble() + 1.1) ++ Array.fill(5)(Random.nextDouble() - 1.1)
82-
val splits = Array(0.0, 1.1)
83-
val expected = Vectors.dense(Array.fill(5)(2.0) ++ Array.fill(5)(0.0))
84-
val wrappedSplits = Array(Double.MinValue) ++ splits ++ Array(Double.MaxValue)
85-
val result = Vectors.dense(
86-
data.map(x => Bucketizer.binarySearchForBuckets(wrappedSplits, x, true, true)))
87-
assert(result ~== expected absTol 1e-5)
106+
test("Binary search correctness in contrast with linear search, on random data") {
107+
val data = Array.fill(100)(Random.nextDouble())
108+
val splits: Array[Double] = Double.NegativeInfinity +:
109+
Array.fill(10)(Random.nextDouble()).sorted :+ Double.PositiveInfinity
110+
val bsResult = Vectors.dense(data.map(x => Bucketizer.binarySearchForBuckets(splits, x)))
111+
val lsResult = Vectors.dense(data.map(x => BucketizerSuite.linearSearchForBuckets(splits, x)))
112+
assert(bsResult ~== lsResult absTol 1e-5)
88113
}
89114
}
90115

91-
private object BucketizerSuite {
92-
private def linearSearchForBuckets(splits: Array[Double], feature: Double): Double = {
116+
private object BucketizerSuite extends FunSuite {
117+
/** Brute force search for buckets. Bucket i is defined by the range [split(i), split(i+1)). */
118+
def linearSearchForBuckets(splits: Array[Double], feature: Double): Double = {
119+
require(feature >= splits.head)
93120
var i = 0
94-
while (i < splits.size) {
95-
if (feature < splits(i)) return i
121+
while (i < splits.length - 1) {
122+
if (feature < splits(i + 1)) return i
96123
i += 1
97124
}
98-
i
125+
throw new RuntimeException(
126+
s"linearSearchForBuckets failed to find bucket for feature value $feature")
127+
}
128+
129+
/** Check all values in splits, plus values between all splits. */
130+
def checkBinarySearch(splits: Array[Double]): Unit = {
131+
def testFeature(feature: Double, expectedBucket: Double): Unit = {
132+
assert(Bucketizer.binarySearchForBuckets(splits, feature) === expectedBucket,
133+
s"Expected feature value $feature to be in bucket $expectedBucket with splits:" +
134+
s" ${splits.mkString(", ")}")
135+
}
136+
var i = 0
137+
while (i < splits.length - 1) {
138+
testFeature(splits(i), i) // Split i should fall in bucket i.
139+
testFeature((splits(i) + splits(i + 1)) / 2, i) // Value between splits i,i+1 should be in i.
140+
i += 1
141+
}
142+
if (splits.last === Double.PositiveInfinity) {
143+
testFeature(Double.PositiveInfinity, splits.length - 2)
144+
}
99145
}
100146
}

0 commit comments

Comments
 (0)