Skip to content

Commit 5fa0863

Browse files
MechCodermengxr
authored andcommitted
[SPARK-8679] [PYSPARK] [MLLIB] Default values in Pipeline API should be immutable
It might be dangerous to have a mutable as value for default param. (http://stackoverflow.com/a/11416002/1170730) e.g def func(example, f={}): f[example] = 1 return f func(2) {2: 1} func(3) {2:1, 3:1} mengxr Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes apache#7058 from MechCoder/pipeline_api_playground and squashes the following commits: 40a5eb2 [MechCoder] copy 95f7ff2 [MechCoder] [SPARK-8679] [PySpark] [MLlib] Default values in Pipeline API should be immutable
1 parent 4528166 commit 5fa0863

File tree

2 files changed

+21
-7
lines changed

2 files changed

+21
-7
lines changed

python/pyspark/ml/pipeline.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def _fit(self, dataset):
4242
"""
4343
raise NotImplementedError()
4444

45-
def fit(self, dataset, params={}):
45+
def fit(self, dataset, params=None):
4646
"""
4747
Fits a model to the input dataset with optional parameters.
4848
@@ -54,6 +54,8 @@ def fit(self, dataset, params={}):
5454
list of models.
5555
:returns: fitted model(s)
5656
"""
57+
if params is None:
58+
params = dict()
5759
if isinstance(params, (list, tuple)):
5860
return [self.fit(dataset, paramMap) for paramMap in params]
5961
elif isinstance(params, dict):
@@ -86,7 +88,7 @@ def _transform(self, dataset):
8688
"""
8789
raise NotImplementedError()
8890

89-
def transform(self, dataset, params={}):
91+
def transform(self, dataset, params=None):
9092
"""
9193
Transforms the input dataset with optional parameters.
9294
@@ -96,6 +98,8 @@ def transform(self, dataset, params={}):
9698
params.
9799
:returns: transformed dataset
98100
"""
101+
if params is None:
102+
params = dict()
99103
if isinstance(params, dict):
100104
if params:
101105
return self.copy(params,)._transform(dataset)
@@ -135,10 +139,12 @@ class Pipeline(Estimator):
135139
"""
136140

137141
@keyword_only
138-
def __init__(self, stages=[]):
142+
def __init__(self, stages=None):
139143
"""
140144
__init__(self, stages=[])
141145
"""
146+
if stages is None:
147+
stages = []
142148
super(Pipeline, self).__init__()
143149
#: Param for pipeline stages.
144150
self.stages = Param(self, "stages", "pipeline stages")
@@ -162,11 +168,13 @@ def getStages(self):
162168
return self._paramMap[self.stages]
163169

164170
@keyword_only
165-
def setParams(self, stages=[]):
171+
def setParams(self, stages=None):
166172
"""
167173
setParams(self, stages=[])
168174
Sets params for Pipeline.
169175
"""
176+
if stages is None:
177+
stages = []
170178
kwargs = self.setParams._input_kwargs
171179
return self._set(**kwargs)
172180

@@ -195,7 +203,9 @@ def _fit(self, dataset):
195203
transformers.append(stage)
196204
return PipelineModel(transformers)
197205

198-
def copy(self, extra={}):
206+
def copy(self, extra=None):
207+
if extra is None:
208+
extra = dict()
199209
that = Params.copy(self, extra)
200210
stages = [stage.copy(extra) for stage in that.getStages()]
201211
return that.setStages(stages)
@@ -216,6 +226,8 @@ def _transform(self, dataset):
216226
dataset = t.transform(dataset)
217227
return dataset
218228

219-
def copy(self, extra={}):
229+
def copy(self, extra=None):
230+
if extra is None:
231+
extra = dict()
220232
stages = [stage.copy(extra) for stage in self.stages]
221233
return PipelineModel(stages)

python/pyspark/ml/wrapper.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def __init__(self, java_model):
166166
self._java_obj = java_model
167167
self.uid = java_model.uid()
168168

169-
def copy(self, extra={}):
169+
def copy(self, extra=None):
170170
"""
171171
Creates a copy of this instance with the same uid and some
172172
extra params. This implementation first calls Params.copy and
@@ -175,6 +175,8 @@ def copy(self, extra={}):
175175
:param extra: Extra parameters to copy to the new instance
176176
:return: Copy of this instance
177177
"""
178+
if extra is None:
179+
extra = dict()
178180
that = super(JavaModel, self).copy(extra)
179181
that._java_obj = self._java_obj.copy(self._empty_java_param_map())
180182
that._transfer_params_to_java()

0 commit comments

Comments
 (0)