Skip to content

Commit 44c2405

Browse files
committed
Merge pull request #2 from davies/ml
refactor
2 parents dd1256b + 14ae7e2 commit 44c2405

File tree

14 files changed

+379
-395
lines changed

14 files changed

+379
-395
lines changed

python/docs/conf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,9 @@
5555
# built documents.
5656
#
5757
# The short X.Y version.
58-
version = '1.2-SNAPSHOT'
58+
version = '1.3-SNAPSHOT'
5959
# The full version, including alpha/beta/rc tags.
60-
release = '1.2-SNAPSHOT'
60+
release = '1.3-SNAPSHOT'
6161

6262
# The language for content autogenerated by Sphinx. Refer to documentation
6363
# for a list of supported languages.

python/docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ Contents:
1414
pyspark
1515
pyspark.sql
1616
pyspark.streaming
17+
pyspark.ml
1718
pyspark.mllib
1819

1920

python/docs/pyspark.ml.rst

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,6 @@ pyspark.ml module
1010
.. automodule:: pyspark.ml
1111
:members:
1212
:undoc-members:
13-
:show-inheritance:
14-
:inherited-members:
15-
16-
pyspark.ml.param module
17-
-----------------------
18-
19-
.. automodule:: pyspark.ml.param
20-
:members:
21-
:undoc-members:
22-
:show-inheritance:
2313
:inherited-members:
2414

2515
pyspark.ml.feature module
@@ -28,7 +18,6 @@ pyspark.ml.feature module
2818
.. automodule:: pyspark.ml.feature
2919
:members:
3020
:undoc-members:
31-
:show-inheritance:
3221
:inherited-members:
3322

3423
pyspark.ml.classification module
@@ -37,5 +26,4 @@ pyspark.ml.classification module
3726
.. automodule:: pyspark.ml.classification
3827
:members:
3928
:undoc-members:
40-
:show-inheritance:
4129
:inherited-members:

python/docs/pyspark.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ Subpackages
99

1010
pyspark.sql
1111
pyspark.streaming
12+
pyspark.ml
1213
pyspark.mllib
1314

1415
Contents

python/pyspark/ml/__init__.py

Lines changed: 3 additions & 306 deletions
Original file line numberDiff line numberDiff line change
@@ -15,310 +15,7 @@
1515
# limitations under the License.
1616
#
1717

18-
from abc import ABCMeta, abstractmethod, abstractproperty
18+
from pyspark.ml.param import *
19+
from pyspark.ml.pipeline import *
1920

20-
from pyspark import SparkContext
21-
from pyspark.sql import SchemaRDD, inherit_doc # TODO: move inherit_doc to Spark Core
22-
from pyspark.ml.param import Param, Params
23-
from pyspark.ml.util import Identifiable
24-
25-
__all__ = ["Pipeline", "Transformer", "Estimator", "param", "feature", "classification"]
26-
27-
28-
def _jvm():
29-
"""
30-
Returns the JVM view associated with SparkContext. Must be called
31-
after SparkContext is initialized.
32-
"""
33-
jvm = SparkContext._jvm
34-
if jvm:
35-
return jvm
36-
else:
37-
raise AttributeError("Cannot load _jvm from SparkContext. Is SparkContext initialized?")
38-
39-
40-
@inherit_doc
41-
class PipelineStage(Params):
42-
"""
43-
A stage in a pipeline, either an :py:class:`Estimator` or a
44-
:py:class:`Transformer`.
45-
"""
46-
47-
__metaclass__ = ABCMeta
48-
49-
def __init__(self):
50-
super(PipelineStage, self).__init__()
51-
52-
53-
@inherit_doc
54-
class Estimator(PipelineStage):
55-
"""
56-
Abstract class for estimators that fit models to data.
57-
"""
58-
59-
__metaclass__ = ABCMeta
60-
61-
def __init__(self):
62-
super(Estimator, self).__init__()
63-
64-
@abstractmethod
65-
def fit(self, dataset, params={}):
66-
"""
67-
Fits a model to the input dataset with optional parameters.
68-
69-
:param dataset: input dataset, which is an instance of
70-
:py:class:`pyspark.sql.SchemaRDD`
71-
:param params: an optional param map that overwrites embedded
72-
params
73-
:returns: fitted model
74-
"""
75-
raise NotImplementedError()
76-
77-
78-
@inherit_doc
79-
class Transformer(PipelineStage):
80-
"""
81-
Abstract class for transformers that transform one dataset into
82-
another.
83-
"""
84-
85-
__metaclass__ = ABCMeta
86-
87-
def __init__(self):
88-
super(Transformer, self).__init__()
89-
90-
@abstractmethod
91-
def transform(self, dataset, params={}):
92-
"""
93-
Transforms the input dataset with optional parameters.
94-
95-
:param dataset: input dataset, which is an instance of
96-
:py:class:`pyspark.sql.SchemaRDD`
97-
:param params: an optional param map that overwrites embedded
98-
params
99-
:returns: transformed dataset
100-
"""
101-
raise NotImplementedError()
102-
103-
104-
@inherit_doc
105-
class Model(Transformer):
106-
"""
107-
Abstract class for models fitted by :py:class:`Estimator`s.
108-
"""
109-
110-
___metaclass__ = ABCMeta
111-
112-
def __init__(self):
113-
super(Model, self).__init__()
114-
115-
116-
@inherit_doc
117-
class Pipeline(Estimator):
118-
"""
119-
A simple pipeline, which acts as an estimator. A Pipeline consists
120-
of a sequence of stages, each of which is either an
121-
:py:class:`Estimator` or a :py:class:`Transformer`. When
122-
:py:meth:`Pipeline.fit` is called, the stages are executed in
123-
order. If a stage is an :py:class:`Estimator`, its
124-
:py:meth:`Estimator.fit` method will be called on the input
125-
dataset to fit a model. Then the model, which is a transformer,
126-
will be used to transform the dataset as the input to the next
127-
stage. If a stage is a :py:class:`Transformer`, its
128-
:py:meth:`Transformer.transform` method will be called to produce
129-
the dataset for the next stage. The fitted model from a
130-
:py:class:`Pipeline` is an :py:class:`PipelineModel`, which
131-
consists of fitted models and transformers, corresponding to the
132-
pipeline stages. If there are no stages, the pipeline acts as an
133-
identity transformer.
134-
"""
135-
136-
def __init__(self):
137-
super(Pipeline, self).__init__()
138-
#: Param for pipeline stages.
139-
self.stages = Param(self, "stages", "pipeline stages")
140-
141-
def setStages(self, value):
142-
"""
143-
Set pipeline stages.
144-
:param value: a list of transformers or estimators
145-
:return: the pipeline instance
146-
"""
147-
self.paramMap[self.stages] = value
148-
return self
149-
150-
def getStages(self):
151-
"""
152-
Get pipeline stages.
153-
"""
154-
if self.stages in self.paramMap:
155-
return self.paramMap[self.stages]
156-
157-
def fit(self, dataset, params={}):
158-
paramMap = self._merge_params(params)
159-
stages = paramMap[self.stages]
160-
for stage in stages:
161-
if not (isinstance(stage, Estimator) or isinstance(stage, Transformer)):
162-
raise ValueError(
163-
"Cannot recognize a pipeline stage of type %s." % type(stage).__name__)
164-
indexOfLastEstimator = -1
165-
for i, stage in enumerate(stages):
166-
if isinstance(stage, Estimator):
167-
indexOfLastEstimator = i
168-
transformers = []
169-
for i, stage in enumerate(stages):
170-
if i <= indexOfLastEstimator:
171-
if isinstance(stage, Transformer):
172-
transformers.append(stage)
173-
dataset = stage.transform(dataset, paramMap)
174-
else: # must be an Estimator
175-
model = stage.fit(dataset, paramMap)
176-
transformers.append(model)
177-
if i < indexOfLastEstimator:
178-
dataset = model.transform(dataset, paramMap)
179-
else:
180-
transformers.append(stage)
181-
return PipelineModel(transformers)
182-
183-
184-
@inherit_doc
185-
class PipelineModel(Model):
186-
"""
187-
Represents a compiled pipeline with transformers and fitted models.
188-
"""
189-
190-
def __init__(self, transformers):
191-
super(PipelineModel, self).__init__()
192-
self.transformers = transformers
193-
194-
def transform(self, dataset, params={}):
195-
paramMap = self._merge_params(params)
196-
for t in self.transformers:
197-
dataset = t.transform(dataset, paramMap)
198-
return dataset
199-
200-
201-
@inherit_doc
202-
class JavaWrapper(Params):
203-
"""
204-
Utility class to help create wrapper classes from Java/Scala
205-
implementations of pipeline components.
206-
"""
207-
208-
__metaclass__ = ABCMeta
209-
210-
def __init__(self):
211-
super(JavaWrapper, self).__init__()
212-
213-
@abstractproperty
214-
def _java_class(self):
215-
"""
216-
Fully-qualified class name of the wrapped Java component.
217-
"""
218-
raise NotImplementedError
219-
220-
def _java_obj(self):
221-
"""
222-
Returns or creates a Java object.
223-
"""
224-
java_obj = _jvm()
225-
for name in self._java_class.split("."):
226-
java_obj = getattr(java_obj, name)
227-
return java_obj()
228-
229-
def _transfer_params_to_java(self, params, java_obj):
230-
"""
231-
Transforms the embedded params and additional params to the
232-
input Java object.
233-
:param params: additional params (overwriting embedded values)
234-
:param java_obj: Java object to receive the params
235-
"""
236-
paramMap = self._merge_params(params)
237-
for param in self.params:
238-
if param in paramMap:
239-
java_obj.set(param.name, paramMap[param])
240-
241-
def _empty_java_param_map(self):
242-
"""
243-
Returns an empty Java ParamMap reference.
244-
"""
245-
return _jvm().org.apache.spark.ml.param.ParamMap()
246-
247-
def _create_java_param_map(self, params, java_obj):
248-
paramMap = self._empty_java_param_map()
249-
for param, value in params.items():
250-
if param.parent is self:
251-
paramMap.put(java_obj.getParam(param.name), value)
252-
return paramMap
253-
254-
255-
@inherit_doc
256-
class JavaEstimator(Estimator, JavaWrapper):
257-
"""
258-
Base class for :py:class:`Estimator`s that wrap Java/Scala
259-
implementations.
260-
"""
261-
262-
__metaclass__ = ABCMeta
263-
264-
def __init__(self):
265-
super(JavaEstimator, self).__init__()
266-
267-
@abstractmethod
268-
def _create_model(self, java_model):
269-
"""
270-
Creates a model from the input Java model reference.
271-
"""
272-
raise NotImplementedError
273-
274-
def _fit_java(self, dataset, params={}):
275-
"""
276-
Fits a Java model to the input dataset.
277-
:param dataset: input dataset, which is an instance of
278-
:py:class:`pyspark.sql.SchemaRDD`
279-
:param params: additional params (overwriting embedded values)
280-
:return: fitted Java model
281-
"""
282-
java_obj = self._java_obj()
283-
self._transfer_params_to_java(params, java_obj)
284-
return java_obj.fit(dataset._jschema_rdd, self._empty_java_param_map())
285-
286-
def fit(self, dataset, params={}):
287-
java_model = self._fit_java(dataset, params)
288-
return self._create_model(java_model)
289-
290-
291-
@inherit_doc
292-
class JavaTransformer(Transformer, JavaWrapper):
293-
"""
294-
Base class for :py:class:`Transformer`s that wrap Java/Scala
295-
implementations.
296-
"""
297-
298-
__metaclass__ = ABCMeta
299-
300-
def __init__(self):
301-
super(JavaTransformer, self).__init__()
302-
303-
def transform(self, dataset, params={}):
304-
java_obj = self._java_obj()
305-
self._transfer_params_to_java({}, java_obj)
306-
java_param_map = self._create_java_param_map(params, java_obj)
307-
return SchemaRDD(java_obj.transform(dataset._jschema_rdd, java_param_map),
308-
dataset.sql_ctx)
309-
310-
311-
@inherit_doc
312-
class JavaModel(JavaTransformer):
313-
"""
314-
Base class for :py:class:`Model`s that wrap Java/Scala
315-
implementations.
316-
"""
317-
318-
__metaclass__ = ABCMeta
319-
320-
def __init__(self):
321-
super(JavaTransformer, self).__init__()
322-
323-
def _java_obj(self):
324-
return self._java_model
21+
__all__ = ["Param", "Params", "Transformer", "Estimator", "Pipeline"]

0 commit comments

Comments
 (0)