Skip to content

Fix test failure and few minor clean up for tests - 20404 #11

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 29, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,8 @@ def stop(self):
"""
if getattr(self, "_jsc", None):
try:
# We should clean the default session up. See SPARK-23228.
self._jvm.SparkSession.clearDefaultSession()
self._jsc.stop()
except Py4JError:
# Case: SPARK-18523
Expand Down
1 change: 0 additions & 1 deletion python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,6 @@ def stop(self):
"""Stop the underlying :class:`SparkContext`.
"""
self._sc.stop()
self._jvm.SparkSession.clearDefaultSession()
SparkSession._instantiatedSession = None

@since(2.0)
Expand Down
70 changes: 27 additions & 43 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
from pyspark.sql.types import _array_signed_int_typecode_ctype_mappings, _array_type_mappings
from pyspark.sql.types import _array_unsigned_int_typecode_ctype_mappings
from pyspark.sql.types import _merge_type
from pyspark.tests import QuietTest, ReusedPySparkTestCase, SparkSubmitTests
from pyspark.tests import QuietTest, ReusedPySparkTestCase, PySparkTestCase, SparkSubmitTests
from pyspark.sql.functions import UserDefinedFunction, sha2, lit
from pyspark.sql.window import Window
from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException
Expand Down Expand Up @@ -204,48 +204,6 @@ def assertPandasEqual(self, expected, result):
self.assertTrue(expected.equals(result), msg=msg)


class PySparkSessionTests(unittest.TestCase):

def test_set_jvm_default_session(self):
spark = None
sc = None
try:
sc = SparkContext('local[4]', "test_spark_session")
spark = SparkSession(sc)
self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isDefined())
finally:
if spark is not None:
spark.stop()
self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isEmpty())
spark = None
sc = None

if sc is not None:
sc.stop()
sc = None

def test_jvm_default_session_already_set(self):
spark = None
sc = None
try:
sc = SparkContext('local[4]', "test_spark_session")
jsession = sc._jvm.SparkSession(sc._jsc.sc())
sc._jvm.SparkSession.setDefaultSession(jsession)

spark = SparkSession(sc, jsession)
self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isDefined())
self.assertTrue(jsession.equals(spark._jvm.SparkSession.getDefaultSession().get()))
finally:
if spark is not None:
spark.stop()
spark = None
sc = None

if sc is not None:
sc.stop()
sc = None


class DataTypeTests(unittest.TestCase):
# regression test for SPARK-6055
def test_data_type_eq(self):
Expand Down Expand Up @@ -2954,6 +2912,32 @@ def test_sparksession_with_stopped_sparkcontext(self):
sc.stop()


class SparkSessionTests(PySparkTestCase):

# This test is separate because it's closely related with session's start and stop.
# See SPARK-23228.
def test_set_jvm_default_session(self):
spark = SparkSession.builder.getOrCreate()
try:
self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isDefined())
finally:
spark.stop()
self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isEmpty())

def test_jvm_default_session_already_set(self):
# Here, we assume there is the default session already set in JVM.
jsession = self.sc._jvm.SparkSession(self.sc._jsc.sc())
self.sc._jvm.SparkSession.setDefaultSession(jsession)

spark = SparkSession.builder.getOrCreate()
try:
self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isDefined())
# The session should be the same with the exiting one.
self.assertTrue(jsession.equals(spark._jvm.SparkSession.getDefaultSession().get()))
finally:
spark.stop()


class UDFInitializationTests(unittest.TestCase):
def tearDown(self):
if SparkSession._instantiatedSession is not None:
Expand Down