Skip to content

Commit

Permalink
[SPARK-19852][PYSPARK][ML] Python StringIndexer supports 'keep' to ha…
Browse files Browse the repository at this point in the history
…ndle invalid data

## What changes were proposed in this pull request?
This PR is to maintain API parity with changes made in SPARK-17498 to support a new option
'keep' in StringIndexer to handle unseen labels or NULL values with PySpark.

Note: This is updated version of apache#17237 , the primary author of this PR is VinceShieh .
## How was this patch tested?
Unit tests.

Author: VinceShieh <vincent.xie@intel.com>
Author: Yanbo Liang <ybliang8@gmail.com>

Closes apache#18453 from yanboliang/spark-19852.
  • Loading branch information
yanboliang committed Jul 2, 2017
1 parent c605fee commit c19680b
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 0 deletions.
6 changes: 6 additions & 0 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -2132,6 +2132,12 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid,
"frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc.",
typeConverter=TypeConverters.toString)

handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid data (unseen " +
"labels or NULL values). Options are 'skip' (filter out rows with " +
"invalid data), error (throw an error), or 'keep' (put invalid data " +
"in a special additional bucket, at index numLabels).",
typeConverter=TypeConverters.toString)

@keyword_only
def __init__(self, inputCol=None, outputCol=None, handleInvalid="error",
stringOrderType="frequencyDesc"):
Expand Down
21 changes: 21 additions & 0 deletions python/pyspark/ml/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,27 @@ def test_rformula_string_indexer_order_type(self):
for i in range(0, len(expected)):
self.assertTrue(all(observed[i]["features"].toArray() == expected[i]))

def test_string_indexer_handle_invalid(self):
df = self.spark.createDataFrame([
(0, "a"),
(1, "d"),
(2, None)], ["id", "label"])

si1 = StringIndexer(inputCol="label", outputCol="indexed", handleInvalid="keep",
stringOrderType="alphabetAsc")
model1 = si1.fit(df)
td1 = model1.transform(df)
actual1 = td1.select("id", "indexed").collect()
expected1 = [Row(id=0, indexed=0.0), Row(id=1, indexed=1.0), Row(id=2, indexed=2.0)]
self.assertEqual(actual1, expected1)

si2 = si1.setHandleInvalid("skip")
model2 = si2.fit(df)
td2 = model2.transform(df)
actual2 = td2.select("id", "indexed").collect()
expected2 = [Row(id=0, indexed=0.0), Row(id=1, indexed=1.0)]
self.assertEqual(actual2, expected2)


class HasInducedError(Params):

Expand Down

0 comments on commit c19680b

Please sign in to comment.