Skip to content

Commit

Permalink
[SPARK-33613][PYTHON][TESTS] Replace deprecated APIs in pyspark tests
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This replaces deprecated API usage in PySpark tests with the preferred APIs. These have been deprecated for some time and usage is not consistent within tests.

- https://docs.python.org/3/library/unittest.html#deprecated-aliases

### Why are the changes needed?

For consistency and eventual removal of deprecated APIs.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Existing tests

Closes apache#30557 from BryanCutler/replace-deprecated-apis-in-tests.

Authored-by: Bryan Cutler <cutlerb@gmail.com>
Signed-off-by: HyukjinKwon <gurwls223@apache.org>
  • Loading branch information
BryanCutler authored and HyukjinKwon committed Dec 1, 2020
1 parent 596fbc1 commit aeb3649
Show file tree
Hide file tree
Showing 27 changed files with 274 additions and 275 deletions.
2 changes: 1 addition & 1 deletion python/pyspark/ml/tests/test_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def test_count_vectorizer_from_vocab(self):

# Test an empty vocabulary
with QuietTest(self.sc):
with self.assertRaisesRegexp(Exception, "vocabSize.*invalid.*0"):
with self.assertRaisesRegex(Exception, "vocabSize.*invalid.*0"):
CountVectorizerModel.from_vocabulary([], inputCol="words")

# Test model with default settings can transform
Expand Down
6 changes: 3 additions & 3 deletions python/pyspark/ml/tests/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,19 @@ def test_read_images(self):
self.assertEqual(ImageSchema.undefinedImageType, "Undefined")

with QuietTest(self.sc):
self.assertRaisesRegexp(
self.assertRaisesRegex(
TypeError,
"image argument should be pyspark.sql.types.Row; however",
lambda: ImageSchema.toNDArray("a"))

with QuietTest(self.sc):
self.assertRaisesRegexp(
self.assertRaisesRegex(
ValueError,
"image argument should have attributes specified in",
lambda: ImageSchema.toNDArray(Row(a=1)))

with QuietTest(self.sc):
self.assertRaisesRegexp(
self.assertRaisesRegex(
TypeError,
"array argument should be numpy.ndarray; however, it got",
lambda: ImageSchema.toImage("a"))
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/ml/tests/test_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def test_logistic_regression_check_thresholds(self):
LogisticRegression
)

self.assertRaisesRegexp(
self.assertRaisesRegex(
ValueError,
"Logistic Regression getThreshold found inconsistent.*$",
LogisticRegression, threshold=0.42, thresholds=[0.5, 0.5]
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/ml/tests/test_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ def test_default_read_write_default_params(self):
del metadata['defaultParamMap']
metadataStr = json.dumps(metadata, separators=[',', ':'])
loadedMetadata = reader._parseMetaData(metadataStr, )
with self.assertRaisesRegexp(AssertionError, "`defaultParamMap` section not found"):
with self.assertRaisesRegex(AssertionError, "`defaultParamMap` section not found"):
reader.getAndSetParams(lr, loadedMetadata)

# Prior to 2.4.0, metadata doesn't have `defaultParamMap`.
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/ml/tests/test_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,15 +499,15 @@ def test_invalid_user_specified_folds(self):
evaluator=evaluator,
numFolds=2,
foldCol="fold")
with self.assertRaisesRegexp(Exception, "Fold number must be in range"):
with self.assertRaisesRegex(Exception, "Fold number must be in range"):
cv.fit(dataset_with_folds)

cv = CrossValidator(estimator=lr,
estimatorParamMaps=grid,
evaluator=evaluator,
numFolds=4,
foldCol="fold")
with self.assertRaisesRegexp(Exception, "The validation data at fold 3 is empty"):
with self.assertRaisesRegex(Exception, "The validation data at fold 3 is empty"):
cv.fit(dataset_with_folds)


Expand Down
6 changes: 3 additions & 3 deletions python/pyspark/ml/tests/test_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_java_object_gets_detached(self):
model.__del__()

def condition():
with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object):
with self.assertRaisesRegex(py4j.protocol.Py4JError, error_no_object):
model._java_obj.toString()
self.assertIn("LinearRegressionTrainingSummary", summary._java_obj.toString())
return True
Expand All @@ -67,9 +67,9 @@ def condition():
pass

def condition():
with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object):
with self.assertRaisesRegex(py4j.protocol.Py4JError, error_no_object):
model._java_obj.toString()
with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object):
with self.assertRaisesRegex(py4j.protocol.Py4JError, error_no_object):
summary._java_obj.toString()
return True

Expand Down
28 changes: 14 additions & 14 deletions python/pyspark/sql/tests/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

if have_pandas:
import pandas as pd
from pandas.util.testing import assert_frame_equal
from pandas.testing import assert_frame_equal

if have_pyarrow:
import pyarrow as pa # noqa: F401
Expand Down Expand Up @@ -137,7 +137,7 @@ def test_toPandas_fallback_disabled(self):
df = self.spark.createDataFrame([(None,)], schema=schema)
with QuietTest(self.sc):
with self.warnings_lock:
with self.assertRaisesRegexp(Exception, 'Unsupported type'):
with self.assertRaisesRegex(Exception, 'Unsupported type'):
df.toPandas()

def test_null_conversion(self):
Expand Down Expand Up @@ -214,7 +214,7 @@ def raise_exception():
exception_udf = udf(raise_exception, IntegerType())
df = df.withColumn("error", exception_udf())
with QuietTest(self.sc):
with self.assertRaisesRegexp(Exception, 'My error'):
with self.assertRaisesRegex(Exception, 'My error'):
df.toPandas()

def _createDataFrame_toggle(self, pdf, schema=None):
Expand All @@ -228,7 +228,7 @@ def _createDataFrame_toggle(self, pdf, schema=None):
def test_createDataFrame_toggle(self):
pdf = self.create_pandas_data_frame()
df_no_arrow, df_arrow = self._createDataFrame_toggle(pdf, schema=self.schema)
self.assertEquals(df_no_arrow.collect(), df_arrow.collect())
self.assertEqual(df_no_arrow.collect(), df_arrow.collect())

def test_createDataFrame_respect_session_timezone(self):
from datetime import timedelta
Expand Down Expand Up @@ -258,7 +258,7 @@ def test_createDataFrame_respect_session_timezone(self):
def test_createDataFrame_with_schema(self):
pdf = self.create_pandas_data_frame()
df = self.spark.createDataFrame(pdf, schema=self.schema)
self.assertEquals(self.schema, df.schema)
self.assertEqual(self.schema, df.schema)
pdf_arrow = df.toPandas()
assert_frame_equal(pdf_arrow, pdf)

Expand All @@ -269,31 +269,31 @@ def test_createDataFrame_with_incorrect_schema(self):
wrong_schema = StructType(fields)
with self.sql_conf({"spark.sql.execution.pandas.convertToArrowArraySafely": False}):
with QuietTest(self.sc):
with self.assertRaisesRegexp(Exception, "[D|d]ecimal.*got.*date"):
with self.assertRaisesRegex(Exception, "[D|d]ecimal.*got.*date"):
self.spark.createDataFrame(pdf, schema=wrong_schema)

def test_createDataFrame_with_names(self):
pdf = self.create_pandas_data_frame()
new_names = list(map(str, range(len(self.schema.fieldNames()))))
# Test that schema as a list of column names gets applied
df = self.spark.createDataFrame(pdf, schema=list(new_names))
self.assertEquals(df.schema.fieldNames(), new_names)
self.assertEqual(df.schema.fieldNames(), new_names)
# Test that schema as tuple of column names gets applied
df = self.spark.createDataFrame(pdf, schema=tuple(new_names))
self.assertEquals(df.schema.fieldNames(), new_names)
self.assertEqual(df.schema.fieldNames(), new_names)

def test_createDataFrame_column_name_encoding(self):
pdf = pd.DataFrame({u'a': [1]})
columns = self.spark.createDataFrame(pdf).columns
self.assertTrue(isinstance(columns[0], str))
self.assertEquals(columns[0], 'a')
self.assertEqual(columns[0], 'a')
columns = self.spark.createDataFrame(pdf, [u'b']).columns
self.assertTrue(isinstance(columns[0], str))
self.assertEquals(columns[0], 'b')
self.assertEqual(columns[0], 'b')

def test_createDataFrame_with_single_data_type(self):
with QuietTest(self.sc):
with self.assertRaisesRegexp(ValueError, ".*IntegerType.*not supported.*"):
with self.assertRaisesRegex(ValueError, ".*IntegerType.*not supported.*"):
self.spark.createDataFrame(pd.DataFrame({"a": [1]}), schema="int")

def test_createDataFrame_does_not_modify_input(self):
Expand All @@ -311,7 +311,7 @@ def test_schema_conversion_roundtrip(self):
from pyspark.sql.pandas.types import from_arrow_schema, to_arrow_schema
arrow_schema = to_arrow_schema(self.schema)
schema_rt = from_arrow_schema(arrow_schema)
self.assertEquals(self.schema, schema_rt)
self.assertEqual(self.schema, schema_rt)

def test_createDataFrame_with_array_type(self):
pdf = pd.DataFrame({"a": [[1, 2], [3, 4]], "b": [[u"x", u"y"], [u"y", u"z"]]})
Expand Down Expand Up @@ -420,7 +420,7 @@ def test_createDataFrame_fallback_enabled(self):

def test_createDataFrame_fallback_disabled(self):
with QuietTest(self.sc):
with self.assertRaisesRegexp(TypeError, 'Unsupported type'):
with self.assertRaisesRegex(TypeError, 'Unsupported type'):
self.spark.createDataFrame(
pd.DataFrame({"a": [[datetime.datetime(2015, 11, 1, 0, 30)]]}),
"a: array<timestamp>")
Expand Down Expand Up @@ -545,7 +545,7 @@ def tearDownClass(cls):
cls.spark.stop()

def test_exception_by_max_results(self):
with self.assertRaisesRegexp(Exception, "is bigger than"):
with self.assertRaisesRegex(Exception, "is bigger than"):
self.spark.range(0, 10000, 1, 100).toPandas()


Expand Down
56 changes: 28 additions & 28 deletions python/pyspark/sql/tests/test_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ class CatalogTests(ReusedSQLTestCase):
def test_current_database(self):
spark = self.spark
with self.database("some_db"):
self.assertEquals(spark.catalog.currentDatabase(), "default")
self.assertEqual(spark.catalog.currentDatabase(), "default")
spark.sql("CREATE DATABASE some_db")
spark.catalog.setCurrentDatabase("some_db")
self.assertEquals(spark.catalog.currentDatabase(), "some_db")
self.assertRaisesRegexp(
self.assertEqual(spark.catalog.currentDatabase(), "some_db")
self.assertRaisesRegex(
AnalysisException,
"does_not_exist",
lambda: spark.catalog.setCurrentDatabase("does_not_exist"))
Expand All @@ -38,10 +38,10 @@ def test_list_databases(self):
spark = self.spark
with self.database("some_db"):
databases = [db.name for db in spark.catalog.listDatabases()]
self.assertEquals(databases, ["default"])
self.assertEqual(databases, ["default"])
spark.sql("CREATE DATABASE some_db")
databases = [db.name for db in spark.catalog.listDatabases()]
self.assertEquals(sorted(databases), ["default", "some_db"])
self.assertEqual(sorted(databases), ["default", "some_db"])

def test_list_tables(self):
from pyspark.sql.catalog import Table
Expand All @@ -50,8 +50,8 @@ def test_list_tables(self):
spark.sql("CREATE DATABASE some_db")
with self.table("tab1", "some_db.tab2", "tab3_via_catalog"):
with self.tempView("temp_tab"):
self.assertEquals(spark.catalog.listTables(), [])
self.assertEquals(spark.catalog.listTables("some_db"), [])
self.assertEqual(spark.catalog.listTables(), [])
self.assertEqual(spark.catalog.listTables("some_db"), [])
spark.createDataFrame([(1, 1)]).createOrReplaceTempView("temp_tab")
spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING parquet")
spark.sql("CREATE TABLE some_db.tab2 (name STRING, age INT) USING parquet")
Expand All @@ -66,40 +66,40 @@ def test_list_tables(self):
sorted(spark.catalog.listTables("default"), key=lambda t: t.name)
tablesSomeDb = \
sorted(spark.catalog.listTables("some_db"), key=lambda t: t.name)
self.assertEquals(tables, tablesDefault)
self.assertEquals(len(tables), 3)
self.assertEquals(len(tablesSomeDb), 2)
self.assertEquals(tables[0], Table(
self.assertEqual(tables, tablesDefault)
self.assertEqual(len(tables), 3)
self.assertEqual(len(tablesSomeDb), 2)
self.assertEqual(tables[0], Table(
name="tab1",
database="default",
description=None,
tableType="MANAGED",
isTemporary=False))
self.assertEquals(tables[1], Table(
self.assertEqual(tables[1], Table(
name="tab3_via_catalog",
database="default",
description=description,
tableType="MANAGED",
isTemporary=False))
self.assertEquals(tables[2], Table(
self.assertEqual(tables[2], Table(
name="temp_tab",
database=None,
description=None,
tableType="TEMPORARY",
isTemporary=True))
self.assertEquals(tablesSomeDb[0], Table(
self.assertEqual(tablesSomeDb[0], Table(
name="tab2",
database="some_db",
description=None,
tableType="MANAGED",
isTemporary=False))
self.assertEquals(tablesSomeDb[1], Table(
self.assertEqual(tablesSomeDb[1], Table(
name="temp_tab",
database=None,
description=None,
tableType="TEMPORARY",
isTemporary=True))
self.assertRaisesRegexp(
self.assertRaisesRegex(
AnalysisException,
"does_not_exist",
lambda: spark.catalog.listTables("does_not_exist"))
Expand All @@ -119,12 +119,12 @@ def test_list_functions(self):
self.assertTrue("to_timestamp" in functions)
self.assertTrue("to_unix_timestamp" in functions)
self.assertTrue("current_database" in functions)
self.assertEquals(functions["+"], Function(
self.assertEqual(functions["+"], Function(
name="+",
description=None,
className="org.apache.spark.sql.catalyst.expressions.Add",
isTemporary=True))
self.assertEquals(functions, functionsDefault)
self.assertEqual(functions, functionsDefault)

with self.function("func1", "some_db.func2"):
spark.catalog.registerFunction("temp_func", lambda x: str(x))
Expand All @@ -141,7 +141,7 @@ def test_list_functions(self):
self.assertTrue("temp_func" in newFunctionsSomeDb)
self.assertTrue("func1" not in newFunctionsSomeDb)
self.assertTrue("func2" in newFunctionsSomeDb)
self.assertRaisesRegexp(
self.assertRaisesRegex(
AnalysisException,
"does_not_exist",
lambda: spark.catalog.listFunctions("does_not_exist"))
Expand All @@ -158,16 +158,16 @@ def test_list_columns(self):
columns = sorted(spark.catalog.listColumns("tab1"), key=lambda c: c.name)
columnsDefault = \
sorted(spark.catalog.listColumns("tab1", "default"), key=lambda c: c.name)
self.assertEquals(columns, columnsDefault)
self.assertEquals(len(columns), 2)
self.assertEquals(columns[0], Column(
self.assertEqual(columns, columnsDefault)
self.assertEqual(len(columns), 2)
self.assertEqual(columns[0], Column(
name="age",
description=None,
dataType="int",
nullable=True,
isPartition=False,
isBucket=False))
self.assertEquals(columns[1], Column(
self.assertEqual(columns[1], Column(
name="name",
description=None,
dataType="string",
Expand All @@ -176,26 +176,26 @@ def test_list_columns(self):
isBucket=False))
columns2 = \
sorted(spark.catalog.listColumns("tab2", "some_db"), key=lambda c: c.name)
self.assertEquals(len(columns2), 2)
self.assertEquals(columns2[0], Column(
self.assertEqual(len(columns2), 2)
self.assertEqual(columns2[0], Column(
name="nickname",
description=None,
dataType="string",
nullable=True,
isPartition=False,
isBucket=False))
self.assertEquals(columns2[1], Column(
self.assertEqual(columns2[1], Column(
name="tolerance",
description=None,
dataType="float",
nullable=True,
isPartition=False,
isBucket=False))
self.assertRaisesRegexp(
self.assertRaisesRegex(
AnalysisException,
"tab2",
lambda: spark.catalog.listColumns("tab2"))
self.assertRaisesRegexp(
self.assertRaisesRegex(
AnalysisException,
"does_not_exist",
lambda: spark.catalog.listColumns("does_not_exist"))
Expand Down
10 changes: 5 additions & 5 deletions python/pyspark/sql/tests/test_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_validate_column_types(self):
self.assertTrue("Column" in _to_java_column(u"a").getClass().toString())
self.assertTrue("Column" in _to_java_column(self.spark.range(1).id).getClass().toString())

self.assertRaisesRegexp(
self.assertRaisesRegex(
TypeError,
"Invalid argument, not a string or column",
lambda: _to_java_column(1))
Expand All @@ -58,7 +58,7 @@ class A():
self.assertRaises(TypeError, lambda: _to_java_column(A()))
self.assertRaises(TypeError, lambda: _to_java_column([]))

self.assertRaisesRegexp(
self.assertRaisesRegex(
TypeError,
"Invalid argument, not a string or column",
lambda: udf(lambda x: x)(None))
Expand All @@ -79,9 +79,9 @@ def test_column_operators(self):
cs.startswith('a'), cs.endswith('a'), ci.eqNullSafe(cs)
self.assertTrue(all(isinstance(c, Column) for c in css))
self.assertTrue(isinstance(ci.cast(LongType()), Column))
self.assertRaisesRegexp(ValueError,
"Cannot apply 'in' operator against a column",
lambda: 1 in cs)
self.assertRaisesRegex(ValueError,
"Cannot apply 'in' operator against a column",
lambda: 1 in cs)

def test_column_accessor(self):
from pyspark.sql.functions import col
Expand Down
Loading

0 comments on commit aeb3649

Please sign in to comment.