Skip to content

[SPARK-5886][ML] Add StringIndexer as a feature transformer #4735

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 126 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.feature

import org.apache.spark.SparkException
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.param._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.util.collection.OpenHashMap

/**
* Base trait for [[StringIndexer]] and [[StringIndexerModel]].
*/
private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol {

/** Validates and transforms the input schema. */
protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = {
val map = this.paramMap ++ paramMap
checkInputColumn(schema, map(inputCol), StringType)
val inputFields = schema.fields
val outputColName = map(outputCol)
require(inputFields.forall(_.name != outputColName),
s"Output column $outputColName already exists.")
val attr = NominalAttribute.defaultAttr.withName(map(outputCol))
val outputFields = inputFields :+ attr.toStructField()
StructType(outputFields)
}
}

/**
* :: AlphaComponent ::
* A label indexer that maps a string column of labels to an ML column of label indices.
* The indices are in [0, numLabels), ordered by label frequencies.
* So the most frequent label gets index 0.
*/
@AlphaComponent
class StringIndexer extends Estimator[StringIndexerModel] with StringIndexerBase {

/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)

/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)

// TODO: handle unseen labels

override def fit(dataset: DataFrame, paramMap: ParamMap): StringIndexerModel = {
val map = this.paramMap ++ paramMap
val counts = dataset.select(map(inputCol)).map(_.getString(0)).countByValue()
val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray
val model = new StringIndexerModel(this, map, labels)
Params.inheritValues(map, this, model)
model
}

override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
validateAndTransformSchema(schema, paramMap)
}
}

/**
* :: AlphaComponent ::
* Model fitted by [[StringIndexer]].
*/
@AlphaComponent
class StringIndexerModel private[ml] (
override val parent: StringIndexer,
override val fittingParamMap: ParamMap,
labels: Array[String]) extends Model[StringIndexerModel] with StringIndexerBase {

private val labelToIndex: OpenHashMap[String, Double] = {
val n = labels.length
val map = new OpenHashMap[String, Double](n)
var i = 0
while (i < n) {
map.update(labels(i), i)
i += 1
}
map
}

/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)

/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)

override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
val map = this.paramMap ++ paramMap
val indexer = udf { label: String =>
if (labelToIndex.contains(label)) {
labelToIndex(label)
} else {
// TODO: handle unseen labels
throw new SparkException(s"Unseen label: $label.")
}
}
val outputColName = map(outputCol)
val metadata = NominalAttribute.defaultAttr
.withName(outputColName).withValues(labels).toStructField().metadata
dataset.select(col("*"), indexer(dataset(map(inputCol))).as(outputColName, metadata))
}

override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
validateAndTransformSchema(schema, paramMap)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.feature

import org.scalatest.FunSuite

import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.SQLContext

class StringIndexerSuite extends FunSuite with MLlibTestSparkContext {
private var sqlContext: SQLContext = _

override def beforeAll(): Unit = {
super.beforeAll()
sqlContext = new SQLContext(sc)
}

test("StringIndexer") {
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
val df = sqlContext.createDataFrame(data).toDF("id", "label")
val indexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("labelIndex")
.fit(df)
val transformed = indexer.transform(df)
val attr = Attribute.fromStructField(transformed.schema("labelIndex"))
.asInstanceOf[NominalAttribute]
assert(attr.values.get === Array("a", "c", "b"))
val output = transformed.select("id", "labelIndex").map { r =>
(r.getInt(0), r.getDouble(1))
}.collect().toSet
// a -> 0, b -> 2, c -> 1
val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0))
assert(output === expected)
}
}