Skip to content

Commit 5fe190e

Browse files
committed
add bucketizer
1 parent 01187f5 commit 5fe190e

File tree

1 file changed

+100
-0
lines changed

1 file changed

+100
-0
lines changed
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.feature
19+
20+
import org.apache.spark.annotation.AlphaComponent
21+
import org.apache.spark.ml.Transformer
22+
import org.apache.spark.ml.attribute.{NominalAttribute, BinaryAttribute}
23+
import org.apache.spark.ml.param._
24+
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
25+
import org.apache.spark.ml.util.SchemaUtils
26+
import org.apache.spark.sql._
27+
import org.apache.spark.sql.functions._
28+
import org.apache.spark.sql.types.{DoubleType, StructType}
29+
30+
/**
31+
* :: AlphaComponent ::
32+
* Binarize a column of continuous features given a threshold.
33+
*/
34+
@AlphaComponent
35+
final class Bucketizer extends Transformer with HasInputCol with HasOutputCol {
36+
37+
/**
38+
* Param for threshold used to binarize continuous features.
39+
* The features greater than the threshold, will be binarized to 1.0.
40+
* The features equal to or less than the threshold, will be binarized to 0.0.
41+
* @group param
42+
*/
43+
val buckets: Param[Array[Double]] = new Param[Array[Double]](this, "buckets", "")
44+
45+
/** @group getParam */
46+
def getBuckets: Array[Double] = $(buckets)
47+
48+
/** @group setParam */
49+
def setBuckets(value: Array[Double]): this.type = set(buckets, value)
50+
51+
/** @group setParam */
52+
def setInputCol(value: String): this.type = set(inputCol, value)
53+
54+
/** @group setParam */
55+
def setOutputCol(value: String): this.type = set(outputCol, value)
56+
57+
override def transform(dataset: DataFrame): DataFrame = {
58+
transformSchema(dataset.schema)
59+
val bucketizer = udf { feature: Double => binarySearchForBins($(buckets), feature) }
60+
val outputColName = $(outputCol)
61+
val metadata = NominalAttribute.defaultAttr
62+
.withName(outputColName).withValues($(buckets).map(_.toString)).toMetadata()
63+
dataset.select(col("*"), bucketizer(dataset($(inputCol))).as(outputColName, metadata))
64+
}
65+
66+
/**
67+
* Binary searching in several bins to place each data point.
68+
*/
69+
private def binarySearchForBins(splits: Array[Double], feature: Double): Double = {
70+
val wrappedSplits = Array(Double.MinValue) ++ splits ++ Array(Double.MaxValue)
71+
var left = 0
72+
var right = wrappedSplits.length - 2
73+
while (left <= right) {
74+
val mid = left + (right - left) / 2
75+
val split = wrappedSplits(mid)
76+
if ((feature > split) && (feature <= wrappedSplits(mid + 1))) {
77+
return mid
78+
} else if (feature <= split) {
79+
right = mid - 1
80+
} else {
81+
left = mid + 1
82+
}
83+
}
84+
-1
85+
}
86+
87+
override def transformSchema(schema: StructType): StructType = {
88+
SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
89+
90+
val inputFields = schema.fields
91+
val outputColName = $(outputCol)
92+
93+
require(inputFields.forall(_.name != outputColName),
94+
s"Output column $outputColName already exists.")
95+
96+
val attr = NominalAttribute.defaultAttr.withName(outputColName)
97+
val outputFields = inputFields :+ attr.toStructField()
98+
StructType(outputFields)
99+
}
100+
}

0 commit comments

Comments
 (0)