Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -1917,8 +1917,7 @@ def mean(self):


@inherit_doc
class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, JavaMLReadable,
JavaMLWritable):
class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
"""
A label indexer that maps a string column of labels to an ML column of label indices.
If the input column is numeric, we cast it to string and index the string values.
Expand All @@ -1936,6 +1935,14 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid,
>>> sorted(set([(i[0], str(i[1])) for i in itd.select(itd.id, itd.label2).collect()]),
... key=lambda x: x[0])
[(0, 'a'), (1, 'b'), (2, 'c'), (3, 'a'), (4, 'a'), (5, 'c')]
>>> testData2 = sc.parallelize([Row(id=0, label="a"), Row(id=1, label="d"),
... Row(id=2, label=None)], 2)
>>> dfKeep= spark.createDataFrame(testData2)
>>> modelKeep = stringIndexer.setHandleInvalid("keep").fit(stringIndDf)
>>> tdK = modelKeep.transform(dfKeep)
>>> sorted(set([(i[0], i[1]) for i in tdK.select(tdK.id, tdK.indexed).collect()]),
... key=lambda x: x[0])
[(0, 0.0), (1, 3.0), (2, 3.0)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you move the newly added test to tests.py? We keep the basic doc tests here both for test and example, other tests should be placed at tests.py. Thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay. Sure @yanboliang @holdenk

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@VinceShieh Do you have time to update this PR? We would like to get this in 2.2. Thanks.

>>> stringIndexerPath = temp_path + "/string-indexer"
>>> stringIndexer.save(stringIndexerPath)
>>> loadedIndexer = StringIndexer.load(stringIndexerPath)
Expand All @@ -1955,6 +1962,12 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid,
.. versionadded:: 1.4.0
"""

handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid data (unseen " +
"labels or NULL values). Options are 'skip' (filter out rows with " +
"invalid data), error (throw an error), or 'keep' (put invalid data " +
"in a special additional bucket, at index numLabels).",
typeConverter=TypeConverters.toString)

@keyword_only
def __init__(self, inputCol=None, outputCol=None, handleInvalid="error"):
"""
Expand All @@ -1979,6 +1992,20 @@ def setParams(self, inputCol=None, outputCol=None, handleInvalid="error"):
def _create_model(self, java_model):
return StringIndexerModel(java_model)

@since("2.2.0")
def setHandleInvalid(self, value):
"""
Sets the value of :py:attr:`handleInvalid`.
"""
return self._set(handleInvalid=value)

@since("2.2.0")
def getHandleInvalid(self):
"""
Gets the value of :py:attr:`handleInvalid` or its default value.
"""
return self.getOrDefault(self.handleInvalid)


class StringIndexerModel(JavaModel, JavaMLReadable, JavaMLWritable):
"""
Expand Down