1
- import inspect
1
+ #
2
+ # Licensed to the Apache Software Foundation (ASF) under one or more
3
+ # contributor license agreements. See the NOTICE file distributed with
4
+ # this work for additional information regarding copyright ownership.
5
+ # The ASF licenses this file to You under the Apache License, Version 2.0
6
+ # (the "License"); you may not use this file except in compliance with
7
+ # the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+
18
+ from abc import ABCMeta , abstractmethod
2
19
3
20
from pyspark import SparkContext
4
- from pyspark .ml .param import Param
21
+ from pyspark .sql import inherit_doc
22
+ from pyspark .ml .param import Param , Params
23
+ from pyspark .ml .util import Identifiable
5
24
6
25
__all__ = ["Pipeline" , "Transformer" , "Estimator" ]
7
26
8
- # An implementation of PEP3102 for Python 2.
9
- _keyword_only_secret = 70861589
27
+
28
+ def _jvm ():
29
+ return SparkContext ._jvm
10
30
11
31
12
- def _assert_keyword_only_args ():
32
+ @inherit_doc
33
+ class PipelineStage (Params ):
13
34
"""
14
- Checks whether the _keyword_only trick is applied and validates input arguments.
35
+ A stage in a pipeline, either an :py:class:`Estimator` or a
36
+ :py:class:`Transformer`.
15
37
"""
16
- # Get the frame of the function that calls this function.
17
- frame = inspect .currentframe ().f_back
18
- info = inspect .getargvalues (frame )
19
- if "_keyword_only" not in info .args :
20
- raise ValueError ("Function does not have argument _keyword_only." )
21
- if info .locals ["_keyword_only" ] != _keyword_only_secret :
22
- raise ValueError ("Must use keyword arguments instead of positional ones." )
23
38
24
- def _jvm ():
25
- return SparkContext ._jvm
39
+ def __init__ (self ):
40
+ super .__init__ (self )
41
+
42
+
43
+ @inherit_doc
44
+ class Estimator (PipelineStage ):
45
+ """
46
+ Abstract class for estimators that fit models to data.
47
+ """
48
+
49
+ __metaclass__ = ABCMeta
50
+
51
+ def __init__ (self ):
52
+ super .__init__ (self )
53
+
54
+ @abstractmethod
55
+ def fit (self , dataset , params = {}):
56
+ """
57
+ Fits a model to the input dataset with optional parameters.
58
+
59
+ :param dataset: input dataset, which is an instance of
60
+ :py:class:`pyspark.sql.SchemaRDD`
61
+ :param params: an optional param map that overwrites embedded
62
+ params
63
+ :returns: fitted model
64
+ """
65
+ raise NotImplementedError ()
66
+
67
+
68
+ @inherit_doc
69
+ class Transformer (PipelineStage ):
70
+ """
71
+ Abstract class for transformers that transform one dataset into
72
+ another.
73
+ """
74
+
75
+ __metaclass__ = ABCMeta
26
76
27
- class Pipeline (object ):
77
+ @abstractmethod
78
+ def transform (self , dataset , params = {}):
79
+ """
80
+ Transforms the input dataset with optional parameters.
81
+
82
+ :param dataset: input dataset, which is an instance of
83
+ :py:class:`pyspark.sql.SchemaRDD`
84
+ :param params: an optional param map that overwrites embedded
85
+ params
86
+ :returns: transformed dataset
87
+ """
88
+ raise NotImplementedError ()
89
+
90
+
91
+ @inherit_doc
92
+ class Pipeline (Estimator ):
93
+ """
94
+ A simple pipeline, which acts as an estimator. A Pipeline consists
95
+ of a sequence of stages, each of which is either an
96
+ :py:class:`Estimator` or a :py:class:`Transformer`. When
97
+ :py:meth:`Pipeline.fit` is called, the stages are executed in
98
+ order. If a stage is an :py:class:`Estimator`, its
99
+ :py:meth:`Estimator.fit` method will be called on the input
100
+ dataset to fit a model. Then the model, which is a transformer,
101
+ will be used to transform the dataset as the input to the next
102
+ stage. If a stage is a :py:class:`Transformer`, its
103
+ :py:meth:`Transformer.transform` method will be called to produce
104
+ the dataset for the next stage. The fitted model from a
105
+ :py:class:`Pipeline` is an :py:class:`PipelineModel`, which
106
+ consists of fitted models and transformers, corresponding to the
107
+ pipeline stages. If there are no stages, the pipeline acts as an
108
+ identity transformer.
109
+ """
28
110
29
111
def __init__ (self ):
112
+ super .__init__ (self )
113
+ #: Param for pipeline stages.
30
114
self .stages = Param (self , "stages" , "pipeline stages" )
31
- self .paramMap = {}
32
115
33
116
def setStages (self , value ):
117
+ """
118
+ Set pipeline stages.
119
+ :param value: a list of transformers or estimators
120
+ :return: the pipeline instance
121
+ """
34
122
self .paramMap [self .stages ] = value
35
123
return self
36
124
37
125
def getStages (self ):
126
+ """
127
+ Get pipeline stages.
128
+ """
38
129
if self .stages in self .paramMap :
39
130
return self .paramMap [self .stages ]
40
131
41
132
def fit (self , dataset ):
42
133
transformers = []
43
134
for stage in self .getStages ():
44
- if hasattr (stage , "transform" ):
135
+ if isinstance (stage , Transformer ):
45
136
transformers .append (stage )
46
137
dataset = stage .transform (dataset )
47
- elif hasattr (stage , "fit" ):
138
+ elif isinstance (stage , Estimator ):
48
139
model = stage .fit (dataset )
49
140
transformers .append (model )
50
141
dataset = model .transform (dataset )
51
142
return PipelineModel (transformers )
52
143
53
144
54
- class PipelineModel (object ):
145
+ @inherit_doc
146
+ class PipelineModel (Transformer ):
55
147
56
148
def __init__ (self , transformers ):
57
149
self .transformers = transformers
@@ -60,15 +152,3 @@ def transform(self, dataset):
60
152
for t in self .transformers :
61
153
dataset = t .transform (dataset )
62
154
return dataset
63
-
64
-
65
- class Estimator (object ):
66
-
67
- def fit (self , dataset , params = {}):
68
- raise NotImplementedError ()
69
-
70
-
71
- class Transformer (object ):
72
-
73
- def transform (self , dataset , paramMap = {}):
74
- raise NotImplementedError ()
0 commit comments