Skip to content

[SPARK-27992][SPARK-28881][PYTHON][2.4] Allow Python to join with connection thread to propagate errors #25593

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,29 @@ private[spark] object PythonRDD extends Logging {
Array(port, secret)
}

/**
* Create a socket server object and background thread to execute the writeFunc
* with the given OutputStream.
*
* This is the same as serveToStream, only it returns a server object that
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: only -> but

Copy link
Member

@gatorsmile gatorsmile Aug 29, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we update the comment of serveToStream in the master branch? This might be a common mistake if the contributors are not aware of the trap.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the comment was updated at the PR against master branch -https://github.com/apache/spark/pull/24834/files#diff-1f54938136d72cd234ae55003c91d565R111-R122

* can be used to sync in Python.
*/
private[spark] def serveToStreamWithSync(
threadName: String)(writeFunc: OutputStream => Unit): Array[Any] = {

val handleFunc = (sock: Socket) => {
val out = new BufferedOutputStream(sock.getOutputStream())
Utils.tryWithSafeFinally {
writeFunc(out)
} {
out.close()
}
}

val server = new SocketFuncServer(authHelper, threadName, handleFunc)
Array(server.port, server.secret, server)
}

private def getMergedConf(confAsMap: java.util.HashMap[String, String],
baseConf: Configuration): Configuration = {
val conf = PythonHadoopUtil.mapToConf(confAsMap)
Expand Down Expand Up @@ -957,3 +980,17 @@ private[spark] class PythonParallelizeServer(sc: SparkContext, parallelism: Int)
}
}

/**
* Create a socket server class and run user function on the socket in a background thread.
* This is the same as calling SocketAuthServer.setupOneConnectionServer except it creates
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SocketAuthServer.setupOneConnectionServer sets the timeout 15 seconds. This one does not set it. What is the reason we set it in the past?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there was a specific reason about the timeout. Cc @vanzin and @squito

Copy link
Member Author

@HyukjinKwon HyukjinKwon Aug 30, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, there are multiple pleases for such hardcoded timeout - e.g.

I suspect it won't be a major issue.

Copy link
Member Author

@HyukjinKwon HyukjinKwon Aug 30, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI. @HyukjinKwon and @cloud-fan also found the timeout value only affects accept()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @gatorsmile . Do you mean it affects the on-going 2.4.4 vote?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry I think I'm a bit lost in the discussion -- @gatorsmile do you think something is wrong here or not? seems OK to me, the timeout is the same as before, just in a different spot.

I don't think the timeout is crucial for correctness, its more about getting sane errors if there is some bug and nothing connects back. Rather than having things block forever, you'll get a timeout exception.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC the timeout is for establishing the socket connection. Given we build local socket connection between JVM and Python, 10 seconds is fine.

And agree with @squito it's nothing about correctness, if weird thing happens users will get an error instead of wrong result.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, it's fine. I don't think there's any major issue with that.

* a server object that can then be synced from Python.
*/
private [spark] class SocketFuncServer(
authHelper: SocketAuthHelper,
threadName: String,
func: Socket => Unit) extends PythonServer[Unit](authHelper, threadName) {

override def handleConnection(sock: Socket): Unit = {
func(sock)
}
}
10 changes: 7 additions & 3 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2175,9 +2175,13 @@ def _collectAsArrow(self):

.. note:: Experimental.
"""
with SCCallSiteSync(self._sc) as css:
sock_info = self._jdf.collectAsArrowToPython()
return list(_load_from_socket(sock_info, ArrowStreamSerializer()))
with SCCallSiteSync(self._sc):
from pyspark.rdd import _load_from_socket
port, auth_secret, jsocket_auth_server = self._jdf.collectAsArrowToPython()
try:
return list(_load_from_socket((port, auth_secret), ArrowStreamSerializer()))
finally:
jsocket_auth_server.getResult() # Join serving thread and raise any exceptions

##########################################################################################
# Pandas compatibility
Expand Down
28 changes: 27 additions & 1 deletion python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
_have_pyarrow = _pyarrow_requirement_message is None
_test_compiled = _test_not_compiled_message is None

from pyspark import SparkContext
from pyspark import SparkContext, SparkConf
from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row
from pyspark.sql.types import *
from pyspark.sql.types import UserDefinedType, _infer_type, _make_type_verifier
Expand Down Expand Up @@ -4550,6 +4550,32 @@ def test_timestamp_dst(self):
self.assertPandasEqual(pdf, df_from_pandas.toPandas())


@unittest.skipIf(
not _have_pandas or not _have_pyarrow,
_pandas_requirement_message or _pyarrow_requirement_message)
class MaxResultArrowTests(unittest.TestCase):
# These tests are separate as 'spark.driver.maxResultSize' configuration
# is a static configuration to Spark context.

@classmethod
def setUpClass(cls):
cls.spark = SparkSession(SparkContext(
'local[4]', cls.__name__, conf=SparkConf().set("spark.driver.maxResultSize", "10k")))
Copy link
Member Author

@HyukjinKwon HyukjinKwon Aug 27, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, the last test failure looks weird and flaky (#25593 (comment)). This test itself passed but seems like previously set spark.driver.maxResultSize=10k affects the other tests even though I stop the session and context explicitly.

This is fine for now in the master branch because this test is in a separate file and launched in a separate process; however, this is potentially an issue.

Since the issue only happens when spark.builder is used, I am working around, in branch-2.4 specifically, by using SparkSession(SparkContext(...)) for now as it's an orthogonal issue.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we going to do the same change for master branch later?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can although it's fine in the master for now as I described. Let me make a followup later to match.


# Explicitly enable Arrow and disable fallback.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just to double-check, this test fails even if we do not set these 2 configs, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one configuration only spark.sql.execution.arrow.enabled because the partial results are being produced from Arrow optimized code path.

spark.sql.execution.arrow.fallback.enabled is just to make sure we only test Arrow optimized code path.

Copy link
Member

@dongjoon-hyun dongjoon-hyun Aug 27, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep. In branch-2.4, spark.sql.execution.arrow.enabled is false by default.

I verified this manually that this test doesn't fail if we remove this configurations. This is required in branch-2.4.

cls.spark.conf.set("spark.sql.execution.arrow.enabled", "true")
cls.spark.conf.set("spark.sql.execution.arrow.fallback.enabled", "false")

@classmethod
def tearDownClass(cls):
if hasattr(cls, "spark"):
cls.spark.stop()

def test_exception_by_max_results(self):
with self.assertRaisesRegexp(Exception, "is bigger than"):
self.spark.range(0, 10000, 1, 100).toPandas()


class EncryptionArrowTests(ArrowTests):

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3284,7 +3284,7 @@ class Dataset[T] private[sql](
val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone

withAction("collectAsArrowToPython", queryExecution) { plan =>
PythonRDD.serveToStream("serve-Arrow") { out =>
PythonRDD.serveToStreamWithSync("serve-Arrow") { out =>
val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId)
val arrowBatchRdd = toArrowBatchRdd(plan)
val numPartitions = arrowBatchRdd.partitions.length
Expand Down