26
26
27
27
28
28
def _jvm ():
29
- return SparkContext ._jvm
30
-
31
-
32
- def _inherit_doc (cls ):
33
- for name , func in vars (cls ).items ():
34
- # only inherit docstring for public functions
35
- if name .startswith ("_" ):
36
- continue
37
- if not func .__doc__ :
38
- for parent in cls .__bases__ :
39
- parent_func = getattr (parent , name , None )
40
- if parent_func and getattr (parent_func , "__doc__" , None ):
41
- func .__doc__ = parent_func .__doc__
42
- break
43
- return cls
29
+ """
30
+ Returns the JVM view associated with SparkContext. Must be called
31
+ after SparkContext is initialized.
32
+ """
33
+ jvm = SparkContext ._jvm
34
+ if jvm :
35
+ return jvm
36
+ else :
37
+ raise AttributeError ("Cannot load _jvm from SparkContext. Is SparkContext initialized?" )
44
38
45
39
46
40
@inherit_doc
@@ -50,6 +44,8 @@ class PipelineStage(Params):
50
44
:py:class:`Transformer`.
51
45
"""
52
46
47
+ __metaclass__ = ABCMeta
48
+
53
49
def __init__ (self ):
54
50
super (PipelineStage , self ).__init__ ()
55
51
@@ -147,38 +143,54 @@ def getStages(self):
147
143
return self .paramMap [self .stages ]
148
144
149
145
def fit (self , dataset , params = {}):
150
- map = self ._merge_params (params )
151
- transformers = []
152
- for stage in self .getStages ():
153
- if isinstance (stage , Transformer ):
154
- transformers .append (stage )
155
- dataset = stage .transform (dataset , map )
156
- elif isinstance (stage , Estimator ):
157
- model = stage .fit (dataset , map )
158
- transformers .append (model )
159
- dataset = model .transform (dataset , map )
160
- else :
146
+ paramMap = self ._merge_params (params )
147
+ stages = paramMap (self .stages )
148
+ for stage in stages :
149
+ if not (isinstance (stage , Estimator ) or isinstance (stage , Transformer )):
161
150
raise ValueError (
162
151
"Cannot recognize a pipeline stage of type %s." % type (stage ).__name__ )
152
+ indexOfLastEstimator = - 1
153
+ for i , stage in enumerate (stages ):
154
+ if isinstance (stage , Estimator ):
155
+ indexOfLastEstimator = i
156
+ transformers = []
157
+ for i , stage in enumerate (stages ):
158
+ if i <= indexOfLastEstimator :
159
+ if isinstance (stage , Transformer ):
160
+ transformers .append (stage )
161
+ dataset = stage .transform (dataset , paramMap )
162
+ else : # must be an Estimator
163
+ model = stage .fit (dataset , paramMap )
164
+ transformers .append (model )
165
+ dataset = model .transform (dataset , paramMap )
166
+ else :
167
+ transformers .append (stage )
163
168
return PipelineModel (transformers )
164
169
165
170
166
171
@inherit_doc
167
172
class PipelineModel (Transformer ):
173
+ """
174
+ Represents a compiled pipeline with transformers and fitted models.
175
+ """
168
176
169
177
def __init__ (self , transformers ):
170
178
super (PipelineModel , self ).__init__ ()
171
179
self .transformers = transformers
172
180
173
181
def transform (self , dataset , params = {}):
174
- map = self ._merge_params (params )
182
+ paramMap = self ._merge_params (params )
175
183
for t in self .transformers :
176
- dataset = t .transform (dataset , map )
184
+ dataset = t .transform (dataset , paramMap )
177
185
return dataset
178
186
179
187
180
188
@inherit_doc
181
- class JavaWrapper (object ):
189
+ class JavaWrapper (Params ):
190
+ """
191
+ Utility class to help create wrapper classes from Java/Scala
192
+ implementations of pipeline components.
193
+ """
182
194
183
195
__metaclass__ = ABCMeta
184
196
@@ -187,17 +199,45 @@ def __init__(self):
187
199
188
200
@abstractproperty
189
201
def _java_class (self ):
202
+ """
203
+ Fully-qualified class name of the wrapped Java component.
204
+ """
190
205
raise NotImplementedError
191
206
192
207
def _create_java_obj (self ):
208
+ """
209
+ Creates a new Java object and returns its reference.
210
+ """
193
211
java_obj = _jvm ()
194
212
for name in self ._java_class .split ("." ):
195
213
java_obj = getattr (java_obj , name )
196
214
return java_obj ()
197
215
216
+ def _transfer_params_to_java (self , params , java_obj ):
217
+ """
218
+ Transforms the embedded params and additional params to the
219
+ input Java object.
220
+ :param params: additional params (overwriting embedded values)
221
+ :param java_obj: Java object to receive the params
222
+ """
223
+ paramMap = self ._merge_params (params )
224
+ for param in self .params ():
225
+ if param in paramMap :
226
+ java_obj .set (param .name , paramMap [param ])
227
+
228
+ def _empty_java_param_map (self ):
229
+ """
230
+ Returns an empty Java ParamMap reference.
231
+ """
232
+ return _jvm ().org .apache .spark .ml .param .ParamMap ()
233
+
198
234
199
235
@inherit_doc
200
236
class JavaEstimator (Estimator , JavaWrapper ):
237
+ """
238
+ Base class for :py:class:`Estimator`s that wrap Java/Scala
239
+ implementations.
240
+ """
201
241
202
242
__metaclass__ = ABCMeta
203
243
@@ -206,12 +246,22 @@ def __init__(self):
206
246
207
247
@abstractmethod
208
248
def _create_model (self , java_model ):
249
+ """
250
+ Creates a model from the input Java model reference.
251
+ """
209
252
raise NotImplementedError
210
253
211
254
def _fit_java (self , dataset , params = {}):
255
+ """
256
+ Fits a Java model to the input dataset.
257
+ :param dataset: input dataset, which is an instance of
258
+ :py:class:`pyspark.sql.SchemaRDD`
259
+ :param params: additional params (overwriting embedded values)
260
+ :return: fitted Java model
261
+ """
212
262
java_obj = self ._create_java_obj ()
213
263
self ._transfer_params_to_java (params , java_obj )
214
- return java_obj .fit (dataset ._jschema_rdd , _jvm (). org . apache . spark . ml . param . ParamMap ())
264
+ return java_obj .fit (dataset ._jschema_rdd , self . _empty_java_param_map ())
215
265
216
266
def fit (self , dataset , params = {}):
217
267
java_model = self ._fit_java (dataset , params )
@@ -220,6 +270,10 @@ def fit(self, dataset, params={}):
220
270
221
271
@inherit_doc
222
272
class JavaTransformer (Transformer , JavaWrapper ):
273
+ """
274
+ Base class for :py:class:`Transformer`s that wrap Java/Scala
275
+ implementations.
276
+ """
223
277
224
278
__metaclass__ = ABCMeta
225
279
@@ -229,6 +283,5 @@ def __init__(self):
229
283
def transform (self , dataset , params = {}):
230
284
java_obj = self ._create_java_obj ()
231
285
self ._transfer_params_to_java (params , java_obj )
232
- return SchemaRDD (java_obj .transform (dataset ._jschema_rdd ,
233
- _jvm ().org .apache .spark .ml .param .ParamMap ()),
286
+ return SchemaRDD (java_obj .transform (dataset ._jschema_rdd , self ._empty_java_param_map ()),
234
287
dataset .sql_ctx )
0 commit comments