Skip to content

Commit e112394

Browse files
committed
make StringIndexerModel silent if input column does not exist
1 parent ad06727 commit e112394

File tree

2 files changed

+23
-3
lines changed

2 files changed

+23
-3
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,12 @@ class StringIndexerModel private[ml] (
112112
def setOutputCol(value: String): this.type = set(outputCol, value)
113113

114114
override def transform(dataset: DataFrame): DataFrame = {
115+
if (!dataset.schema.fieldNames.contains($(inputCol))) {
116+
logInfo(s"Input column ${$(inputCol)} does not exist during transformation. " +
117+
"Skip StringIndexerModel.")
118+
return dataset
119+
}
120+
115121
val indexer = udf { label: String =>
116122
if (labelToIndex.contains(label)) {
117123
labelToIndex(label)
@@ -128,6 +134,11 @@ class StringIndexerModel private[ml] (
128134
}
129135

130136
override def transformSchema(schema: StructType): StructType = {
131-
validateAndTransformSchema(schema)
137+
if (schema.fieldNames.contains($(inputCol))) {
138+
validateAndTransformSchema(schema)
139+
} else {
140+
// If the input column does not exist during transformation, we skip StringIndexerModel.
141+
schema
142+
}
132143
}
133144
}

mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717

1818
package org.apache.spark.ml.feature
1919

20-
import org.apache.spark.SparkFunSuite
20+
import org.scalatest.FunSuite
21+
2122
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
2223
import org.apache.spark.mllib.util.MLlibTestSparkContext
2324

24-
class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
25+
class StringIndexerSuite extends FunSuite with MLlibTestSparkContext {
2526

2627
test("StringIndexer") {
2728
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
@@ -60,4 +61,12 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
6061
val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0))
6162
assert(output === expected)
6263
}
64+
65+
test("StringIndexerModel should keep silent if the input column does not exist.") {
66+
val indexerModel = new StringIndexerModel("indexer", Array("a", "b", "c"))
67+
.setInputCol("label")
68+
.setOutputCol("labelIndex")
69+
val df = sqlContext.range(0L, 10L)
70+
assert(indexerModel.transform(df).eq(df))
71+
}
6372
}

0 commit comments

Comments
 (0)