Skip to content

Commit 6b9737a

Browse files
brkyvzmengxr
authored andcommitted
[SPARK-7388] [SPARK-7383] wrapper for VectorAssembler in Python
The wrapper required the implementation of the `ArrayParam`, because `Array[T]` is hard to obtain from Python. `ArrayParam` has an extra function called `wCast` which is an internal function to obtain `Array[T]` from `Seq[T]` Author: Burak Yavuz <brkyvz@gmail.com> Author: Xiangrui Meng <meng@databricks.com> Closes apache#5930 from brkyvz/ml-feat and squashes the following commits: 73e745f [Burak Yavuz] Merge pull request #3 from mengxr/SPARK-7388 c221db9 [Xiangrui Meng] overload StringArrayParam.w c81072d [Burak Yavuz] addressed comments 99c2ebf [Burak Yavuz] add to python_shared_params 39ecb07 [Burak Yavuz] fix scalastyle 7f7ea2a [Burak Yavuz] [SPARK-7388][SPARK-7383] wrapper for VectorAssembler in Python (cherry picked from commit 9e2ffb1) Signed-off-by: Xiangrui Meng <meng@databricks.com>
1 parent 84ee348 commit 6b9737a

File tree

8 files changed

+105
-13
lines changed

8 files changed

+105
-13
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ import org.apache.spark.sql.types._
3030

3131
/**
3232
* :: AlphaComponent ::
33-
* A feature transformer than merge multiple columns into a vector column.
33+
* A feature transformer that merges multiple columns into a vector column.
3434
*/
3535
@AlphaComponent
3636
class VectorAssembler extends Transformer with HasInputCols with HasOutputCol {

mllib/src/main/scala/org/apache/spark/ml/param/params.scala

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import java.util.NoSuchElementException
2222

2323
import scala.annotation.varargs
2424
import scala.collection.mutable
25+
import scala.collection.JavaConverters._
2526

2627
import org.apache.spark.annotation.AlphaComponent
2728
import org.apache.spark.ml.util.Identifiable
@@ -218,6 +219,19 @@ class BooleanParam(parent: Params, name: String, doc: String) // No need for isV
218219
override def w(value: Boolean): ParamPair[Boolean] = super.w(value)
219220
}
220221

222+
/** Specialized version of [[Param[Array[T]]]] for Java. */
223+
class StringArrayParam(parent: Params, name: String, doc: String, isValid: Array[String] => Boolean)
224+
extends Param[Array[String]](parent, name, doc, isValid) {
225+
226+
def this(parent: Params, name: String, doc: String) =
227+
this(parent, name, doc, ParamValidators.alwaysTrue)
228+
229+
override def w(value: Array[String]): ParamPair[Array[String]] = super.w(value)
230+
231+
/** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */
232+
def w(value: java.util.List[String]): ParamPair[Array[String]] = w(value.asScala.toArray)
233+
}
234+
221235
/**
222236
* A param amd its value.
223237
*/
@@ -310,9 +324,7 @@ trait Params extends Identifiable with Serializable {
310324
* Sets a parameter in the embedded param map.
311325
*/
312326
protected final def set[T](param: Param[T], value: T): this.type = {
313-
shouldOwn(param)
314-
paramMap.put(param.asInstanceOf[Param[Any]], value)
315-
this
327+
set(param -> value)
316328
}
317329

318330
/**
@@ -322,6 +334,15 @@ trait Params extends Identifiable with Serializable {
322334
set(getParam(param), value)
323335
}
324336

337+
/**
338+
* Sets a parameter in the embedded param map.
339+
*/
340+
protected final def set(paramPair: ParamPair[_]): this.type = {
341+
shouldOwn(paramPair.param)
342+
paramMap.put(paramPair)
343+
this
344+
}
345+
325346
/**
326347
* Optionally returns the user-supplied value of a param.
327348
*/

mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ private[shared] object SharedParamsCodeGen {
8585
case _ if c == classOf[Float] => "FloatParam"
8686
case _ if c == classOf[Double] => "DoubleParam"
8787
case _ if c == classOf[Boolean] => "BooleanParam"
88+
case _ if c.isArray && c.getComponentType == classOf[String] => s"StringArrayParam"
8889
case _ => s"Param[${getTypeString(c)}]"
8990
}
9091
}

mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ private[ml] trait HasInputCols extends Params {
178178
* Param for input column names.
179179
* @group param
180180
*/
181-
final val inputCols: Param[Array[String]] = new Param[Array[String]](this, "inputCols", "input column names")
181+
final val inputCols: StringArrayParam = new StringArrayParam(this, "inputCols", "input column names")
182182

183183
/** @group getParam */
184184
final def getInputCols: Array[String] = $(inputCols)

python/pyspark/ml/feature.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@
1616
#
1717

1818
from pyspark.rdd import ignore_unicode_prefix
19-
from pyspark.ml.param.shared import HasInputCol, HasOutputCol, HasNumFeatures
19+
from pyspark.ml.param.shared import HasInputCol, HasInputCols, HasOutputCol, HasNumFeatures
2020
from pyspark.ml.util import keyword_only
2121
from pyspark.ml.wrapper import JavaTransformer
2222
from pyspark.mllib.common import inherit_doc
2323

24-
__all__ = ['Tokenizer', 'HashingTF']
24+
__all__ = ['Tokenizer', 'HashingTF', 'VectorAssembler']
2525

2626

2727
@inherit_doc
@@ -112,6 +112,45 @@ def setParams(self, numFeatures=1 << 18, inputCol=None, outputCol=None):
112112
return self._set(**kwargs)
113113

114114

115+
@inherit_doc
116+
class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol):
117+
"""
118+
A feature transformer that merges multiple columns into a vector column.
119+
120+
>>> from pyspark.sql import Row
121+
>>> df = sc.parallelize([Row(a=1, b=0, c=3)]).toDF()
122+
>>> vecAssembler = VectorAssembler(inputCols=["a", "b", "c"], outputCol="features")
123+
>>> vecAssembler.transform(df).head().features
124+
SparseVector(3, {0: 1.0, 2: 3.0})
125+
>>> vecAssembler.setParams(outputCol="freqs").transform(df).head().freqs
126+
SparseVector(3, {0: 1.0, 2: 3.0})
127+
>>> params = {vecAssembler.inputCols: ["b", "a"], vecAssembler.outputCol: "vector"}
128+
>>> vecAssembler.transform(df, params).head().vector
129+
SparseVector(2, {1: 1.0})
130+
"""
131+
132+
_java_class = "org.apache.spark.ml.feature.VectorAssembler"
133+
134+
@keyword_only
135+
def __init__(self, inputCols=None, outputCol=None):
136+
"""
137+
__init__(self, inputCols=None, outputCol=None)
138+
"""
139+
super(VectorAssembler, self).__init__()
140+
self._setDefault()
141+
kwargs = self.__init__._input_kwargs
142+
self.setParams(**kwargs)
143+
144+
@keyword_only
145+
def setParams(self, inputCols=None, outputCol=None):
146+
"""
147+
setParams(self, inputCols=None, outputCol=None)
148+
Sets params for this VectorAssembler.
149+
"""
150+
kwargs = self.setParams._input_kwargs
151+
return self._set(**kwargs)
152+
153+
115154
if __name__ == "__main__":
116155
import doctest
117156
from pyspark.context import SparkContext

python/pyspark/ml/param/_shared_params_code_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def get$Name(self):
9595
("predictionCol", "prediction column name", "'prediction'"),
9696
("rawPredictionCol", "raw prediction column name", "'rawPrediction'"),
9797
("inputCol", "input column name", None),
98+
("inputCols", "input column names", None),
9899
("outputCol", "output column name", None),
99100
("numFeatures", "number of features", None)]
100101
code = []

python/pyspark/ml/param/shared.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,35 @@ def getInputCol(self):
223223
return self.getOrDefault(self.inputCol)
224224

225225

226+
class HasInputCols(Params):
227+
"""
228+
Mixin for param inputCols: input column names.
229+
"""
230+
231+
# a placeholder to make it appear in the generated doc
232+
inputCols = Param(Params._dummy(), "inputCols", "input column names")
233+
234+
def __init__(self):
235+
super(HasInputCols, self).__init__()
236+
#: param for input column names
237+
self.inputCols = Param(self, "inputCols", "input column names")
238+
if None is not None:
239+
self._setDefault(inputCols=None)
240+
241+
def setInputCols(self, value):
242+
"""
243+
Sets the value of :py:attr:`inputCols`.
244+
"""
245+
self.paramMap[self.inputCols] = value
246+
return self
247+
248+
def getInputCols(self):
249+
"""
250+
Gets the value of inputCols or its default value.
251+
"""
252+
return self.getOrDefault(self.inputCols)
253+
254+
226255
class HasOutputCol(Params):
227256
"""
228257
Mixin for param outputCol: output column name.

python/pyspark/ml/wrapper.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,9 @@ def _transfer_params_to_java(self, params, java_obj):
6767
paramMap = self.extractParamMap(params)
6868
for param in self.params:
6969
if param in paramMap:
70-
java_obj.set(param.name, paramMap[param])
70+
value = paramMap[param]
71+
java_param = java_obj.getParam(param.name)
72+
java_obj.set(java_param.w(value))
7173

7274
def _empty_java_param_map(self):
7375
"""
@@ -79,7 +81,8 @@ def _create_java_param_map(self, params, java_obj):
7981
paramMap = self._empty_java_param_map()
8082
for param, value in params.items():
8183
if param.parent is self:
82-
paramMap.put(java_obj.getParam(param.name), value)
84+
java_param = java_obj.getParam(param.name)
85+
paramMap.put(java_param.w(value))
8386
return paramMap
8487

8588

@@ -126,10 +129,8 @@ class JavaTransformer(Transformer, JavaWrapper):
126129

127130
def transform(self, dataset, params={}):
128131
java_obj = self._java_obj()
129-
self._transfer_params_to_java({}, java_obj)
130-
java_param_map = self._create_java_param_map(params, java_obj)
131-
return DataFrame(java_obj.transform(dataset._jdf, java_param_map),
132-
dataset.sql_ctx)
132+
self._transfer_params_to_java(params, java_obj)
133+
return DataFrame(java_obj.transform(dataset._jdf), dataset.sql_ctx)
133134

134135

135136
@inherit_doc

0 commit comments

Comments
 (0)