Skip to content

Commit

Permalink
[SPARK-20307][ML][SPARKR][FOLLOW-UP] RFormula should handle invalid f…
Browse files Browse the repository at this point in the history
…or both features and label column.

## What changes were proposed in this pull request?
```RFormula``` should handle invalid for both features and label column.
apache#18496 only handle invalid values in features column. This PR add handling invalid values for label column and test cases.

## How was this patch tested?
Add test cases.

Author: Yanbo Liang <ybliang8@gmail.com>

Closes apache#18613 from yanboliang/spark-20307.
  • Loading branch information
yanboliang committed Jul 15, 2017
1 parent 74ac1fb commit 69e5282
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 8 deletions.
2 changes: 1 addition & 1 deletion R/pkg/tests/fulltests/test_mllib_tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ test_that("spark.randomForest", {
expect_error(collect(predictions))
model <- spark.randomForest(traindf, clicked ~ ., type = "classification",
maxDepth = 10, maxBins = 10, numTrees = 10,
handleInvalid = "skip")
handleInvalid = "keep")
predictions <- predict(model, testdf)
expect_equal(class(collect(predictions)$clicked[1]), "character")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,16 +134,16 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
def getFormula: String = $(formula)

/**
* Param for how to handle invalid data (unseen labels or NULL values).
* Options are 'skip' (filter out rows with invalid data),
* Param for how to handle invalid data (unseen or NULL values) in features and label column
* of string type. Options are 'skip' (filter out rows with invalid data),
* 'error' (throw an error), or 'keep' (put invalid data in a special additional
* bucket, at index numLabels).
* Default: "error"
* @group param
*/
@Since("2.3.0")
override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid",
"How to handle invalid data (unseen labels or NULL values). " +
override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "How to " +
"handle invalid data (unseen or NULL values) in features and label column of string type. " +
"Options are 'skip' (filter out rows with invalid data), error (throw an error), " +
"or 'keep' (put invalid data in a special additional bucket, at index numLabels).",
ParamValidators.inArray(StringIndexer.supportedHandleInvalids))
Expand Down Expand Up @@ -265,6 +265,7 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
encoderStages += new StringIndexer()
.setInputCol(resolvedFormula.label)
.setOutputCol($(labelCol))
.setHandleInvalid($(handleInvalid))
}

val pipelineModel = new Pipeline(uid).setStages(encoderStages.toArray).fit(dataset)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.ml.feature

import org.apache.spark.SparkFunSuite
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.param.ParamsSuite
Expand Down Expand Up @@ -501,4 +501,51 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
assert(expected.resolvedFormula.hasIntercept === actual.resolvedFormula.hasIntercept)
}
}

test("handle unseen features or labels") {
val df1 = Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz")).toDF("id", "a", "b")
val df2 = Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zy")).toDF("id", "a", "b")

// Handle unseen features.
val formula1 = new RFormula().setFormula("id ~ a + b")
intercept[SparkException] {
formula1.fit(df1).transform(df2).collect()
}
val result1 = formula1.setHandleInvalid("skip").fit(df1).transform(df2)
val result2 = formula1.setHandleInvalid("keep").fit(df1).transform(df2)

val expected1 = Seq(
(1, "foo", "zq", Vectors.dense(0.0, 1.0), 1.0),
(2, "bar", "zq", Vectors.dense(1.0, 1.0), 2.0)
).toDF("id", "a", "b", "features", "label")
val expected2 = Seq(
(1, "foo", "zq", Vectors.dense(0.0, 1.0, 1.0, 0.0), 1.0),
(2, "bar", "zq", Vectors.dense(1.0, 0.0, 1.0, 0.0), 2.0),
(3, "bar", "zy", Vectors.dense(1.0, 0.0, 0.0, 0.0), 3.0)
).toDF("id", "a", "b", "features", "label")

assert(result1.collect() === expected1.collect())
assert(result2.collect() === expected2.collect())

// Handle unseen labels.
val formula2 = new RFormula().setFormula("b ~ a + id")
intercept[SparkException] {
formula2.fit(df1).transform(df2).collect()
}
val result3 = formula2.setHandleInvalid("skip").fit(df1).transform(df2)
val result4 = formula2.setHandleInvalid("keep").fit(df1).transform(df2)

val expected3 = Seq(
(1, "foo", "zq", Vectors.dense(0.0, 1.0), 0.0),
(2, "bar", "zq", Vectors.dense(1.0, 2.0), 0.0)
).toDF("id", "a", "b", "features", "label")
val expected4 = Seq(
(1, "foo", "zq", Vectors.dense(0.0, 1.0, 1.0), 0.0),
(2, "bar", "zq", Vectors.dense(1.0, 0.0, 2.0), 0.0),
(3, "bar", "zy", Vectors.dense(1.0, 0.0, 3.0), 2.0)
).toDF("id", "a", "b", "features", "label")

assert(result3.collect() === expected3.collect())
assert(result4.collect() === expected4.collect())
}
}
5 changes: 3 additions & 2 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -2107,8 +2107,9 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid,
typeConverter=TypeConverters.toString)

handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid data (unseen " +
"labels or NULL values). Options are 'skip' (filter out rows with " +
"invalid data), error (throw an error), or 'keep' (put invalid data " +
"or NULL values) in features and label column of string type. " +
"Options are 'skip' (filter out rows with invalid data), " +
"error (throw an error), or 'keep' (put invalid data " +
"in a special additional bucket, at index numLabels).",
typeConverter=TypeConverters.toString)

Expand Down

0 comments on commit 69e5282

Please sign in to comment.