Skip to content

Commit

Permalink
[SPARK-32320][PYSPARK] Remove mutable default arguments
Browse files Browse the repository at this point in the history
This is bad practice, and might lead to unexpected behaviour:
https://florimond.dev/blog/articles/2018/08/python-mutable-defaults-are-the-source-of-all-evil/

Add bugbear to check it in the CI
  • Loading branch information
Fokko committed Oct 20, 2020
1 parent eb9966b commit 6b0f39f
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 25 deletions.
9 changes: 5 additions & 4 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ class Module(object):
files have changed.
"""

def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=(), environ={},
sbt_test_goals=(), python_test_goals=(), excluded_python_implementations=(),
test_tags=(), should_run_r_tests=False, should_run_build_tests=False):
def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=(),
environ=None, sbt_test_goals=(), python_test_goals=(),
excluded_python_implementations=(), test_tags=(), should_run_r_tests=False,
should_run_build_tests=False):
"""
Define a new module.
Expand Down Expand Up @@ -62,7 +63,7 @@ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=
self.source_file_prefixes = source_file_regexes
self.sbt_test_goals = sbt_test_goals
self.build_profile_flags = build_profile_flags
self.environ = environ
self.environ = environ or {}
self.python_test_goals = python_test_goals
self.excluded_python_implementations = excluded_python_implementations
self.test_tags = test_tags
Expand Down
2 changes: 1 addition & 1 deletion dev/tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@ max-line-length=100
exclude=python/pyspark/cloudpickle/*.py,shared.py,python/docs/source/conf.py,work/*/*.py,python/.eggs/*,dist/*,.git/*

[flake8]
select = E901,E999,F821,F822,F823,F401,F405
select = E901,E999,F821,F822,F823,F401,F405,B006
exclude = python/pyspark/cloudpickle/*.py,shared.py*,python/docs/source/conf.py,work/*/*.py,python/.eggs/*,dist/*,.git/*,python/out,python/pyspark/sql/pandas/functions.pyi,python/pyspark/sql/column.pyi,python/pyspark/worker.pyi,python/pyspark/java_gateway.pyi
max-line-length = 100
8 changes: 5 additions & 3 deletions python/pyspark/ml/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
'RandomForestRegressor', 'RandomForestRegressionModel',
'FMRegressor', 'FMRegressionModel']

DEFAULT_QUANTILE_PROBABILITIES = [0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]


class Regressor(Predictor, _PredictorParams, metaclass=ABCMeta):
"""
Expand Down Expand Up @@ -1654,7 +1656,7 @@ class _AFTSurvivalRegressionParams(_PredictorParams, HasMaxIter, HasTol, HasFitI
def __init__(self, *args):
super(_AFTSurvivalRegressionParams, self).__init__(*args)
self._setDefault(censorCol="censor",
quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99],
quantileProbabilities=DEFAULT_QUANTILE_PROBABILITIES,
maxIter=100, tol=1E-6, blockSize=1)

@since("1.6.0")
Expand Down Expand Up @@ -1740,7 +1742,7 @@ class AFTSurvivalRegression(_JavaRegressor, _AFTSurvivalRegressionParams,
@keyword_only
def __init__(self, *, featuresCol="features", labelCol="label", predictionCol="prediction",
fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor",
quantileProbabilities=list([0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]),
quantileProbabilities=DEFAULT_QUANTILE_PROBABILITIES,
quantilesCol=None, aggregationDepth=2, blockSize=1):
"""
__init__(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
Expand All @@ -1758,7 +1760,7 @@ def __init__(self, *, featuresCol="features", labelCol="label", predictionCol="p
@since("1.6.0")
def setParams(self, *, featuresCol="features", labelCol="label", predictionCol="prediction",
fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor",
quantileProbabilities=list([0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]),
quantileProbabilities=DEFAULT_QUANTILE_PROBABILITIES,
quantilesCol=None, aggregationDepth=2, blockSize=1):
"""
setParams(self, \\*, featuresCol="features", labelCol="label", predictionCol="prediction", \
Expand Down
8 changes: 4 additions & 4 deletions python/pyspark/ml/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,13 +508,13 @@ class CrossValidatorModel(Model, _CrossValidatorParams, MLReadable, MLWritable):
.. versionadded:: 1.4.0
"""

def __init__(self, bestModel, avgMetrics=[], subModels=None):
def __init__(self, bestModel, avgMetrics=None, subModels=None):
super(CrossValidatorModel, self).__init__()
#: best model from cross validation
self.bestModel = bestModel
#: Average cross-validation metrics for each paramMap in
#: CrossValidator.estimatorParamMaps, in the corresponding order.
self.avgMetrics = avgMetrics
self.avgMetrics = avgMetrics or []
#: sub model list from cross validation
self.subModels = subModels

Expand Down Expand Up @@ -868,12 +868,12 @@ class TrainValidationSplitModel(Model, _TrainValidationSplitParams, MLReadable,
.. versionadded:: 2.0.0
"""

def __init__(self, bestModel, validationMetrics=[], subModels=None):
def __init__(self, bestModel, validationMetrics=None, subModels=None):
super(TrainValidationSplitModel, self).__init__()
#: best model from train validation split
self.bestModel = bestModel
#: evaluated validation metrics
self.validationMetrics = validationMetrics
self.validationMetrics = validationMetrics or []
#: sub models from train validation split
self.subModels = subModels

Expand Down
6 changes: 3 additions & 3 deletions python/pyspark/resource/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ class ResourceProfile(object):
.. versionadded:: 3.1.0
"""

def __init__(self, _java_resource_profile=None, _exec_req={}, _task_req={}):
def __init__(self, _java_resource_profile=None, _exec_req=None, _task_req=None):
if _java_resource_profile is not None:
self._java_resource_profile = _java_resource_profile
else:
self._java_resource_profile = None
self._executor_resource_requests = _exec_req
self._task_resource_requests = _task_req
self._executor_resource_requests = _exec_req or {}
self._task_resource_requests = _task_req or {}

@property
def id(self):
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/sql/avro/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@


@since(3.0)
def from_avro(data, jsonFormatSchema, options={}):
def from_avro(data, jsonFormatSchema, options=None):
"""
Converts a binary column of Avro format into its corresponding catalyst value.
The specified schema must match the read data, otherwise the behavior is undefined:
Expand Down Expand Up @@ -59,7 +59,7 @@ def from_avro(data, jsonFormatSchema, options={}):
sc = SparkContext._active_spark_context
try:
jc = sc._jvm.org.apache.spark.sql.avro.functions.from_avro(
_to_java_column(data), jsonFormatSchema, options)
_to_java_column(data), jsonFormatSchema, options or {})
except TypeError as e:
if str(e) == "'JavaPackage' object is not callable":
_print_missing_jar("Avro", "avro", "avro", sc.version)
Expand Down
18 changes: 10 additions & 8 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,10 @@ def _():
return _


def _options_to_str(options):
return {key: to_str(value) for (key, value) in options.items()}
def _options_to_str(options=None):
if options:
return {key: to_str(value) for (key, value) in options.items()}
return {}

_lit_doc = """
Creates a :class:`Column` of literal value.
Expand Down Expand Up @@ -2476,7 +2478,7 @@ def json_tuple(col, *fields):


@since(2.1)
def from_json(col, schema, options={}):
def from_json(col, schema, options=None):
"""
Parses a column containing a JSON string into a :class:`MapType` with :class:`StringType`
as keys type, :class:`StructType` or :class:`ArrayType` with
Expand Down Expand Up @@ -2524,7 +2526,7 @@ def from_json(col, schema, options={}):


@since(2.1)
def to_json(col, options={}):
def to_json(col, options=None):
"""
Converts a column containing a :class:`StructType`, :class:`ArrayType` or a :class:`MapType`
into a JSON string. Throws an exception, in the case of an unsupported type.
Expand Down Expand Up @@ -2564,7 +2566,7 @@ def to_json(col, options={}):


@since(2.4)
def schema_of_json(json, options={}):
def schema_of_json(json, options=None):
"""
Parses a JSON string and infers its schema in DDL format.
Expand Down Expand Up @@ -2594,7 +2596,7 @@ def schema_of_json(json, options={}):


@since(3.0)
def schema_of_csv(csv, options={}):
def schema_of_csv(csv, options=None):
"""
Parses a CSV string and infers its schema in DDL format.
Expand All @@ -2620,7 +2622,7 @@ def schema_of_csv(csv, options={}):


@since(3.0)
def to_csv(col, options={}):
def to_csv(col, options=None):
"""
Converts a column containing a :class:`StructType` into a CSV string.
Throws an exception, in the case of an unsupported type.
Expand Down Expand Up @@ -2931,7 +2933,7 @@ def sequence(start, stop, step=None):


@since(3.0)
def from_csv(col, schema, options={}):
def from_csv(col, schema, options=None):
"""
Parses a column containing a CSV string to a row with the specified schema.
Returns `null`, in the case of an unparseable string.
Expand Down

0 comments on commit 6b0f39f

Please sign in to comment.