Skip to content

Commit 90f79ee

Browse files
committed
recreate pr
1 parent 1b99d0c commit 90f79ee

File tree

3 files changed

+98
-23
lines changed

3 files changed

+98
-23
lines changed

python/pyspark/ml/feature.py

Lines changed: 79 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -327,26 +327,34 @@ class BucketedRandomProjectionLSHModel(LSHModel, JavaMLReadable, JavaMLWritable)
327327

328328

329329
@inherit_doc
330-
class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasHandleInvalid,
331-
JavaMLReadable, JavaMLWritable):
330+
class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasInputCols, HasOutputCols,
331+
HasHandleInvalid, JavaMLReadable, JavaMLWritable):
332332
"""
333-
Maps a column of continuous features to a column of feature buckets.
334-
335-
>>> values = [(0.1,), (0.4,), (1.2,), (1.5,), (float("nan"),), (float("nan"),)]
336-
>>> df = spark.createDataFrame(values, ["values"])
333+
Maps a column of continuous features to a column of feature buckets. Since 2.3.0,
334+
:py:class:`Bucketizer` can map multiple columns at once by setting the :py:attr:`inputCols`
335+
parameter. Note that when both the :py:attr:`inputCol` and :py:attr:`inputCols` parameters
336+
are set, an Exception will be thrown. The :py:attr:`splits` parameter is only used for single
337+
column usage, and :py:attr:`splitsArray` is for multiple columns.
338+
339+
>>> values = [(0.1, 0.0), (0.4, 1.0), (1.2, 1.3), (1.5, float("nan")),
340+
... (float("nan"), 1.0), (float("nan"), 0.0)]
341+
>>> df = spark.createDataFrame(values, ["values1", "values2"])
337342
>>> bucketizer = Bucketizer(splits=[-float("inf"), 0.5, 1.4, float("inf")],
338-
... inputCol="values", outputCol="buckets")
343+
... inputCol="values1", outputCol="buckets")
339344
>>> bucketed = bucketizer.setHandleInvalid("keep").transform(df).collect()
340-
>>> len(bucketed)
341-
6
342-
>>> bucketed[0].buckets
343-
0.0
344-
>>> bucketed[1].buckets
345-
0.0
346-
>>> bucketed[2].buckets
347-
1.0
348-
>>> bucketed[3].buckets
349-
2.0
345+
>>> bucketed = bucketizer.setHandleInvalid("keep").transform(df.select("values1"))
346+
>>> bucketed.show(truncate=False)
347+
+-------+-------+
348+
|values1|buckets|
349+
+-------+-------+
350+
|0.1 |0.0 |
351+
|0.4 |0.0 |
352+
|1.2 |1.0 |
353+
|1.5 |2.0 |
354+
|NaN |3.0 |
355+
|NaN |3.0 |
356+
+-------+-------+
357+
...
350358
>>> bucketizer.setParams(outputCol="b").transform(df).head().b
351359
0.0
352360
>>> bucketizerPath = temp_path + "/bucketizer"
@@ -357,6 +365,22 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasHandleInvalid,
357365
>>> bucketed = bucketizer.setHandleInvalid("skip").transform(df).collect()
358366
>>> len(bucketed)
359367
4
368+
>>> bucketizer2 = Bucketizer(splitsArray=
369+
... [[-float("inf"), 0.5, 1.4, float("inf")], [-float("inf"), 0.5, float("inf")]],
370+
... inputCols=["values1", "values2"], outputCols=["buckets1", "buckets2"])
371+
>>> bucketed2 = bucketizer2.setHandleInvalid("keep").transform(df)
372+
>>> bucketed2.show(truncate=False)
373+
+-------+-------+--------+--------+
374+
|values1|values2|buckets1|buckets2|
375+
+-------+-------+--------+--------+
376+
|0.1 |0.0 |0.0 |0.0 |
377+
|0.4 |1.0 |0.0 |1.0 |
378+
|1.2 |1.3 |1.0 |1.0 |
379+
|1.5 |NaN |2.0 |2.0 |
380+
|NaN |1.0 |3.0 |1.0 |
381+
|NaN |0.0 |3.0 |0.0 |
382+
+-------+-------+--------+--------+
383+
...
360384
361385
.. versionadded:: 1.4.0
362386
"""
@@ -374,14 +398,30 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasHandleInvalid,
374398
handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries "
375399
"containing NaN values. Values outside the splits will always be treated "
376400
"as errors. Options are 'skip' (filter out rows with invalid values), " +
377-
"'error' (throw an error), or 'keep' (keep invalid values in a special " +
378-
"additional bucket).",
401+
"'error' (throw an error), or 'keep' (keep invalid values in a " +
402+
"special additional bucket). Note that in the multiple column " +
403+
"case, the invalid handling is applied to all columns. That said " +
404+
"for 'error' it will throw an error if any invalids are found in " +
405+
"any column, for 'skip' it will skip rows with any invalids in " +
406+
"any columns, etc.",
379407
typeConverter=TypeConverters.toString)
380408

409+
splitsArray = Param(Params._dummy(), "splitsArray", "The array of split points for mapping " +
410+
"continuous features into buckets for multiple columns. For each input " +
411+
"column, with n+1 splits, there are n buckets. A bucket defined by " +
412+
"splits x,y holds values in the range [x,y) except the last bucket, " +
413+
"which also includes y. The splits should be of length >= 3 and " +
414+
"strictly increasing. Values at -inf, inf must be explicitly provided " +
415+
"to cover all Double values; otherwise, values outside the splits " +
416+
"specified will be treated as errors.",
417+
typeConverter=TypeConverters.toListListFloat)
418+
381419
@keyword_only
382-
def __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error"):
420+
def __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error",
421+
splitsArray=None, inputCols=None, outputCols=None):
383422
"""
384-
__init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error")
423+
__init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error", \
424+
splitsArray=None, inputCols=None, outputCols=None)
385425
"""
386426
super(Bucketizer, self).__init__()
387427
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Bucketizer", self.uid)
@@ -391,9 +431,11 @@ def __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="er
391431

392432
@keyword_only
393433
@since("1.4.0")
394-
def setParams(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error"):
434+
def setParams(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error",
435+
splitsArray=None, inputCols=None, outputCols=None):
395436
"""
396-
setParams(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error")
437+
setParams(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error", \
438+
splitsArray=None, inputCols=None, outputCols=None)
397439
Sets params for this Bucketizer.
398440
"""
399441
kwargs = self._input_kwargs
@@ -413,6 +455,20 @@ def getSplits(self):
413455
"""
414456
return self.getOrDefault(self.splits)
415457

458+
@since("3.0.0")
459+
def setSplitsArray(self, value):
460+
"""
461+
Sets the value of :py:attr:`splitsArray`.
462+
"""
463+
return self._set(splitsArray=value)
464+
465+
@since("3.0.0")
466+
def getSplitsArray(self):
467+
"""
468+
Gets the array of split points or its default value.
469+
"""
470+
return self.getOrDefault(self.splitsArray)
471+
416472

417473
class _CountVectorizerParams(JavaParams, HasInputCol, HasOutputCol):
418474
"""

python/pyspark/ml/param/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,16 @@ def toListFloat(value):
134134
return [float(v) for v in value]
135135
raise TypeError("Could not convert %s to list of floats" % value)
136136

137+
@staticmethod
138+
def toListListFloat(value):
139+
"""
140+
Convert a value to list of list of floats, if possible.
141+
"""
142+
if TypeConverters._can_convert_to_list(value):
143+
value = TypeConverters.toList(value)
144+
return [TypeConverters.toListFloat(v) for v in value]
145+
raise TypeError("Could not convert %s to list of list of floats" % value)
146+
137147
@staticmethod
138148
def toListInt(value):
139149
"""

python/pyspark/ml/tests/test_param.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,15 @@ def test_list_float(self):
8787
self.assertTrue(all([type(v) == float for v in b.getSplits()]))
8888
self.assertRaises(TypeError, lambda: Bucketizer(splits=["a", 1.0]))
8989

90+
def test_list_list_float(self):
91+
b = Bucketizer(splitsArray=[[-0.1, 0.5, 3], [-5, 1.5]])
92+
self.assertEqual(b.getSplitsArray(), [[-0.1, 0.5, 3.0], [-5.0, 1.5]])
93+
self.assertTrue(all([type(v) == list for v in b.getSplitsArray()]))
94+
self.assertTrue(all([type(v) == float for v in b.getSplitsArray()[0]]))
95+
self.assertTrue(all([type(v) == float for v in b.getSplitsArray()[1]]))
96+
self.assertRaises(TypeError, lambda: Bucketizer(splitsArray=["a", 1.0]))
97+
self.assertRaises(TypeError, lambda: Bucketizer(splitsArray=[[-5, 1.5], ["a", 1.0]]))
98+
9099
def test_list_string(self):
91100
for labels in [np.array(['a', u'b']), ['a', u'b'], np.array(['a', 'b'])]:
92101
idx_to_string = IndexToString(labels=labels)

0 commit comments

Comments
 (0)