@@ -327,26 +327,34 @@ class BucketedRandomProjectionLSHModel(LSHModel, JavaMLReadable, JavaMLWritable)
327327
328328
329329@inherit_doc
330- class Bucketizer (JavaTransformer , HasInputCol , HasOutputCol , HasHandleInvalid ,
331- JavaMLReadable , JavaMLWritable ):
330+ class Bucketizer (JavaTransformer , HasInputCol , HasOutputCol , HasInputCols , HasOutputCols ,
331+ HasHandleInvalid , JavaMLReadable , JavaMLWritable ):
332332 """
333- Maps a column of continuous features to a column of feature buckets.
334-
335- >>> values = [(0.1,), (0.4,), (1.2,), (1.5,), (float("nan"),), (float("nan"),)]
336- >>> df = spark.createDataFrame(values, ["values"])
333+ Maps a column of continuous features to a column of feature buckets. Since 2.3.0,
334+ :py:class:`Bucketizer` can map multiple columns at once by setting the :py:attr:`inputCols`
335+ parameter. Note that when both the :py:attr:`inputCol` and :py:attr:`inputCols` parameters
336+ are set, an Exception will be thrown. The :py:attr:`splits` parameter is only used for single
337+ column usage, and :py:attr:`splitsArray` is for multiple columns.
338+
339+ >>> values = [(0.1, 0.0), (0.4, 1.0), (1.2, 1.3), (1.5, float("nan")),
340+ ... (float("nan"), 1.0), (float("nan"), 0.0)]
341+ >>> df = spark.createDataFrame(values, ["values1", "values2"])
337342 >>> bucketizer = Bucketizer(splits=[-float("inf"), 0.5, 1.4, float("inf")],
338- ... inputCol="values ", outputCol="buckets")
343+ ... inputCol="values1 ", outputCol="buckets")
339344 >>> bucketed = bucketizer.setHandleInvalid("keep").transform(df).collect()
340- >>> len(bucketed)
341- 6
342- >>> bucketed[0].buckets
343- 0.0
344- >>> bucketed[1].buckets
345- 0.0
346- >>> bucketed[2].buckets
347- 1.0
348- >>> bucketed[3].buckets
349- 2.0
345+ >>> bucketed = bucketizer.setHandleInvalid("keep").transform(df.select("values1"))
346+ >>> bucketed.show(truncate=False)
347+ +-------+-------+
348+ |values1|buckets|
349+ +-------+-------+
350+ |0.1 |0.0 |
351+ |0.4 |0.0 |
352+ |1.2 |1.0 |
353+ |1.5 |2.0 |
354+ |NaN |3.0 |
355+ |NaN |3.0 |
356+ +-------+-------+
357+ ...
350358 >>> bucketizer.setParams(outputCol="b").transform(df).head().b
351359 0.0
352360 >>> bucketizerPath = temp_path + "/bucketizer"
@@ -357,6 +365,22 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasHandleInvalid,
357365 >>> bucketed = bucketizer.setHandleInvalid("skip").transform(df).collect()
358366 >>> len(bucketed)
359367 4
368+ >>> bucketizer2 = Bucketizer(splitsArray=
369+ ... [[-float("inf"), 0.5, 1.4, float("inf")], [-float("inf"), 0.5, float("inf")]],
370+ ... inputCols=["values1", "values2"], outputCols=["buckets1", "buckets2"])
371+ >>> bucketed2 = bucketizer2.setHandleInvalid("keep").transform(df)
372+ >>> bucketed2.show(truncate=False)
373+ +-------+-------+--------+--------+
374+ |values1|values2|buckets1|buckets2|
375+ +-------+-------+--------+--------+
376+ |0.1 |0.0 |0.0 |0.0 |
377+ |0.4 |1.0 |0.0 |1.0 |
378+ |1.2 |1.3 |1.0 |1.0 |
379+ |1.5 |NaN |2.0 |2.0 |
380+ |NaN |1.0 |3.0 |1.0 |
381+ |NaN |0.0 |3.0 |0.0 |
382+ +-------+-------+--------+--------+
383+ ...
360384
361385 .. versionadded:: 1.4.0
362386 """
@@ -374,14 +398,30 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasHandleInvalid,
374398 handleInvalid = Param (Params ._dummy (), "handleInvalid" , "how to handle invalid entries "
375399 "containing NaN values. Values outside the splits will always be treated "
376400 "as errors. Options are 'skip' (filter out rows with invalid values), " +
377- "'error' (throw an error), or 'keep' (keep invalid values in a special " +
378- "additional bucket)." ,
401+ "'error' (throw an error), or 'keep' (keep invalid values in a " +
402+ "special additional bucket). Note that in the multiple column " +
403+ "case, the invalid handling is applied to all columns. That said " +
404+ "for 'error' it will throw an error if any invalids are found in " +
405+ "any column, for 'skip' it will skip rows with any invalids in " +
406+ "any columns, etc." ,
379407 typeConverter = TypeConverters .toString )
380408
409+ splitsArray = Param (Params ._dummy (), "splitsArray" , "The array of split points for mapping " +
410+ "continuous features into buckets for multiple columns. For each input " +
411+ "column, with n+1 splits, there are n buckets. A bucket defined by " +
412+ "splits x,y holds values in the range [x,y) except the last bucket, " +
413+ "which also includes y. The splits should be of length >= 3 and " +
414+ "strictly increasing. Values at -inf, inf must be explicitly provided " +
415+ "to cover all Double values; otherwise, values outside the splits " +
416+ "specified will be treated as errors." ,
417+ typeConverter = TypeConverters .toListListFloat )
418+
381419 @keyword_only
382- def __init__ (self , splits = None , inputCol = None , outputCol = None , handleInvalid = "error" ):
420+ def __init__ (self , splits = None , inputCol = None , outputCol = None , handleInvalid = "error" ,
421+ splitsArray = None , inputCols = None , outputCols = None ):
383422 """
384- __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error")
423+ __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error", \
424+ splitsArray=None, inputCols=None, outputCols=None)
385425 """
386426 super (Bucketizer , self ).__init__ ()
387427 self ._java_obj = self ._new_java_obj ("org.apache.spark.ml.feature.Bucketizer" , self .uid )
@@ -391,9 +431,11 @@ def __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="er
391431
392432 @keyword_only
393433 @since ("1.4.0" )
394- def setParams (self , splits = None , inputCol = None , outputCol = None , handleInvalid = "error" ):
434+ def setParams (self , splits = None , inputCol = None , outputCol = None , handleInvalid = "error" ,
435+ splitsArray = None , inputCols = None , outputCols = None ):
395436 """
396- setParams(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error")
437+ setParams(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error", \
438+ splitsArray=None, inputCols=None, outputCols=None)
397439 Sets params for this Bucketizer.
398440 """
399441 kwargs = self ._input_kwargs
@@ -413,6 +455,20 @@ def getSplits(self):
413455 """
414456 return self .getOrDefault (self .splits )
415457
458+ @since ("3.0.0" )
459+ def setSplitsArray (self , value ):
460+ """
461+ Sets the value of :py:attr:`splitsArray`.
462+ """
463+ return self ._set (splitsArray = value )
464+
465+ @since ("3.0.0" )
466+ def getSplitsArray (self ):
467+ """
468+ Gets the array of split points or its default value.
469+ """
470+ return self .getOrDefault (self .splitsArray )
471+
416472
417473class _CountVectorizerParams (JavaParams , HasInputCol , HasOutputCol ):
418474 """
0 commit comments