Skip to content

Commit 998bc87

Browse files
committed
check buckets
1 parent 4024cf1 commit 998bc87

File tree

1 file changed

+21
-3
lines changed

1 file changed

+21
-3
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,30 @@ import org.apache.spark.sql.types.{DoubleType, StructType}
3434
@AlphaComponent
3535
final class Bucketizer extends Transformer with HasInputCol with HasOutputCol {
3636

37+
/**
38+
* The given buckets should match 1) its size is larger than zero; 2) it is ordered in a non-DESC
39+
* way.
40+
*/
41+
private def checkBuckets(buckets: Array[Double]): Boolean = {
42+
if (buckets.size == 0) false
43+
else if (buckets.size == 1) true
44+
else {
45+
buckets.foldLeft((true, Double.MinValue)) { case ((validator, prevValue), currValue) =>
46+
if (validator & prevValue <= currValue) {
47+
(true, currValue)
48+
} else {
49+
(false, currValue)
50+
}
51+
}._1
52+
}
53+
}
54+
3755
/**
3856
* Parameter for mapping continuous features into buckets.
3957
* @group param
4058
*/
4159
val buckets: Param[Array[Double]] = new Param[Array[Double]](this, "buckets",
42-
"Map continuous features into buckets.")
60+
"Split points for mapping continuous features into buckets.", checkBuckets)
4361

4462
/** @group getParam */
4563
def getBuckets: Array[Double] = $(buckets)
@@ -55,7 +73,7 @@ final class Bucketizer extends Transformer with HasInputCol with HasOutputCol {
5573

5674
override def transform(dataset: DataFrame): DataFrame = {
5775
transformSchema(dataset.schema)
58-
val bucketizer = udf { feature: Double => binarySearchForBins($(buckets), feature) }
76+
val bucketizer = udf { feature: Double => binarySearchForBuckets($(buckets), feature) }
5977
val outputColName = $(outputCol)
6078
val metadata = NominalAttribute.defaultAttr
6179
.withName(outputColName).withValues($(buckets).map(_.toString)).toMetadata()
@@ -65,7 +83,7 @@ final class Bucketizer extends Transformer with HasInputCol with HasOutputCol {
6583
/**
6684
* Binary searching in several buckets to place each data point.
6785
*/
68-
private def binarySearchForBins(splits: Array[Double], feature: Double): Double = {
86+
private def binarySearchForBuckets(splits: Array[Double], feature: Double): Double = {
6987
val wrappedSplits = Array(Double.MinValue) ++ splits ++ Array(Double.MaxValue)
7088
var left = 0
7189
var right = wrappedSplits.length - 2

0 commit comments

Comments
 (0)