|
15 | 15 | # limitations under the License.
|
16 | 16 | #
|
17 | 17 |
|
18 |
| -from abc import ABCMeta, abstractmethod, abstractproperty |
| 18 | +from pyspark.ml.param import * |
| 19 | +from pyspark.ml.pipeline import * |
19 | 20 |
|
20 |
| -from pyspark import SparkContext |
21 |
| -from pyspark.sql import SchemaRDD, inherit_doc # TODO: move inherit_doc to Spark Core |
22 |
| -from pyspark.ml.param import Param, Params |
23 |
| -from pyspark.ml.util import Identifiable |
24 |
| - |
25 |
| -__all__ = ["Pipeline", "Transformer", "Estimator", "param", "feature", "classification"] |
26 |
| - |
27 |
| - |
28 |
| -def _jvm(): |
29 |
| - """ |
30 |
| - Returns the JVM view associated with SparkContext. Must be called |
31 |
| - after SparkContext is initialized. |
32 |
| - """ |
33 |
| - jvm = SparkContext._jvm |
34 |
| - if jvm: |
35 |
| - return jvm |
36 |
| - else: |
37 |
| - raise AttributeError("Cannot load _jvm from SparkContext. Is SparkContext initialized?") |
38 |
| - |
39 |
| - |
40 |
| -@inherit_doc |
41 |
| -class PipelineStage(Params): |
42 |
| - """ |
43 |
| - A stage in a pipeline, either an :py:class:`Estimator` or a |
44 |
| - :py:class:`Transformer`. |
45 |
| - """ |
46 |
| - |
47 |
| - __metaclass__ = ABCMeta |
48 |
| - |
49 |
| - def __init__(self): |
50 |
| - super(PipelineStage, self).__init__() |
51 |
| - |
52 |
| - |
53 |
| -@inherit_doc |
54 |
| -class Estimator(PipelineStage): |
55 |
| - """ |
56 |
| - Abstract class for estimators that fit models to data. |
57 |
| - """ |
58 |
| - |
59 |
| - __metaclass__ = ABCMeta |
60 |
| - |
61 |
| - def __init__(self): |
62 |
| - super(Estimator, self).__init__() |
63 |
| - |
64 |
| - @abstractmethod |
65 |
| - def fit(self, dataset, params={}): |
66 |
| - """ |
67 |
| - Fits a model to the input dataset with optional parameters. |
68 |
| -
|
69 |
| - :param dataset: input dataset, which is an instance of |
70 |
| - :py:class:`pyspark.sql.SchemaRDD` |
71 |
| - :param params: an optional param map that overwrites embedded |
72 |
| - params |
73 |
| - :returns: fitted model |
74 |
| - """ |
75 |
| - raise NotImplementedError() |
76 |
| - |
77 |
| - |
78 |
| -@inherit_doc |
79 |
| -class Transformer(PipelineStage): |
80 |
| - """ |
81 |
| - Abstract class for transformers that transform one dataset into |
82 |
| - another. |
83 |
| - """ |
84 |
| - |
85 |
| - __metaclass__ = ABCMeta |
86 |
| - |
87 |
| - def __init__(self): |
88 |
| - super(Transformer, self).__init__() |
89 |
| - |
90 |
| - @abstractmethod |
91 |
| - def transform(self, dataset, params={}): |
92 |
| - """ |
93 |
| - Transforms the input dataset with optional parameters. |
94 |
| -
|
95 |
| - :param dataset: input dataset, which is an instance of |
96 |
| - :py:class:`pyspark.sql.SchemaRDD` |
97 |
| - :param params: an optional param map that overwrites embedded |
98 |
| - params |
99 |
| - :returns: transformed dataset |
100 |
| - """ |
101 |
| - raise NotImplementedError() |
102 |
| - |
103 |
| - |
104 |
| -@inherit_doc |
105 |
| -class Model(Transformer): |
106 |
| - """ |
107 |
| - Abstract class for models fitted by :py:class:`Estimator`s. |
108 |
| - """ |
109 |
| - |
110 |
| - ___metaclass__ = ABCMeta |
111 |
| - |
112 |
| - def __init__(self): |
113 |
| - super(Model, self).__init__() |
114 |
| - |
115 |
| - |
116 |
| -@inherit_doc |
117 |
| -class Pipeline(Estimator): |
118 |
| - """ |
119 |
| - A simple pipeline, which acts as an estimator. A Pipeline consists |
120 |
| - of a sequence of stages, each of which is either an |
121 |
| - :py:class:`Estimator` or a :py:class:`Transformer`. When |
122 |
| - :py:meth:`Pipeline.fit` is called, the stages are executed in |
123 |
| - order. If a stage is an :py:class:`Estimator`, its |
124 |
| - :py:meth:`Estimator.fit` method will be called on the input |
125 |
| - dataset to fit a model. Then the model, which is a transformer, |
126 |
| - will be used to transform the dataset as the input to the next |
127 |
| - stage. If a stage is a :py:class:`Transformer`, its |
128 |
| - :py:meth:`Transformer.transform` method will be called to produce |
129 |
| - the dataset for the next stage. The fitted model from a |
130 |
| - :py:class:`Pipeline` is an :py:class:`PipelineModel`, which |
131 |
| - consists of fitted models and transformers, corresponding to the |
132 |
| - pipeline stages. If there are no stages, the pipeline acts as an |
133 |
| - identity transformer. |
134 |
| - """ |
135 |
| - |
136 |
| - def __init__(self): |
137 |
| - super(Pipeline, self).__init__() |
138 |
| - #: Param for pipeline stages. |
139 |
| - self.stages = Param(self, "stages", "pipeline stages") |
140 |
| - |
141 |
| - def setStages(self, value): |
142 |
| - """ |
143 |
| - Set pipeline stages. |
144 |
| - :param value: a list of transformers or estimators |
145 |
| - :return: the pipeline instance |
146 |
| - """ |
147 |
| - self.paramMap[self.stages] = value |
148 |
| - return self |
149 |
| - |
150 |
| - def getStages(self): |
151 |
| - """ |
152 |
| - Get pipeline stages. |
153 |
| - """ |
154 |
| - if self.stages in self.paramMap: |
155 |
| - return self.paramMap[self.stages] |
156 |
| - |
157 |
| - def fit(self, dataset, params={}): |
158 |
| - paramMap = self._merge_params(params) |
159 |
| - stages = paramMap[self.stages] |
160 |
| - for stage in stages: |
161 |
| - if not (isinstance(stage, Estimator) or isinstance(stage, Transformer)): |
162 |
| - raise ValueError( |
163 |
| - "Cannot recognize a pipeline stage of type %s." % type(stage).__name__) |
164 |
| - indexOfLastEstimator = -1 |
165 |
| - for i, stage in enumerate(stages): |
166 |
| - if isinstance(stage, Estimator): |
167 |
| - indexOfLastEstimator = i |
168 |
| - transformers = [] |
169 |
| - for i, stage in enumerate(stages): |
170 |
| - if i <= indexOfLastEstimator: |
171 |
| - if isinstance(stage, Transformer): |
172 |
| - transformers.append(stage) |
173 |
| - dataset = stage.transform(dataset, paramMap) |
174 |
| - else: # must be an Estimator |
175 |
| - model = stage.fit(dataset, paramMap) |
176 |
| - transformers.append(model) |
177 |
| - if i < indexOfLastEstimator: |
178 |
| - dataset = model.transform(dataset, paramMap) |
179 |
| - else: |
180 |
| - transformers.append(stage) |
181 |
| - return PipelineModel(transformers) |
182 |
| - |
183 |
| - |
184 |
| -@inherit_doc |
185 |
| -class PipelineModel(Model): |
186 |
| - """ |
187 |
| - Represents a compiled pipeline with transformers and fitted models. |
188 |
| - """ |
189 |
| - |
190 |
| - def __init__(self, transformers): |
191 |
| - super(PipelineModel, self).__init__() |
192 |
| - self.transformers = transformers |
193 |
| - |
194 |
| - def transform(self, dataset, params={}): |
195 |
| - paramMap = self._merge_params(params) |
196 |
| - for t in self.transformers: |
197 |
| - dataset = t.transform(dataset, paramMap) |
198 |
| - return dataset |
199 |
| - |
200 |
| - |
201 |
| -@inherit_doc |
202 |
| -class JavaWrapper(Params): |
203 |
| - """ |
204 |
| - Utility class to help create wrapper classes from Java/Scala |
205 |
| - implementations of pipeline components. |
206 |
| - """ |
207 |
| - |
208 |
| - __metaclass__ = ABCMeta |
209 |
| - |
210 |
| - def __init__(self): |
211 |
| - super(JavaWrapper, self).__init__() |
212 |
| - |
213 |
| - @abstractproperty |
214 |
| - def _java_class(self): |
215 |
| - """ |
216 |
| - Fully-qualified class name of the wrapped Java component. |
217 |
| - """ |
218 |
| - raise NotImplementedError |
219 |
| - |
220 |
| - def _java_obj(self): |
221 |
| - """ |
222 |
| - Returns or creates a Java object. |
223 |
| - """ |
224 |
| - java_obj = _jvm() |
225 |
| - for name in self._java_class.split("."): |
226 |
| - java_obj = getattr(java_obj, name) |
227 |
| - return java_obj() |
228 |
| - |
229 |
| - def _transfer_params_to_java(self, params, java_obj): |
230 |
| - """ |
231 |
| - Transforms the embedded params and additional params to the |
232 |
| - input Java object. |
233 |
| - :param params: additional params (overwriting embedded values) |
234 |
| - :param java_obj: Java object to receive the params |
235 |
| - """ |
236 |
| - paramMap = self._merge_params(params) |
237 |
| - for param in self.params: |
238 |
| - if param in paramMap: |
239 |
| - java_obj.set(param.name, paramMap[param]) |
240 |
| - |
241 |
| - def _empty_java_param_map(self): |
242 |
| - """ |
243 |
| - Returns an empty Java ParamMap reference. |
244 |
| - """ |
245 |
| - return _jvm().org.apache.spark.ml.param.ParamMap() |
246 |
| - |
247 |
| - def _create_java_param_map(self, params, java_obj): |
248 |
| - paramMap = self._empty_java_param_map() |
249 |
| - for param, value in params.items(): |
250 |
| - if param.parent is self: |
251 |
| - paramMap.put(java_obj.getParam(param.name), value) |
252 |
| - return paramMap |
253 |
| - |
254 |
| - |
255 |
| -@inherit_doc |
256 |
| -class JavaEstimator(Estimator, JavaWrapper): |
257 |
| - """ |
258 |
| - Base class for :py:class:`Estimator`s that wrap Java/Scala |
259 |
| - implementations. |
260 |
| - """ |
261 |
| - |
262 |
| - __metaclass__ = ABCMeta |
263 |
| - |
264 |
| - def __init__(self): |
265 |
| - super(JavaEstimator, self).__init__() |
266 |
| - |
267 |
| - @abstractmethod |
268 |
| - def _create_model(self, java_model): |
269 |
| - """ |
270 |
| - Creates a model from the input Java model reference. |
271 |
| - """ |
272 |
| - raise NotImplementedError |
273 |
| - |
274 |
| - def _fit_java(self, dataset, params={}): |
275 |
| - """ |
276 |
| - Fits a Java model to the input dataset. |
277 |
| - :param dataset: input dataset, which is an instance of |
278 |
| - :py:class:`pyspark.sql.SchemaRDD` |
279 |
| - :param params: additional params (overwriting embedded values) |
280 |
| - :return: fitted Java model |
281 |
| - """ |
282 |
| - java_obj = self._java_obj() |
283 |
| - self._transfer_params_to_java(params, java_obj) |
284 |
| - return java_obj.fit(dataset._jschema_rdd, self._empty_java_param_map()) |
285 |
| - |
286 |
| - def fit(self, dataset, params={}): |
287 |
| - java_model = self._fit_java(dataset, params) |
288 |
| - return self._create_model(java_model) |
289 |
| - |
290 |
| - |
291 |
| -@inherit_doc |
292 |
| -class JavaTransformer(Transformer, JavaWrapper): |
293 |
| - """ |
294 |
| - Base class for :py:class:`Transformer`s that wrap Java/Scala |
295 |
| - implementations. |
296 |
| - """ |
297 |
| - |
298 |
| - __metaclass__ = ABCMeta |
299 |
| - |
300 |
| - def __init__(self): |
301 |
| - super(JavaTransformer, self).__init__() |
302 |
| - |
303 |
| - def transform(self, dataset, params={}): |
304 |
| - java_obj = self._java_obj() |
305 |
| - self._transfer_params_to_java({}, java_obj) |
306 |
| - java_param_map = self._create_java_param_map(params, java_obj) |
307 |
| - return SchemaRDD(java_obj.transform(dataset._jschema_rdd, java_param_map), |
308 |
| - dataset.sql_ctx) |
309 |
| - |
310 |
| - |
311 |
| -@inherit_doc |
312 |
| -class JavaModel(JavaTransformer): |
313 |
| - """ |
314 |
| - Base class for :py:class:`Model`s that wrap Java/Scala |
315 |
| - implementations. |
316 |
| - """ |
317 |
| - |
318 |
| - __metaclass__ = ABCMeta |
319 |
| - |
320 |
| - def __init__(self): |
321 |
| - super(JavaTransformer, self).__init__() |
322 |
| - |
323 |
| - def _java_obj(self): |
324 |
| - return self._java_model |
| 21 | +__all__ = ["Param", "Params", "Transformer", "Estimator", "Pipeline"] |
0 commit comments