@@ -101,6 +101,18 @@ def transform(self, dataset, params={}):
101
101
raise NotImplementedError ()
102
102
103
103
104
+ @inherit_doc
105
+ class Model (Transformer ):
106
+ """
107
+ Abstract class for models fitted by :py:class:`Estimator`s.
108
+ """
109
+
110
+ ___metaclass__ = ABCMeta
111
+
112
+ def __init__ (self ):
113
+ super (Model , self ).__init__ ()
114
+
115
+
104
116
@inherit_doc
105
117
class Pipeline (Estimator ):
106
118
"""
@@ -169,7 +181,7 @@ def fit(self, dataset, params={}):
169
181
170
182
171
183
@inherit_doc
172
- class PipelineModel (Transformer ):
184
+ class PipelineModel (Model ):
173
185
"""
174
186
Represents a compiled pipeline with transformers and fitted models.
175
187
"""
@@ -204,9 +216,9 @@ def _java_class(self):
204
216
"""
205
217
raise NotImplementedError
206
218
207
- def _create_java_obj (self ):
219
+ def _java_obj (self ):
208
220
"""
209
- Creates a new Java object and returns its reference .
221
+ Returns or creates a Java object.
210
222
"""
211
223
java_obj = _jvm ()
212
224
for name in self ._java_class .split ("." ):
@@ -231,6 +243,13 @@ def _empty_java_param_map(self):
231
243
"""
232
244
return _jvm ().org .apache .spark .ml .param .ParamMap ()
233
245
246
+ def _create_java_param_map (self , params , java_obj ):
247
+ paramMap = self ._empty_java_param_map ()
248
+ for param , value in params .items ():
249
+ if param .parent is self :
250
+ paramMap .put (java_obj .getParam (param .name ), value )
251
+ return paramMap
252
+
234
253
235
254
@inherit_doc
236
255
class JavaEstimator (Estimator , JavaWrapper ):
@@ -259,7 +278,7 @@ def _fit_java(self, dataset, params={}):
259
278
:param params: additional params (overwriting embedded values)
260
279
:return: fitted Java model
261
280
"""
262
- java_obj = self ._create_java_obj ()
281
+ java_obj = self ._java_obj ()
263
282
self ._transfer_params_to_java (params , java_obj )
264
283
return java_obj .fit (dataset ._jschema_rdd , self ._empty_java_param_map ())
265
284
@@ -281,7 +300,24 @@ def __init__(self):
281
300
super (JavaTransformer , self ).__init__ ()
282
301
283
302
def transform (self , dataset , params = {}):
284
- java_obj = self ._create_java_obj ()
285
- self ._transfer_params_to_java (params , java_obj )
286
- return SchemaRDD (java_obj .transform (dataset ._jschema_rdd , self ._empty_java_param_map ()),
303
+ java_obj = self ._java_obj ()
304
+ self ._transfer_params_to_java ({}, java_obj )
305
+ java_param_map = self ._create_java_param_map (params , java_obj )
306
+ return SchemaRDD (java_obj .transform (dataset ._jschema_rdd , java_param_map ),
287
307
dataset .sql_ctx )
308
+
309
+
310
+ @inherit_doc
311
+ class JavaModel (JavaTransformer ):
312
+ """
313
+ Base class for :py:class:`Model`s that wrap Java/Scala
314
+ implementations.
315
+ """
316
+
317
+ __metaclass__ = ABCMeta
318
+
319
+ def __init__ (self ):
320
+ super (JavaTransformer , self ).__init__ ()
321
+
322
+ def _java_obj (self ):
323
+ return self ._java_model
0 commit comments