Skip to content

Commit 5153cff

Browse files
committed
simplify java models
1 parent 036ca04 commit 5153cff

File tree

3 files changed

+49
-16
lines changed

3 files changed

+49
-16
lines changed

mllib/src/main/scala/org/apache/spark/ml/param/params.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,6 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten
293293
new ParamMap(this.map ++ other.map)
294294
}
295295

296-
297296
/**
298297
* Adds all parameters from the input param map into this param map.
299298
*/

python/pyspark/ml/__init__.py

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,18 @@ def transform(self, dataset, params={}):
101101
raise NotImplementedError()
102102

103103

104+
@inherit_doc
105+
class Model(Transformer):
106+
"""
107+
Abstract class for models fitted by :py:class:`Estimator`s.
108+
"""
109+
110+
___metaclass__ = ABCMeta
111+
112+
def __init__(self):
113+
super(Model, self).__init__()
114+
115+
104116
@inherit_doc
105117
class Pipeline(Estimator):
106118
"""
@@ -169,7 +181,7 @@ def fit(self, dataset, params={}):
169181

170182

171183
@inherit_doc
172-
class PipelineModel(Transformer):
184+
class PipelineModel(Model):
173185
"""
174186
Represents a compiled pipeline with transformers and fitted models.
175187
"""
@@ -204,9 +216,9 @@ def _java_class(self):
204216
"""
205217
raise NotImplementedError
206218

207-
def _create_java_obj(self):
219+
def _java_obj(self):
208220
"""
209-
Creates a new Java object and returns its reference.
221+
Returns or creates a Java object.
210222
"""
211223
java_obj = _jvm()
212224
for name in self._java_class.split("."):
@@ -231,6 +243,13 @@ def _empty_java_param_map(self):
231243
"""
232244
return _jvm().org.apache.spark.ml.param.ParamMap()
233245

246+
def _create_java_param_map(self, params, java_obj):
247+
paramMap = self._empty_java_param_map()
248+
for param, value in params.items():
249+
if param.parent is self:
250+
paramMap.put(java_obj.getParam(param.name), value)
251+
return paramMap
252+
234253

235254
@inherit_doc
236255
class JavaEstimator(Estimator, JavaWrapper):
@@ -259,7 +278,7 @@ def _fit_java(self, dataset, params={}):
259278
:param params: additional params (overwriting embedded values)
260279
:return: fitted Java model
261280
"""
262-
java_obj = self._create_java_obj()
281+
java_obj = self._java_obj()
263282
self._transfer_params_to_java(params, java_obj)
264283
return java_obj.fit(dataset._jschema_rdd, self._empty_java_param_map())
265284

@@ -281,7 +300,24 @@ def __init__(self):
281300
super(JavaTransformer, self).__init__()
282301

283302
def transform(self, dataset, params={}):
284-
java_obj = self._create_java_obj()
285-
self._transfer_params_to_java(params, java_obj)
286-
return SchemaRDD(java_obj.transform(dataset._jschema_rdd, self._empty_java_param_map()),
303+
java_obj = self._java_obj()
304+
self._transfer_params_to_java({}, java_obj)
305+
java_param_map = self._create_java_param_map(params, java_obj)
306+
return SchemaRDD(java_obj.transform(dataset._jschema_rdd, java_param_map),
287307
dataset.sql_ctx)
308+
309+
310+
@inherit_doc
311+
class JavaModel(JavaTransformer):
312+
"""
313+
Base class for :py:class:`Model`s that wrap Java/Scala
314+
implementations.
315+
"""
316+
317+
__metaclass__ = ABCMeta
318+
319+
def __init__(self):
320+
super(JavaTransformer, self).__init__()
321+
322+
def _java_obj(self):
323+
return self._java_model

python/pyspark/ml/classification.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
# limitations under the License.
1616
#
1717

18-
from pyspark.sql import SchemaRDD, inherit_doc
19-
from pyspark.ml import JavaEstimator, Transformer, _jvm
18+
from pyspark.sql import inherit_doc
19+
from pyspark.ml import JavaEstimator, JavaModel
2020
from pyspark.ml.param.shared import HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,\
2121
HasRegParam
2222

@@ -40,7 +40,7 @@ def _create_model(self, java_model):
4040

4141

4242
@inherit_doc
43-
class LogisticRegressionModel(Transformer):
43+
class LogisticRegressionModel(JavaModel):
4444
"""
4545
Model fitted by LogisticRegression.
4646
"""
@@ -49,8 +49,6 @@ def __init__(self, java_model):
4949
super(LogisticRegressionModel, self).__init__()
5050
self._java_model = java_model
5151

52-
def transform(self, dataset, params={}):
53-
# TODO: handle params here.
54-
return SchemaRDD(self._java_model.transform(
55-
dataset._jschema_rdd,
56-
_jvm().org.apache.spark.ml.param.ParamMap()), dataset.sql_ctx)
52+
@property
53+
def _java_class(self):
54+
return "org.apache.spark.ml.classification.LogisticRegressionModel"

0 commit comments

Comments
 (0)