Skip to content

[SPARK-23319][TESTS][BRANCH-2.3] Explicitly specify Pandas and PyArrow versions in PySpark tests (to skip or test) #20534

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,10 @@
<paranamer.version>2.8</paranamer.version>
<maven-antrun.version>1.8</maven-antrun.version>
<commons-crypto.version>1.0.0</commons-crypto.version>
<!--
If you are changing Arrow version specification, please check ./python/pyspark/sql/utils.py,
./python/run-tests.py and ./python/setup.py too.
-->
<arrow.version>0.8.0</arrow.version>

<test.java.home>${java.home}</test.java.home>
Expand Down
3 changes: 3 additions & 0 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1913,6 +1913,9 @@ def toPandas(self):
0 2 Alice
1 5 Bob
"""
from pyspark.sql.utils import require_minimum_pandas_version
require_minimum_pandas_version()

import pandas as pd

if self.sql_ctx.getConf("spark.sql.execution.pandas.respectSessionTimeZone").lower() \
Expand Down
3 changes: 3 additions & 0 deletions python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,9 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr
except Exception:
has_pandas = False
if has_pandas and isinstance(data, pandas.DataFrame):
from pyspark.sql.utils import require_minimum_pandas_version
require_minimum_pandas_version()

if self.conf.get("spark.sql.execution.pandas.respectSessionTimeZone").lower() \
== "true":
timezone = self.conf.get("spark.sql.session.timeZone")
Expand Down
83 changes: 45 additions & 38 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,26 @@
else:
import unittest

_have_pandas = False
_have_old_pandas = False
_pandas_requirement_message = None
try:
import pandas
try:
from pyspark.sql.utils import require_minimum_pandas_version
require_minimum_pandas_version()
_have_pandas = True
except:
_have_old_pandas = True
except:
# No Pandas, but that's okay, we'll skip those tests
pass
from pyspark.sql.utils import require_minimum_pandas_version
require_minimum_pandas_version()
except ImportError as e:
from pyspark.util import _exception_message
# If Pandas version requirement is not satisfied, skip related tests.
_pandas_requirement_message = _exception_message(e)

_pyarrow_requirement_message = None
try:
from pyspark.sql.utils import require_minimum_pyarrow_version
require_minimum_pyarrow_version()
except ImportError as e:
from pyspark.util import _exception_message
# If Arrow version requirement is not satisfied, skip related tests.
_pyarrow_requirement_message = _exception_message(e)

_have_pandas = _pandas_requirement_message is None
_have_pyarrow = _pyarrow_requirement_message is None

from pyspark import SparkContext
from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row
Expand All @@ -75,15 +82,6 @@
from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException


_have_arrow = False
try:
import pyarrow
_have_arrow = True
except:
# No Arrow, but that's okay, we'll skip those tests
pass


class UTCOffsetTimezone(datetime.tzinfo):
"""
Specifies timezone in UTC offset
Expand Down Expand Up @@ -2788,7 +2786,6 @@ def count_bucketed_cols(names, table="pyspark_bucket"):

def _to_pandas(self):
from datetime import datetime, date
import numpy as np
schema = StructType().add("a", IntegerType()).add("b", StringType())\
.add("c", BooleanType()).add("d", FloatType())\
.add("dt", DateType()).add("ts", TimestampType())
Expand All @@ -2801,7 +2798,7 @@ def _to_pandas(self):
df = self.spark.createDataFrame(data, schema)
return df.toPandas()

@unittest.skipIf(not _have_pandas, "Pandas not installed")
@unittest.skipIf(not _have_pandas, _pandas_requirement_message)
def test_to_pandas(self):
import numpy as np
pdf = self._to_pandas()
Expand All @@ -2813,13 +2810,13 @@ def test_to_pandas(self):
self.assertEquals(types[4], np.object) # datetime.date
self.assertEquals(types[5], 'datetime64[ns]')

@unittest.skipIf(not _have_old_pandas, "Old Pandas not installed")
def test_to_pandas_old(self):
@unittest.skipIf(_have_pandas, "Required Pandas was found.")
def test_to_pandas_required_pandas_not_found(self):
with QuietTest(self.sc):
with self.assertRaisesRegexp(ImportError, 'Pandas >= .* must be installed'):
self._to_pandas()

@unittest.skipIf(not _have_pandas, "Pandas not installed")
@unittest.skipIf(not _have_pandas, _pandas_requirement_message)
def test_to_pandas_avoid_astype(self):
import numpy as np
schema = StructType().add("a", IntegerType()).add("b", StringType())\
Expand All @@ -2837,7 +2834,7 @@ def test_create_dataframe_from_array_of_long(self):
df = self.spark.createDataFrame(data)
self.assertEqual(df.first(), Row(longarray=[-9223372036854775808, 0, 9223372036854775807]))

@unittest.skipIf(not _have_pandas, "Pandas not installed")
@unittest.skipIf(not _have_pandas, _pandas_requirement_message)
def test_create_dataframe_from_pandas_with_timestamp(self):
import pandas as pd
from datetime import datetime
Expand All @@ -2852,14 +2849,16 @@ def test_create_dataframe_from_pandas_with_timestamp(self):
self.assertTrue(isinstance(df.schema['ts'].dataType, TimestampType))
self.assertTrue(isinstance(df.schema['d'].dataType, DateType))

@unittest.skipIf(not _have_old_pandas, "Old Pandas not installed")
def test_create_dataframe_from_old_pandas(self):
import pandas as pd
from datetime import datetime
pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)],
"d": [pd.Timestamp.now().date()]})
@unittest.skipIf(_have_pandas, "Required Pandas was found.")
def test_create_dataframe_required_pandas_not_found(self):
with QuietTest(self.sc):
with self.assertRaisesRegexp(ImportError, 'Pandas >= .* must be installed'):
with self.assertRaisesRegexp(
ImportError,
"(Pandas >= .* must be installed|No module named '?pandas'?)"):
import pandas as pd
from datetime import datetime
pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)],
"d": [pd.Timestamp.now().date()]})
self.spark.createDataFrame(pdf)


Expand Down Expand Up @@ -3351,7 +3350,9 @@ def __init__(self, **kwargs):
_make_type_verifier(data_type, nullable=False)(obj)


@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
@unittest.skipIf(
not _have_pandas or not _have_pyarrow,
_pandas_requirement_message or _pyarrow_requirement_message)
class ArrowTests(ReusedSQLTestCase):

@classmethod
Expand Down Expand Up @@ -3615,7 +3616,9 @@ def test_createDataFrame_with_int_col_names(self):
self.assertEqual(pdf_col_names, df_arrow.columns)


@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
@unittest.skipIf(
not _have_pandas or not _have_pyarrow,
_pandas_requirement_message or _pyarrow_requirement_message)
class PandasUDFTests(ReusedSQLTestCase):
def test_pandas_udf_basic(self):
from pyspark.rdd import PythonEvalType
Expand Down Expand Up @@ -3739,7 +3742,9 @@ def foo(k, v):
return k


@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
@unittest.skipIf(
not _have_pandas or not _have_pyarrow,
_pandas_requirement_message or _pyarrow_requirement_message)
class ScalarPandasUDFTests(ReusedSQLTestCase):

@classmethod
Expand Down Expand Up @@ -4252,7 +4257,9 @@ def test_register_vectorized_udf_basic(self):
self.assertEquals(expected.collect(), res2.collect())


@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
@unittest.skipIf(
not _have_pandas or not _have_pyarrow,
_pandas_requirement_message or _pyarrow_requirement_message)
class GroupedMapPandasUDFTests(ReusedSQLTestCase):

def assertFramesEqual(self, expected, result):
Expand Down
30 changes: 22 additions & 8 deletions python/pyspark/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,18 +115,32 @@ def toJArray(gateway, jtype, arr):
def require_minimum_pandas_version():
""" Raise ImportError if minimum version of Pandas is not installed
"""
# TODO(HyukjinKwon): Relocate and deduplicate the version specification.
minimum_pandas_version = "0.19.2"

from distutils.version import LooseVersion
import pandas
if LooseVersion(pandas.__version__) < LooseVersion('0.19.2'):
raise ImportError("Pandas >= 0.19.2 must be installed on calling Python process; "
"however, your version was %s." % pandas.__version__)
try:
import pandas
except ImportError:
raise ImportError("Pandas >= %s must be installed; however, "
"it was not found." % minimum_pandas_version)
if LooseVersion(pandas.__version__) < LooseVersion(minimum_pandas_version):
raise ImportError("Pandas >= %s must be installed; however, "
"your version was %s." % (minimum_pandas_version, pandas.__version__))


def require_minimum_pyarrow_version():
""" Raise ImportError if minimum version of pyarrow is not installed
"""
# TODO(HyukjinKwon): Relocate and deduplicate the version specification.
minimum_pyarrow_version = "0.8.0"

from distutils.version import LooseVersion
import pyarrow
if LooseVersion(pyarrow.__version__) < LooseVersion('0.8.0'):
raise ImportError("pyarrow >= 0.8.0 must be installed on calling Python process; "
"however, your version was %s." % pyarrow.__version__)
try:
import pyarrow
except ImportError:
raise ImportError("PyArrow >= %s must be installed; however, "
"it was not found." % minimum_pyarrow_version)
if LooseVersion(pyarrow.__version__) < LooseVersion(minimum_pyarrow_version):
raise ImportError("PyArrow >= %s must be installed; however, "
"your version was %s." % (minimum_pyarrow_version, pyarrow.__version__))
10 changes: 9 additions & 1 deletion python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,11 @@ def _supports_symlinks():
file=sys.stderr)
exit(-1)

# If you are changing the versions here, please also change ./python/pyspark/sql/utils.py and
# ./python/run-tests.py. In case of Arrow, you should also check ./pom.xml.
_minimum_pandas_version = "0.19.2"
_minimum_pyarrow_version = "0.8.0"

try:
# We copy the shell script to be under pyspark/python/pyspark so that the launcher scripts
# find it where expected. The rest of the files aren't copied because they are accessed
Expand Down Expand Up @@ -201,7 +206,10 @@ def _supports_symlinks():
extras_require={
'ml': ['numpy>=1.7'],
'mllib': ['numpy>=1.7'],
'sql': ['pandas>=0.19.2', 'pyarrow>=0.8.0']
'sql': [
'pandas>=%s' % _minimum_pandas_version,
'pyarrow>=%s' % _minimum_pyarrow_version,
]
},
classifiers=[
'Development Status :: 5 - Production/Stable',
Expand Down