Skip to content

Commit d3e8dbe

Browse files
committed
more docs
optimize pipeline.fit impl
1 parent 56de571 commit d3e8dbe

File tree

3 files changed

+86
-38
lines changed

3 files changed

+86
-38
lines changed

python/pyspark/ml/__init__.py

Lines changed: 85 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -26,21 +26,15 @@
2626

2727

2828
def _jvm():
29-
return SparkContext._jvm
30-
31-
32-
def _inherit_doc(cls):
33-
for name, func in vars(cls).items():
34-
# only inherit docstring for public functions
35-
if name.startswith("_"):
36-
continue
37-
if not func.__doc__:
38-
for parent in cls.__bases__:
39-
parent_func = getattr(parent, name, None)
40-
if parent_func and getattr(parent_func, "__doc__", None):
41-
func.__doc__ = parent_func.__doc__
42-
break
43-
return cls
29+
"""
30+
Returns the JVM view associated with SparkContext. Must be called
31+
after SparkContext is initialized.
32+
"""
33+
jvm = SparkContext._jvm
34+
if jvm:
35+
return jvm
36+
else:
37+
raise AttributeError("Cannot load _jvm from SparkContext. Is SparkContext initialized?")
4438

4539

4640
@inherit_doc
@@ -50,6 +44,8 @@ class PipelineStage(Params):
5044
:py:class:`Transformer`.
5145
"""
5246

47+
__metaclass__ = ABCMeta
48+
5349
def __init__(self):
5450
super(PipelineStage, self).__init__()
5551

@@ -147,38 +143,54 @@ def getStages(self):
147143
return self.paramMap[self.stages]
148144

149145
def fit(self, dataset, params={}):
150-
map = self._merge_params(params)
151-
transformers = []
152-
for stage in self.getStages():
153-
if isinstance(stage, Transformer):
154-
transformers.append(stage)
155-
dataset = stage.transform(dataset, map)
156-
elif isinstance(stage, Estimator):
157-
model = stage.fit(dataset, map)
158-
transformers.append(model)
159-
dataset = model.transform(dataset, map)
160-
else:
146+
paramMap = self._merge_params(params)
147+
stages = paramMap(self.stages)
148+
for stage in stages:
149+
if not (isinstance(stage, Estimator) or isinstance(stage, Transformer)):
161150
raise ValueError(
162151
"Cannot recognize a pipeline stage of type %s." % type(stage).__name__)
152+
indexOfLastEstimator = -1
153+
for i, stage in enumerate(stages):
154+
if isinstance(stage, Estimator):
155+
indexOfLastEstimator = i
156+
transformers = []
157+
for i, stage in enumerate(stages):
158+
if i <= indexOfLastEstimator:
159+
if isinstance(stage, Transformer):
160+
transformers.append(stage)
161+
dataset = stage.transform(dataset, paramMap)
162+
else: # must be an Estimator
163+
model = stage.fit(dataset, paramMap)
164+
transformers.append(model)
165+
dataset = model.transform(dataset, paramMap)
166+
else:
167+
transformers.append(stage)
163168
return PipelineModel(transformers)
164169

165170

166171
@inherit_doc
167172
class PipelineModel(Transformer):
173+
"""
174+
Represents a compiled pipeline with transformers and fitted models.
175+
"""
168176

169177
def __init__(self, transformers):
170178
super(PipelineModel, self).__init__()
171179
self.transformers = transformers
172180

173181
def transform(self, dataset, params={}):
174-
map = self._merge_params(params)
182+
paramMap = self._merge_params(params)
175183
for t in self.transformers:
176-
dataset = t.transform(dataset, map)
184+
dataset = t.transform(dataset, paramMap)
177185
return dataset
178186

179187

180188
@inherit_doc
181-
class JavaWrapper(object):
189+
class JavaWrapper(Params):
190+
"""
191+
Utility class to help create wrapper classes from Java/Scala
192+
implementations of pipeline components.
193+
"""
182194

183195
__metaclass__ = ABCMeta
184196

@@ -187,17 +199,45 @@ def __init__(self):
187199

188200
@abstractproperty
189201
def _java_class(self):
202+
"""
203+
Fully-qualified class name of the wrapped Java component.
204+
"""
190205
raise NotImplementedError
191206

192207
def _create_java_obj(self):
208+
"""
209+
Creates a new Java object and returns its reference.
210+
"""
193211
java_obj = _jvm()
194212
for name in self._java_class.split("."):
195213
java_obj = getattr(java_obj, name)
196214
return java_obj()
197215

216+
def _transfer_params_to_java(self, params, java_obj):
217+
"""
218+
Transforms the embedded params and additional params to the
219+
input Java object.
220+
:param params: additional params (overwriting embedded values)
221+
:param java_obj: Java object to receive the params
222+
"""
223+
paramMap = self._merge_params(params)
224+
for param in self.params():
225+
if param in paramMap:
226+
java_obj.set(param.name, paramMap[param])
227+
228+
def _empty_java_param_map(self):
229+
"""
230+
Returns an empty Java ParamMap reference.
231+
"""
232+
return _jvm().org.apache.spark.ml.param.ParamMap()
233+
198234

199235
@inherit_doc
200236
class JavaEstimator(Estimator, JavaWrapper):
237+
"""
238+
Base class for :py:class:`Estimator`s that wrap Java/Scala
239+
implementations.
240+
"""
201241

202242
__metaclass__ = ABCMeta
203243

@@ -206,12 +246,22 @@ def __init__(self):
206246

207247
@abstractmethod
208248
def _create_model(self, java_model):
249+
"""
250+
Creates a model from the input Java model reference.
251+
"""
209252
raise NotImplementedError
210253

211254
def _fit_java(self, dataset, params={}):
255+
"""
256+
Fits a Java model to the input dataset.
257+
:param dataset: input dataset, which is an instance of
258+
:py:class:`pyspark.sql.SchemaRDD`
259+
:param params: additional params (overwriting embedded values)
260+
:return: fitted Java model
261+
"""
212262
java_obj = self._create_java_obj()
213263
self._transfer_params_to_java(params, java_obj)
214-
return java_obj.fit(dataset._jschema_rdd, _jvm().org.apache.spark.ml.param.ParamMap())
264+
return java_obj.fit(dataset._jschema_rdd, self._empty_java_param_map())
215265

216266
def fit(self, dataset, params={}):
217267
java_model = self._fit_java(dataset, params)
@@ -220,6 +270,10 @@ def fit(self, dataset, params={}):
220270

221271
@inherit_doc
222272
class JavaTransformer(Transformer, JavaWrapper):
273+
"""
274+
Base class for :py:class:`Transformer`s that wrap Java/Scala
275+
implementations.
276+
"""
223277

224278
__metaclass__ = ABCMeta
225279

@@ -229,6 +283,5 @@ def __init__(self):
229283
def transform(self, dataset, params={}):
230284
java_obj = self._create_java_obj()
231285
self._transfer_params_to_java(params, java_obj)
232-
return SchemaRDD(java_obj.transform(dataset._jschema_rdd,
233-
_jvm().org.apache.spark.ml.param.ParamMap()),
286+
return SchemaRDD(java_obj.transform(dataset._jschema_rdd, self._empty_java_param_map()),
234287
dataset.sql_ctx)

python/pyspark/ml/classification.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class LogisticRegressionModel(Transformer):
4646
"""
4747

4848
def __init__(self, java_model):
49+
super(LogisticRegressionModel, self).__init__()
4950
self._java_model = java_model
5051

5152
def transform(self, dataset, params={}):

python/pyspark/ml/param/__init__.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,3 @@ def _merge_params(self, params):
6868
map = self.paramMap.copy()
6969
map.update(params)
7070
return map
71-
72-
def _transfer_params_to_java(self, params, java_obj):
73-
map = self._merge_params(params)
74-
for param in self.params():
75-
if param in map:
76-
java_obj.set(param.name, map[param])

0 commit comments

Comments
 (0)