Skip to content

Commit dadd84e

Browse files
committed
add base classes and docs
1 parent a3015cf commit dadd84e

File tree

7 files changed

+240
-51
lines changed

7 files changed

+240
-51
lines changed

python/docs/pyspark.ml.rst

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
pyspark.ml package
2+
=====================
3+
4+
Submodules
5+
----------
6+
7+
pyspark.ml module
8+
-------------------------
9+
10+
.. automodule:: pyspark.ml
11+
:members:
12+
:undoc-members:
13+
:show-inheritance:

python/pyspark/ml/__init__.py

Lines changed: 112 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,149 @@
1-
import inspect
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from abc import ABCMeta, abstractmethod
219

320
from pyspark import SparkContext
4-
from pyspark.ml.param import Param
21+
from pyspark.sql import inherit_doc
22+
from pyspark.ml.param import Param, Params
23+
from pyspark.ml.util import Identifiable
524

625
__all__ = ["Pipeline", "Transformer", "Estimator"]
726

8-
# An implementation of PEP3102 for Python 2.
9-
_keyword_only_secret = 70861589
27+
28+
def _jvm():
29+
return SparkContext._jvm
1030

1131

12-
def _assert_keyword_only_args():
32+
@inherit_doc
33+
class PipelineStage(Params):
1334
"""
14-
Checks whether the _keyword_only trick is applied and validates input arguments.
35+
A stage in a pipeline, either an :py:class:`Estimator` or a
36+
:py:class:`Transformer`.
1537
"""
16-
# Get the frame of the function that calls this function.
17-
frame = inspect.currentframe().f_back
18-
info = inspect.getargvalues(frame)
19-
if "_keyword_only" not in info.args:
20-
raise ValueError("Function does not have argument _keyword_only.")
21-
if info.locals["_keyword_only"] != _keyword_only_secret:
22-
raise ValueError("Must use keyword arguments instead of positional ones.")
2338

24-
def _jvm():
25-
return SparkContext._jvm
39+
def __init__(self):
40+
super.__init__(self)
41+
42+
43+
@inherit_doc
44+
class Estimator(PipelineStage):
45+
"""
46+
Abstract class for estimators that fit models to data.
47+
"""
48+
49+
__metaclass__ = ABCMeta
50+
51+
def __init__(self):
52+
super.__init__(self)
53+
54+
@abstractmethod
55+
def fit(self, dataset, params={}):
56+
"""
57+
Fits a model to the input dataset with optional parameters.
58+
59+
:param dataset: input dataset, which is an instance of
60+
:py:class:`pyspark.sql.SchemaRDD`
61+
:param params: an optional param map that overwrites embedded
62+
params
63+
:returns: fitted model
64+
"""
65+
raise NotImplementedError()
66+
67+
68+
@inherit_doc
69+
class Transformer(PipelineStage):
70+
"""
71+
Abstract class for transformers that transform one dataset into
72+
another.
73+
"""
74+
75+
__metaclass__ = ABCMeta
2676

27-
class Pipeline(object):
77+
@abstractmethod
78+
def transform(self, dataset, params={}):
79+
"""
80+
Transforms the input dataset with optional parameters.
81+
82+
:param dataset: input dataset, which is an instance of
83+
:py:class:`pyspark.sql.SchemaRDD`
84+
:param params: an optional param map that overwrites embedded
85+
params
86+
:returns: transformed dataset
87+
"""
88+
raise NotImplementedError()
89+
90+
91+
@inherit_doc
92+
class Pipeline(Estimator):
93+
"""
94+
A simple pipeline, which acts as an estimator. A Pipeline consists
95+
of a sequence of stages, each of which is either an
96+
:py:class:`Estimator` or a :py:class:`Transformer`. When
97+
:py:meth:`Pipeline.fit` is called, the stages are executed in
98+
order. If a stage is an :py:class:`Estimator`, its
99+
:py:meth:`Estimator.fit` method will be called on the input
100+
dataset to fit a model. Then the model, which is a transformer,
101+
will be used to transform the dataset as the input to the next
102+
stage. If a stage is a :py:class:`Transformer`, its
103+
:py:meth:`Transformer.transform` method will be called to produce
104+
the dataset for the next stage. The fitted model from a
105+
:py:class:`Pipeline` is an :py:class:`PipelineModel`, which
106+
consists of fitted models and transformers, corresponding to the
107+
pipeline stages. If there are no stages, the pipeline acts as an
108+
identity transformer.
109+
"""
28110

29111
def __init__(self):
112+
super.__init__(self)
113+
#: Param for pipeline stages.
30114
self.stages = Param(self, "stages", "pipeline stages")
31-
self.paramMap = {}
32115

33116
def setStages(self, value):
117+
"""
118+
Set pipeline stages.
119+
:param value: a list of transformers or estimators
120+
:return: the pipeline instance
121+
"""
34122
self.paramMap[self.stages] = value
35123
return self
36124

37125
def getStages(self):
126+
"""
127+
Get pipeline stages.
128+
"""
38129
if self.stages in self.paramMap:
39130
return self.paramMap[self.stages]
40131

41132
def fit(self, dataset):
42133
transformers = []
43134
for stage in self.getStages():
44-
if hasattr(stage, "transform"):
135+
if isinstance(stage, Transformer):
45136
transformers.append(stage)
46137
dataset = stage.transform(dataset)
47-
elif hasattr(stage, "fit"):
138+
elif isinstance(stage, Estimator):
48139
model = stage.fit(dataset)
49140
transformers.append(model)
50141
dataset = model.transform(dataset)
51142
return PipelineModel(transformers)
52143

53144

54-
class PipelineModel(object):
145+
@inherit_doc
146+
class PipelineModel(Transformer):
55147

56148
def __init__(self, transformers):
57149
self.transformers = transformers
@@ -60,15 +152,3 @@ def transform(self, dataset):
60152
for t in self.transformers:
61153
dataset = t.transform(dataset)
62154
return dataset
63-
64-
65-
class Estimator(object):
66-
67-
def fit(self, dataset, params={}):
68-
raise NotImplementedError()
69-
70-
71-
class Transformer(object):
72-
73-
def transform(self, dataset, paramMap={}):
74-
raise NotImplementedError()

python/pyspark/ml/classification.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,20 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
118
from pyspark.sql import SchemaRDD
219
from pyspark.ml import Estimator, Transformer, _jvm
320
from pyspark.ml.param import Param
@@ -41,7 +58,8 @@ def fit(self, dataset, params=None):
4158
"""
4259
Fits a dataset with optional parameters.
4360
"""
44-
java_model = self._java_obj.fit(dataset._jschema_rdd, _jvm().org.apache.spark.ml.param.ParamMap())
61+
java_model = self._java_obj.fit(dataset._jschema_rdd,
62+
_jvm().org.apache.spark.ml.param.ParamMap())
4563
return LogisticRegressionModel(java_model)
4664

4765

@@ -54,5 +72,6 @@ def __init__(self, _java_model):
5472
self._java_model = _java_model
5573

5674
def transform(self, dataset):
57-
return SchemaRDD(self._java_model.transform(dataset._jschema_rdd, _jvm().org.apache.spark.ml.param.ParamMap()), dataset.sql_ctx)
58-
75+
return SchemaRDD(self._java_model.transform(
76+
dataset._jschema_rdd,
77+
_jvm().org.apache.spark.ml.param.ParamMap()), dataset.sql_ctx)

python/pyspark/ml/feature.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,20 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
118
from pyspark.sql import SchemaRDD, ArrayType, StringType
219
from pyspark.ml import _jvm
320
from pyspark.ml.param import Param

python/pyspark/ml/param.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,28 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from abc import ABCMeta, abstractmethod
19+
20+
from pyspark.ml.util import Identifiable
21+
22+
23+
__all__ = ["Param"]
24+
25+
126
class Param(object):
227
"""
328
A param with self-contained documentation and optionally default value.
@@ -12,5 +37,28 @@ def __init__(self, parent, name, doc, defaultValue=None):
1237
def __str__(self):
1338
return self.parent + "_" + self.name
1439

15-
def __repr_(self):
40+
def __repr__(self):
1641
return self.parent + "_" + self.name
42+
43+
44+
class Params(Identifiable):
45+
"""
46+
Components that take parameters. This also provides an internal
47+
param map to store parameter values attached to the instance.
48+
"""
49+
50+
__metaclass__ = ABCMeta
51+
52+
def __init__(self):
53+
super.__init__(self)
54+
#: Internal param map.
55+
self.paramMap = {}
56+
57+
@abstractmethod
58+
def params(self):
59+
"""
60+
Returns all params. The default implementation uses
61+
:py:func:`dir` to get all attributes of type
62+
:py:class:`Param`.
63+
"""
64+
return [attr for attr in dir(self) if isinstance(attr, Param)]

python/pyspark/ml/test.py

Lines changed: 0 additions & 15 deletions
This file was deleted.

python/pyspark/ml/util.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
19+
class Identifiable(object):
20+
"""
21+
Object with a unique ID.
22+
"""
23+
24+
def __init__(self):
25+
#: A unique id for the object. The default implementation
26+
#: concatenates the class name, "-", and 8 random hex chars.
27+
self.uid = type(self).__name__ + "-" + uuid.uuid4().hex[:8]

0 commit comments

Comments
 (0)