Skip to content

Commit 7f7ea2a

Browse files
committed
[SPARK-7388][SPARK-7383] wrapper for VectorAssembler in Python
1 parent fec7b29 commit 7f7ea2a

File tree

7 files changed

+96
-9
lines changed

7 files changed

+96
-9
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ import org.apache.spark.sql.types._
3030

3131
/**
3232
* :: AlphaComponent ::
33-
* A feature transformer than merge multiple columns into a vector column.
33+
* A feature transformer that merges multiple columns into a vector column.
3434
*/
3535
@AlphaComponent
3636
class VectorAssembler extends Transformer with HasInputCols with HasOutputCol {

mllib/src/main/scala/org/apache/spark/ml/param/params.scala

+18-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import java.util.NoSuchElementException
2222

2323
import scala.annotation.varargs
2424
import scala.collection.mutable
25+
import scala.reflect.ClassTag
2526

2627
import org.apache.spark.annotation.AlphaComponent
2728
import org.apache.spark.ml.util.Identifiable
@@ -218,6 +219,18 @@ class BooleanParam(parent: Params, name: String, doc: String) // No need for isV
218219
override def w(value: Boolean): ParamPair[Boolean] = super.w(value)
219220
}
220221

222+
/** Specialized version of [[Param[Array[T]]]] for Java. */
223+
class ArrayParam[T : ClassTag](parent: Params, name: String, doc: String, isValid: Array[T] => Boolean)
224+
extends Param[Array[T]](parent, name, doc, isValid) {
225+
226+
def this(parent: Params, name: String, doc: String) =
227+
this(parent, name, doc, ParamValidators.alwaysTrue)
228+
229+
override def w(value: Array[T]): ParamPair[Array[T]] = super.w(value)
230+
231+
private[param] def wCast(value: Seq[T]): ParamPair[Array[T]] = w(value.toArray)
232+
}
233+
221234
/**
222235
* A param amd its value.
223236
*/
@@ -311,7 +324,11 @@ trait Params extends Identifiable with Serializable {
311324
*/
312325
protected final def set[T](param: Param[T], value: T): this.type = {
313326
shouldOwn(param)
314-
paramMap.put(param.asInstanceOf[Param[Any]], value)
327+
if (param.isInstanceOf[ArrayParam[_]] && value.isInstanceOf[Seq[_]]) {
328+
paramMap.put(param.asInstanceOf[ArrayParam[Any]].wCast(value.asInstanceOf[Seq[Any]]))
329+
} else {
330+
paramMap.put(param.w(value))
331+
}
315332
this
316333
}
317334

mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala

+1
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ private[shared] object SharedParamsCodeGen {
8383
case _ if c == classOf[Float] => "FloatParam"
8484
case _ if c == classOf[Double] => "DoubleParam"
8585
case _ if c == classOf[Boolean] => "BooleanParam"
86+
case _ if c.isArray => s"ArrayParam[${getTypeString(c.getComponentType)}]"
8687
case _ => s"Param[${getTypeString(c)}]"
8788
}
8889
}

mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ private[ml] trait HasInputCols extends Params {
178178
* Param for input column names.
179179
* @group param
180180
*/
181-
final val inputCols: Param[Array[String]] = new Param[Array[String]](this, "inputCols", "input column names")
181+
final val inputCols: ArrayParam[String] = new ArrayParam[String](this, "inputCols", "input column names")
182182

183183
/** @group getParam */
184184
final def getInputCols: Array[String] = $(inputCols)

python/pyspark/ml/feature.py

+40-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
#
1717

1818
from pyspark.rdd import ignore_unicode_prefix
19-
from pyspark.ml.param.shared import HasInputCol, HasOutputCol, HasNumFeatures
19+
from pyspark.ml.param.shared import HasInputCol, HasInputCols, HasOutputCol, HasNumFeatures
2020
from pyspark.ml.util import keyword_only
2121
from pyspark.ml.wrapper import JavaTransformer
2222
from pyspark.mllib.common import inherit_doc
@@ -112,6 +112,45 @@ def setParams(self, numFeatures=1 << 18, inputCol=None, outputCol=None):
112112
return self._set(**kwargs)
113113

114114

115+
@inherit_doc
116+
class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol):
117+
"""
118+
A feature transformer that merges multiple columns into a vector column.
119+
120+
>>> from pyspark.sql import Row
121+
>>> df = sc.parallelize([Row(a=1, b=0, c=3)]).toDF()
122+
>>> vecAssembler = VectorAssembler(inputCols=["a", "b", "c"], outputCol="features")
123+
>>> vecAssembler.transform(df).head().features
124+
SparseVector(3, {0: 1.0, 2: 3.0})
125+
>>> vecAssembler.setParams(outputCol="freqs").transform(df).head().freqs
126+
SparseVector(3, {0: 1.0, 2: 3.0})
127+
>>> params = {vecAssembler.inputCols: ["b", "a"], vecAssembler.outputCol: "vector"}
128+
>>> vecAssembler.transform(df, params).head().vector
129+
SparseVector(2, {1: 1.0})
130+
"""
131+
132+
_java_class = "org.apache.spark.ml.feature.VectorAssembler"
133+
134+
@keyword_only
135+
def __init__(self, inputCols=None, outputCol=None):
136+
"""
137+
__init__(self, inputCols=None, outputCol=None)
138+
"""
139+
super(VectorAssembler, self).__init__()
140+
self._setDefault()
141+
kwargs = self.__init__._input_kwargs
142+
self.setParams(**kwargs)
143+
144+
@keyword_only
145+
def setParams(self, inputCols=None, outputCol=None):
146+
"""
147+
setParams(self, inputCols=None, outputCol=None)
148+
Sets params for this VectorAssembler.
149+
"""
150+
kwargs = self.setParams._input_kwargs
151+
return self._set(**kwargs)
152+
153+
115154
if __name__ == "__main__":
116155
import doctest
117156
from pyspark.context import SparkContext

python/pyspark/ml/param/shared.py

+29
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,35 @@ def getInputCol(self):
223223
return self.getOrDefault(self.inputCol)
224224

225225

226+
class HasInputCols(Params):
227+
"""
228+
Mixin for param inputCols: input column names.
229+
"""
230+
231+
# a placeholder to make it appear in the generated doc
232+
inputCols = Param(Params._dummy(), "inputCols", "input column names")
233+
234+
def __init__(self):
235+
super(HasInputCols, self).__init__()
236+
#: param for input column names
237+
self.inputCols = Param(self, "inputCols", "input column names")
238+
if None is not None:
239+
self._setDefault(inputCols=None)
240+
241+
def setInputCols(self, value):
242+
"""
243+
Sets the value of :py:attr:`inputCols`.
244+
"""
245+
self.paramMap[self.inputCols] = value
246+
return self
247+
248+
def getInputCols(self):
249+
"""
250+
Gets the value of inputCols or its default value.
251+
"""
252+
return self.getOrDefault(self.inputCols)
253+
254+
226255
class HasOutputCol(Params):
227256
"""
228257
Mixin for param outputCol: output column name.

python/pyspark/ml/wrapper.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,10 @@ def _transfer_params_to_java(self, params, java_obj):
6767
paramMap = self.extractParamMap(params)
6868
for param in self.params:
6969
if param in paramMap:
70-
java_obj.set(param.name, paramMap[param])
70+
value = paramMap[param]
71+
if isinstance(value, list):
72+
value = _jvm().PythonUtils.toSeq(value)
73+
java_obj.set(param.name, value)
7174

7275
def _empty_java_param_map(self):
7376
"""
@@ -126,10 +129,8 @@ class JavaTransformer(Transformer, JavaWrapper):
126129

127130
def transform(self, dataset, params={}):
128131
java_obj = self._java_obj()
129-
self._transfer_params_to_java({}, java_obj)
130-
java_param_map = self._create_java_param_map(params, java_obj)
131-
return DataFrame(java_obj.transform(dataset._jdf, java_param_map),
132-
dataset.sql_ctx)
132+
self._transfer_params_to_java(params, java_obj)
133+
return DataFrame(java_obj.transform(dataset._jdf), dataset.sql_ctx)
133134

134135

135136
@inherit_doc

0 commit comments

Comments
 (0)