Skip to content

[SPARK-23871][ML][PYTHON]add python api for VectorAssembler handleInvalid #21003

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,12 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
*/
@Since("2.4.0")
override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid",
"""Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with
|invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the
|output). Column lengths are taken from the size of ML Attribute Group, which can be set using
|`VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also be inferred
|from first rows of the data since it is safe to do so but only in case of 'error' or 'skip'.
|""".stripMargin.replaceAll("\n", " "),
"""Param for how to handle invalid data (NULL and NaN values). Options are 'skip' (filter out
|rows with invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN
|in the output). Column lengths are taken from the size of ML Attribute Group, which can be
|set using `VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also
|be inferred from first rows of the data since it is safe to do so but only in case of 'error'
|or 'skip'.""".stripMargin.replaceAll("\n", " "),
ParamValidators.inArray(VectorAssembler.supportedHandleInvalids))

setDefault(handleInvalid, VectorAssembler.ERROR_INVALID)
Expand Down
42 changes: 37 additions & 5 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -2701,7 +2701,8 @@ def setParams(self, inputCol=None, outputCol=None):


@inherit_doc
class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, JavaMLReadable, JavaMLWritable):
class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, HasHandleInvalid, JavaMLReadable,
JavaMLWritable):
"""
A feature transformer that merges multiple columns into a vector column.

Expand All @@ -2719,25 +2720,56 @@ class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, JavaMLReadabl
>>> loadedAssembler = VectorAssembler.load(vectorAssemblerPath)
>>> loadedAssembler.transform(df).head().freqs == vecAssembler.transform(df).head().freqs
True
>>> dfWithNullsAndNaNs = spark.createDataFrame(
... [(1.0, 2.0, None), (3.0, float("nan"), 4.0), (5.0, 6.0, 7.0)], ["a", "b", "c"])
>>> vecAssembler2 = VectorAssembler(inputCols=["a", "b", "c"], outputCol="features",
... handleInvalid="keep")
>>> vecAssembler2.transform(dfWithNullsAndNaNs).show()
+---+---+----+-------------+
| a| b| c| features|
+---+---+----+-------------+
|1.0|2.0|null|[1.0,2.0,NaN]|
|3.0|NaN| 4.0|[3.0,NaN,4.0]|
|5.0|6.0| 7.0|[5.0,6.0,7.0]|
+---+---+----+-------------+
...
>>> vecAssembler2.setParams(handleInvalid="skip").transform(dfWithNullsAndNaNs).show()
+---+---+---+-------------+
| a| b| c| features|
+---+---+---+-------------+
|5.0|6.0|7.0|[5.0,6.0,7.0]|
+---+---+---+-------------+
...

.. versionadded:: 1.4.0
"""

handleInvalid = Param(Params._dummy(), "handleInvalid", "How to handle invalid data (NULL " +
"and NaN values). Options are 'skip' (filter out rows with invalid " +
"data), 'error' (throw an error), or 'keep' (return relevant number " +
"of NaN in the output). Column lengths are taken from the size of ML " +
"Attribute Group, which can be set using `VectorSizeHint` in a " +
"pipeline before `VectorAssembler`. Column lengths can also be " +
"inferred from first rows of the data since it is safe to do so but " +
"only in case of 'error' or 'skip').",
typeConverter=TypeConverters.toString)

@keyword_only
def __init__(self, inputCols=None, outputCol=None):
def __init__(self, inputCols=None, outputCol=None, handleInvalid="error"):
"""
__init__(self, inputCols=None, outputCol=None)
__init__(self, inputCols=None, outputCol=None, handleInvalid="error")
"""
super(VectorAssembler, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorAssembler", self.uid)
self._setDefault(handleInvalid="error")
kwargs = self._input_kwargs
self.setParams(**kwargs)

@keyword_only
@since("1.4.0")
def setParams(self, inputCols=None, outputCol=None):
def setParams(self, inputCols=None, outputCol=None, handleInvalid="error"):
"""
setParams(self, inputCols=None, outputCol=None)
setParams(self, inputCols=None, outputCol=None, handleInvalid="error")
Sets params for this VectorAssembler.
"""
kwargs = self._input_kwargs
Expand Down