@@ -39,12 +39,12 @@ private[ml] final class Bucketizer(override val parent: Estimator[Bucketizer])
39
39
40
40
/**
41
41
* Parameter for mapping continuous features into buckets. With n splits, there are n+1 buckets.
42
- * A bucket defined by splits x,y holds values in the range ( x,y] .
42
+ * A bucket defined by splits x,y holds values in the range [ x,y) .
43
43
* @group param
44
44
*/
45
45
val splits : Param [Array [Double ]] = new Param [Array [Double ]](this , " splits" ,
46
46
" Split points for mapping continuous features into buckets. With n splits, there are n+1" +
47
- " buckets. A bucket defined by splits x,y holds values in the range ( x,y] ." ,
47
+ " buckets. A bucket defined by splits x,y holds values in the range [ x,y) ." ,
48
48
Bucketizer .checkSplits)
49
49
50
50
/** @group getParam */
@@ -85,7 +85,8 @@ private[ml] final class Bucketizer(override val parent: Estimator[Bucketizer])
85
85
transformSchema(dataset.schema)
86
86
val wrappedSplits = Array (Double .MinValue ) ++ $(splits) ++ Array (Double .MaxValue )
87
87
val bucketizer = udf { feature : Double =>
88
- Bucketizer .binarySearchForBuckets(wrappedSplits, feature) }
88
+ Bucketizer
89
+ .binarySearchForBuckets(wrappedSplits, feature, $(lowerInclusive), $(upperInclusive)) }
89
90
val newCol = bucketizer(dataset($(inputCol)))
90
91
val newField = prepOutputField(dataset.schema)
91
92
dataset.withColumn($(outputCol), newCol.as($(outputCol), newField.metadata))
@@ -95,7 +96,6 @@ private[ml] final class Bucketizer(override val parent: Estimator[Bucketizer])
95
96
val attr = new NominalAttribute (
96
97
name = Some ($(outputCol)),
97
98
isOrdinal = Some (true ),
98
- numValues = Some ($(splits).size),
99
99
values = Some ($(splits).map(_.toString)))
100
100
101
101
attr.toStructField()
@@ -131,20 +131,27 @@ object Bucketizer {
131
131
/**
132
132
* Binary searching in several buckets to place each data point.
133
133
*/
134
- private [feature] def binarySearchForBuckets (splits : Array [Double ], feature : Double ): Double = {
134
+ private [feature] def binarySearchForBuckets (
135
+ splits : Array [Double ],
136
+ feature : Double ,
137
+ lowerInclusive : Boolean ,
138
+ upperInclusive : Boolean ): Double = {
139
+ if ((feature < splits.head && ! lowerInclusive) || (feature > splits.last && ! upperInclusive))
140
+ throw new Exception (s " Feature $feature out of bound, check your features or loose the " +
141
+ s " lower/upper bound constraint. " )
135
142
var left = 0
136
143
var right = splits.length - 2
137
144
while (left <= right) {
138
145
val mid = left + (right - left) / 2
139
146
val split = splits(mid)
140
- if ((feature > split) && (feature <= splits(mid + 1 ))) {
147
+ if ((feature >= split) && (feature < splits(mid + 1 ))) {
141
148
return mid
142
- } else if (feature <= split) {
149
+ } else if (feature < split) {
143
150
right = mid - 1
144
151
} else {
145
152
left = mid + 1
146
153
}
147
154
}
148
- throw new Exception (" Failed to find a bucket." )
155
+ throw new Exception (s " Failed to find a bucket for feature $feature . " )
149
156
}
150
157
}
0 commit comments