@@ -34,12 +34,30 @@ import org.apache.spark.sql.types.{DoubleType, StructType}
34
34
@ AlphaComponent
35
35
final class Bucketizer extends Transformer with HasInputCol with HasOutputCol {
36
36
37
+ /**
38
+ * The given buckets should match 1) its size is larger than zero; 2) it is ordered in a non-DESC
39
+ * way.
40
+ */
41
+ private def checkBuckets (buckets : Array [Double ]): Boolean = {
42
+ if (buckets.size == 0 ) false
43
+ else if (buckets.size == 1 ) true
44
+ else {
45
+ buckets.foldLeft((true , Double .MinValue )) { case ((validator, prevValue), currValue) =>
46
+ if (validator & prevValue <= currValue) {
47
+ (true , currValue)
48
+ } else {
49
+ (false , currValue)
50
+ }
51
+ }._1
52
+ }
53
+ }
54
+
37
55
/**
38
56
* Parameter for mapping continuous features into buckets.
39
57
* @group param
40
58
*/
41
59
val buckets : Param [Array [Double ]] = new Param [Array [Double ]](this , " buckets" ,
42
- " Map continuous features into buckets." )
60
+ " Split points for mapping continuous features into buckets." , checkBuckets )
43
61
44
62
/** @group getParam */
45
63
def getBuckets : Array [Double ] = $(buckets)
@@ -55,7 +73,7 @@ final class Bucketizer extends Transformer with HasInputCol with HasOutputCol {
55
73
56
74
override def transform (dataset : DataFrame ): DataFrame = {
57
75
transformSchema(dataset.schema)
58
- val bucketizer = udf { feature : Double => binarySearchForBins ($(buckets), feature) }
76
+ val bucketizer = udf { feature : Double => binarySearchForBuckets ($(buckets), feature) }
59
77
val outputColName = $(outputCol)
60
78
val metadata = NominalAttribute .defaultAttr
61
79
.withName(outputColName).withValues($(buckets).map(_.toString)).toMetadata()
@@ -65,7 +83,7 @@ final class Bucketizer extends Transformer with HasInputCol with HasOutputCol {
65
83
/**
66
84
* Binary searching in several buckets to place each data point.
67
85
*/
68
- private def binarySearchForBins (splits : Array [Double ], feature : Double ): Double = {
86
+ private def binarySearchForBuckets (splits : Array [Double ], feature : Double ): Double = {
69
87
val wrappedSplits = Array (Double .MinValue ) ++ splits ++ Array (Double .MaxValue )
70
88
var left = 0
71
89
var right = wrappedSplits.length - 2
0 commit comments