|
16 | 16 | #
|
17 | 17 |
|
18 | 18 | from pyspark.rdd import ignore_unicode_prefix
|
19 |
| -from pyspark.ml.param.shared import HasInputCol, HasOutputCol, HasNumFeatures |
| 19 | +from pyspark.ml.param.shared import HasInputCol, HasInputCols, HasOutputCol, HasNumFeatures |
20 | 20 | from pyspark.ml.util import keyword_only
|
21 | 21 | from pyspark.ml.wrapper import JavaTransformer
|
22 | 22 | from pyspark.mllib.common import inherit_doc
|
@@ -112,6 +112,45 @@ def setParams(self, numFeatures=1 << 18, inputCol=None, outputCol=None):
|
112 | 112 | return self._set(**kwargs)
|
113 | 113 |
|
114 | 114 |
|
| 115 | +@inherit_doc |
| 116 | +class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol): |
| 117 | + """ |
| 118 | + A feature transformer that merges multiple columns into a vector column. |
| 119 | +
|
| 120 | + >>> from pyspark.sql import Row |
| 121 | + >>> df = sc.parallelize([Row(a=1, b=0, c=3)]).toDF() |
| 122 | + >>> vecAssembler = VectorAssembler(inputCols=["a", "b", "c"], outputCol="features") |
| 123 | + >>> vecAssembler.transform(df).head().features |
| 124 | + SparseVector(3, {0: 1.0, 2: 3.0}) |
| 125 | + >>> vecAssembler.setParams(outputCol="freqs").transform(df).head().freqs |
| 126 | + SparseVector(3, {0: 1.0, 2: 3.0}) |
| 127 | + >>> params = {vecAssembler.inputCols: ["b", "a"], vecAssembler.outputCol: "vector"} |
| 128 | + >>> vecAssembler.transform(df, params).head().vector |
| 129 | + SparseVector(2, {1: 1.0}) |
| 130 | + """ |
| 131 | + |
| 132 | + _java_class = "org.apache.spark.ml.feature.VectorAssembler" |
| 133 | + |
| 134 | + @keyword_only |
| 135 | + def __init__(self, inputCols=None, outputCol=None): |
| 136 | + """ |
| 137 | + __init__(self, inputCols=None, outputCol=None) |
| 138 | + """ |
| 139 | + super(VectorAssembler, self).__init__() |
| 140 | + self._setDefault() |
| 141 | + kwargs = self.__init__._input_kwargs |
| 142 | + self.setParams(**kwargs) |
| 143 | + |
| 144 | + @keyword_only |
| 145 | + def setParams(self, inputCols=None, outputCol=None): |
| 146 | + """ |
| 147 | + setParams(self, inputCols=None, outputCol=None) |
| 148 | + Sets params for this VectorAssembler. |
| 149 | + """ |
| 150 | + kwargs = self.setParams._input_kwargs |
| 151 | + return self._set(**kwargs) |
| 152 | + |
| 153 | + |
115 | 154 | if __name__ == "__main__":
|
116 | 155 | import doctest
|
117 | 156 | from pyspark.context import SparkContext
|
|
0 commit comments