Skip to content

Commit 35fb42a

Browse files
yinxusenjkbradley
authored andcommitted
[SPARK-5893] [ML] Add bucketizer
JIRA issue [here](https://issues.apache.org/jira/browse/SPARK-5893). One thing to make clear, the `buckets` parameter, which is an array of `Double`, performs as split points. Say, ```scala buckets = Array(-0.5, 0.0, 0.5) ``` splits the real number into 4 ranges, (-inf, -0.5], (-0.5, 0.0], (0.0, 0.5], (0.5, +inf), which is encoded as 0, 1, 2, 3. Author: Xusen Yin <yinxusen@gmail.com> Author: Joseph K. Bradley <joseph@databricks.com> Closes apache#5980 from yinxusen/SPARK-5893 and squashes the following commits: dc8c843 [Xusen Yin] Merge pull request #4 from jkbradley/yinxusen-SPARK-5893 1ca973a [Joseph K. Bradley] one more bucketizer test 34f124a [Joseph K. Bradley] Removed lowerInclusive, upperInclusive params from Bucketizer, and used splits instead. eacfcfa [Xusen Yin] change ML attribute from splits into buckets c3cc770 [Xusen Yin] add more unit test for binary search 3a16cc2 [Xusen Yin] refine comments and names ac77859 [Xusen Yin] fix style error fb30d79 [Xusen Yin] fix and test binary search 2466322 [Xusen Yin] refactor Bucketizer 11fb00a [Xusen Yin] change it into an Estimator 998bc87 [Xusen Yin] check buckets 4024cf1 [Xusen Yin] add test suite 5fe190e [Xusen Yin] add bucketizer
1 parent 87229c9 commit 35fb42a

File tree

3 files changed

+290
-0
lines changed

3 files changed

+290
-0
lines changed
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.feature
19+
20+
import org.apache.spark.annotation.AlphaComponent
21+
import org.apache.spark.ml.attribute.NominalAttribute
22+
import org.apache.spark.ml.param._
23+
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
24+
import org.apache.spark.ml.util.SchemaUtils
25+
import org.apache.spark.ml.{Estimator, Model}
26+
import org.apache.spark.sql._
27+
import org.apache.spark.sql.functions._
28+
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
29+
30+
/**
31+
* :: AlphaComponent ::
32+
* `Bucketizer` maps a column of continuous features to a column of feature buckets.
33+
*/
34+
@AlphaComponent
35+
final class Bucketizer private[ml] (override val parent: Estimator[Bucketizer])
36+
extends Model[Bucketizer] with HasInputCol with HasOutputCol {
37+
38+
def this() = this(null)
39+
40+
/**
41+
* Parameter for mapping continuous features into buckets. With n splits, there are n+1 buckets.
42+
* A bucket defined by splits x,y holds values in the range [x,y). Splits should be strictly
43+
* increasing. Values at -inf, inf must be explicitly provided to cover all Double values;
44+
* otherwise, values outside the splits specified will be treated as errors.
45+
* @group param
46+
*/
47+
val splits: Param[Array[Double]] = new Param[Array[Double]](this, "splits",
48+
"Split points for mapping continuous features into buckets. With n splits, there are n+1 " +
49+
"buckets. A bucket defined by splits x,y holds values in the range [x,y). The splits " +
50+
"should be strictly increasing. Values at -inf, inf must be explicitly provided to cover" +
51+
" all Double values; otherwise, values outside the splits specified will be treated as" +
52+
" errors.",
53+
Bucketizer.checkSplits)
54+
55+
/** @group getParam */
56+
def getSplits: Array[Double] = $(splits)
57+
58+
/** @group setParam */
59+
def setSplits(value: Array[Double]): this.type = set(splits, value)
60+
61+
/** @group setParam */
62+
def setInputCol(value: String): this.type = set(inputCol, value)
63+
64+
/** @group setParam */
65+
def setOutputCol(value: String): this.type = set(outputCol, value)
66+
67+
override def transform(dataset: DataFrame): DataFrame = {
68+
transformSchema(dataset.schema)
69+
val bucketizer = udf { feature: Double =>
70+
Bucketizer.binarySearchForBuckets($(splits), feature)
71+
}
72+
val newCol = bucketizer(dataset($(inputCol)))
73+
val newField = prepOutputField(dataset.schema)
74+
dataset.withColumn($(outputCol), newCol.as($(outputCol), newField.metadata))
75+
}
76+
77+
private def prepOutputField(schema: StructType): StructField = {
78+
val buckets = $(splits).sliding(2).map(bucket => bucket.mkString(", ")).toArray
79+
val attr = new NominalAttribute(name = Some($(outputCol)), isOrdinal = Some(true),
80+
values = Some(buckets))
81+
attr.toStructField()
82+
}
83+
84+
override def transformSchema(schema: StructType): StructType = {
85+
SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
86+
SchemaUtils.appendColumn(schema, prepOutputField(schema))
87+
}
88+
}
89+
90+
private[feature] object Bucketizer {
91+
/** We require splits to be of length >= 3 and to be in strictly increasing order. */
92+
def checkSplits(splits: Array[Double]): Boolean = {
93+
if (splits.length < 3) {
94+
false
95+
} else {
96+
var i = 0
97+
while (i < splits.length - 1) {
98+
if (splits(i) >= splits(i + 1)) return false
99+
i += 1
100+
}
101+
true
102+
}
103+
}
104+
105+
/**
106+
* Binary searching in several buckets to place each data point.
107+
* @throws RuntimeException if a feature is < splits.head or >= splits.last
108+
*/
109+
def binarySearchForBuckets(
110+
splits: Array[Double],
111+
feature: Double): Double = {
112+
// Check bounds. We make an exception for +inf so that it can exist in some bin.
113+
if ((feature < splits.head) || (feature >= splits.last && feature != Double.PositiveInfinity)) {
114+
throw new RuntimeException(s"Feature value $feature out of Bucketizer bounds" +
115+
s" [${splits.head}, ${splits.last}). Check your features, or loosen " +
116+
s"the lower/upper bound constraints.")
117+
}
118+
var left = 0
119+
var right = splits.length - 2
120+
while (left < right) {
121+
val mid = (left + right) / 2
122+
val split = splits(mid + 1)
123+
if (feature < split) {
124+
right = mid
125+
} else {
126+
left = mid + 1
127+
}
128+
}
129+
left
130+
}
131+
}

mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,4 +58,15 @@ object SchemaUtils {
5858
val outputFields = schema.fields :+ StructField(colName, dataType, nullable = false)
5959
StructType(outputFields)
6060
}
61+
62+
/**
63+
* Appends a new column to the input schema. This fails if the given output column already exists.
64+
* @param schema input schema
65+
* @param col New column schema
66+
* @return new schema with the input column appended
67+
*/
68+
def appendColumn(schema: StructType, col: StructField): StructType = {
69+
require(!schema.fieldNames.contains(col.name), s"Column ${col.name} already exists.")
70+
StructType(schema.fields :+ col)
71+
}
6172
}
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.feature
19+
20+
import scala.util.Random
21+
22+
import org.scalatest.FunSuite
23+
24+
import org.apache.spark.SparkException
25+
import org.apache.spark.mllib.linalg.Vectors
26+
import org.apache.spark.mllib.util.MLlibTestSparkContext
27+
import org.apache.spark.mllib.util.TestingUtils._
28+
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
29+
30+
class BucketizerSuite extends FunSuite with MLlibTestSparkContext {
31+
32+
@transient private var sqlContext: SQLContext = _
33+
34+
override def beforeAll(): Unit = {
35+
super.beforeAll()
36+
sqlContext = new SQLContext(sc)
37+
}
38+
39+
test("Bucket continuous features, without -inf,inf") {
40+
// Check a set of valid feature values.
41+
val splits = Array(-0.5, 0.0, 0.5)
42+
val validData = Array(-0.5, -0.3, 0.0, 0.2)
43+
val expectedBuckets = Array(0.0, 0.0, 1.0, 1.0)
44+
val dataFrame: DataFrame =
45+
sqlContext.createDataFrame(validData.zip(expectedBuckets)).toDF("feature", "expected")
46+
47+
val bucketizer: Bucketizer = new Bucketizer()
48+
.setInputCol("feature")
49+
.setOutputCol("result")
50+
.setSplits(splits)
51+
52+
bucketizer.transform(dataFrame).select("result", "expected").collect().foreach {
53+
case Row(x: Double, y: Double) =>
54+
assert(x === y,
55+
s"The feature value is not correct after bucketing. Expected $y but found $x")
56+
}
57+
58+
// Check for exceptions when using a set of invalid feature values.
59+
val invalidData1: Array[Double] = Array(-0.9) ++ validData
60+
val invalidData2 = Array(0.5) ++ validData
61+
val badDF1 = sqlContext.createDataFrame(invalidData1.zipWithIndex).toDF("feature", "idx")
62+
intercept[RuntimeException]{
63+
bucketizer.transform(badDF1).collect()
64+
println("Invalid feature value -0.9 was not caught as an invalid feature!")
65+
}
66+
val badDF2 = sqlContext.createDataFrame(invalidData2.zipWithIndex).toDF("feature", "idx")
67+
intercept[RuntimeException]{
68+
bucketizer.transform(badDF2).collect()
69+
println("Invalid feature value 0.5 was not caught as an invalid feature!")
70+
}
71+
}
72+
73+
test("Bucket continuous features, with -inf,inf") {
74+
val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity)
75+
val validData = Array(-0.9, -0.5, -0.3, 0.0, 0.2, 0.5, 0.9)
76+
val expectedBuckets = Array(0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0)
77+
val dataFrame: DataFrame =
78+
sqlContext.createDataFrame(validData.zip(expectedBuckets)).toDF("feature", "expected")
79+
80+
val bucketizer: Bucketizer = new Bucketizer()
81+
.setInputCol("feature")
82+
.setOutputCol("result")
83+
.setSplits(splits)
84+
85+
bucketizer.transform(dataFrame).select("result", "expected").collect().foreach {
86+
case Row(x: Double, y: Double) =>
87+
assert(x === y,
88+
s"The feature value is not correct after bucketing. Expected $y but found $x")
89+
}
90+
}
91+
92+
test("Binary search correctness on hand-picked examples") {
93+
import BucketizerSuite.checkBinarySearch
94+
// length 3, with -inf
95+
checkBinarySearch(Array(Double.NegativeInfinity, 0.0, 1.0))
96+
// length 4
97+
checkBinarySearch(Array(-1.0, -0.5, 0.0, 1.0))
98+
// length 5
99+
checkBinarySearch(Array(-1.0, -0.5, 0.0, 1.0, 1.5))
100+
// length 3, with inf
101+
checkBinarySearch(Array(0.0, 1.0, Double.PositiveInfinity))
102+
// length 3, with -inf and inf
103+
checkBinarySearch(Array(Double.NegativeInfinity, 1.0, Double.PositiveInfinity))
104+
// length 4, with -inf and inf
105+
checkBinarySearch(Array(Double.NegativeInfinity, 0.0, 1.0, Double.PositiveInfinity))
106+
}
107+
108+
test("Binary search correctness in contrast with linear search, on random data") {
109+
val data = Array.fill(100)(Random.nextDouble())
110+
val splits: Array[Double] = Double.NegativeInfinity +:
111+
Array.fill(10)(Random.nextDouble()).sorted :+ Double.PositiveInfinity
112+
val bsResult = Vectors.dense(data.map(x => Bucketizer.binarySearchForBuckets(splits, x)))
113+
val lsResult = Vectors.dense(data.map(x => BucketizerSuite.linearSearchForBuckets(splits, x)))
114+
assert(bsResult ~== lsResult absTol 1e-5)
115+
}
116+
}
117+
118+
private object BucketizerSuite extends FunSuite {
119+
/** Brute force search for buckets. Bucket i is defined by the range [split(i), split(i+1)). */
120+
def linearSearchForBuckets(splits: Array[Double], feature: Double): Double = {
121+
require(feature >= splits.head)
122+
var i = 0
123+
while (i < splits.length - 1) {
124+
if (feature < splits(i + 1)) return i
125+
i += 1
126+
}
127+
throw new RuntimeException(
128+
s"linearSearchForBuckets failed to find bucket for feature value $feature")
129+
}
130+
131+
/** Check all values in splits, plus values between all splits. */
132+
def checkBinarySearch(splits: Array[Double]): Unit = {
133+
def testFeature(feature: Double, expectedBucket: Double): Unit = {
134+
assert(Bucketizer.binarySearchForBuckets(splits, feature) === expectedBucket,
135+
s"Expected feature value $feature to be in bucket $expectedBucket with splits:" +
136+
s" ${splits.mkString(", ")}")
137+
}
138+
var i = 0
139+
while (i < splits.length - 1) {
140+
testFeature(splits(i), i) // Split i should fall in bucket i.
141+
testFeature((splits(i) + splits(i + 1)) / 2, i) // Value between splits i,i+1 should be in i.
142+
i += 1
143+
}
144+
if (splits.last === Double.PositiveInfinity) {
145+
testFeature(Double.PositiveInfinity, splits.length - 2)
146+
}
147+
}
148+
}

0 commit comments

Comments
 (0)