Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions python/pyspark/ml/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def _fit(self, dataset):
"""
raise NotImplementedError()

def fit(self, dataset, params={}):
def fit(self, dataset, params=None):
"""
Fits a model to the input dataset with optional parameters.

Expand All @@ -54,6 +54,8 @@ def fit(self, dataset, params={}):
list of models.
:returns: fitted model(s)
"""
if params is None:
params = dict()
if isinstance(params, (list, tuple)):
return [self.fit(dataset, paramMap) for paramMap in params]
elif isinstance(params, dict):
Expand Down Expand Up @@ -86,7 +88,7 @@ def _transform(self, dataset):
"""
raise NotImplementedError()

def transform(self, dataset, params={}):
def transform(self, dataset, params=None):
"""
Transforms the input dataset with optional parameters.

Expand All @@ -96,6 +98,8 @@ def transform(self, dataset, params={}):
params.
:returns: transformed dataset
"""
if params is None:
params = dict()
if isinstance(params, dict):
if params:
return self.copy(params,)._transform(dataset)
Expand Down Expand Up @@ -135,10 +139,12 @@ class Pipeline(Estimator):
"""

@keyword_only
def __init__(self, stages=[]):
def __init__(self, stages=None):
"""
__init__(self, stages=[])
"""
if stages is None:
stages = []
super(Pipeline, self).__init__()
#: Param for pipeline stages.
self.stages = Param(self, "stages", "pipeline stages")
Expand All @@ -162,11 +168,13 @@ def getStages(self):
return self._paramMap[self.stages]

@keyword_only
def setParams(self, stages=[]):
def setParams(self, stages=None):
"""
setParams(self, stages=[])
Sets params for Pipeline.
"""
if stages is None:
stages = []
kwargs = self.setParams._input_kwargs
return self._set(**kwargs)

Expand Down Expand Up @@ -195,7 +203,9 @@ def _fit(self, dataset):
transformers.append(stage)
return PipelineModel(transformers)

def copy(self, extra={}):
def copy(self, extra=None):
if extra is None:
extra = dict()
that = Params.copy(self, extra)
stages = [stage.copy(extra) for stage in that.getStages()]
return that.setStages(stages)
Expand All @@ -216,6 +226,8 @@ def _transform(self, dataset):
dataset = t.transform(dataset)
return dataset

def copy(self, extra={}):
def copy(self, extra=None):
if extra is None:
extra = dict()
stages = [stage.copy(extra) for stage in self.stages]
return PipelineModel(stages)
4 changes: 3 additions & 1 deletion python/pyspark/ml/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def __init__(self, java_model):
self._java_obj = java_model
self.uid = java_model.uid()

def copy(self, extra={}):
def copy(self, extra=None):
"""
Creates a copy of this instance with the same uid and some
extra params. This implementation first calls Params.copy and
Expand All @@ -175,6 +175,8 @@ def copy(self, extra={}):
:param extra: Extra parameters to copy to the new instance
:return: Copy of this instance
"""
if extra is None:
extra = dict()
that = super(JavaModel, self).copy(extra)
that._java_obj = self._java_obj.copy(self._empty_java_param_map())
that._transfer_params_to_java()
Expand Down