Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-44839][SS][CONNECT] Better Error Logging when user tries to serialize spark session #42594

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions python/pyspark/errors/error_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,11 @@
"State is either not defined or has already been removed."
]
},
"STREAMING_CONNECT_SERIALIZATION_ERROR" : {
"message" : [
"Cannot serialize the function `<name>`. If you accessed the spark session, or a dataframe defined outside of the function, please be aware that they are not allowed in Spark Connect. For foreachBatch, please access the spark session using `df.sparkSession`, where `df` is the first parameter in your foreachBatch function. For StreamingQueryListener, please access the spark session using `self.spark`. For details please check out the PySpark doc for foreachBatch and StreamingQueryListener."
]
},
"STOP_ITERATION_OCCURRED" : {
"message" : [
"Caught StopIteration thrown from user's code; failing the task: <exc>"
Expand Down
10 changes: 9 additions & 1 deletion python/pyspark/sql/connect/streaming/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import json
import sys
import pickle
from typing import TYPE_CHECKING, Any, cast, Dict, List, Optional

from pyspark.errors import StreamingQueryException, PySparkValueError
Expand All @@ -32,6 +33,7 @@
from pyspark.errors.exceptions.connect import (
StreamingQueryException as CapturedStreamingQueryException,
)
from pyspark.errors import PySparkRuntimeError

__all__ = ["StreamingQuery", "StreamingQueryManager"]

Expand Down Expand Up @@ -237,7 +239,13 @@ def addListener(self, listener: StreamingQueryListener) -> None:
listener._init_listener_id()
cmd = pb2.StreamingQueryManagerCommand()
expr = proto.PythonUDF()
expr.command = CloudPickleSerializer().dumps(listener)
try:
expr.command = CloudPickleSerializer().dumps(listener)
except pickle.PicklingError:
raise PySparkRuntimeError(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@itholic do we need a dedicated error class for PicklingError? e.g., PySparkPicklingError?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I believe we need a new error class for new type of user-facing errors. Could you add a new PySpark error class for representing pickle.PicklingError?? See https://github.com/apache/spark/pull/40938/files as an example. I think we can also do it as a follow ups.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes sure I could do that. Just want to confirm the ask is to define a new PySparkPicklingError and replace this PySparkRuntimeError with that right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

define a new PySparkPicklingError and replace this PySparkRuntimeError with that right?

Correct :-)

error_class="STREAMING_CONNECT_SERIALIZATION_ERROR",
message_parameters={"name": "addListener"},
)
expr.python_ver = get_python_ver()
cmd.add_listener.python_listener_payload.CopyFrom(expr)
cmd.add_listener.id = listener._id
Expand Down
27 changes: 20 additions & 7 deletions python/pyspark/sql/connect/streaming/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
check_dependencies(__name__)

import sys
import pickle
from typing import cast, overload, Callable, Dict, List, Optional, TYPE_CHECKING, Union

from pyspark.serializers import CloudPickleSerializer
Expand All @@ -33,7 +34,7 @@
)
from pyspark.sql.connect.utils import get_python_ver
from pyspark.sql.types import Row, StructType
from pyspark.errors import PySparkTypeError, PySparkValueError
from pyspark.errors import PySparkTypeError, PySparkValueError, PySparkRuntimeError

if TYPE_CHECKING:
from pyspark.sql.connect.session import SparkSession
Expand Down Expand Up @@ -488,18 +489,30 @@ def foreach(self, f: Union[Callable[[Row], None], "SupportsProcess"]) -> "DataSt
serializer = AutoBatchedSerializer(CPickleSerializer())
command = (func, None, serializer, serializer)
# Python ForeachWriter isn't really a PythonUDF. But we reuse it for simplicity.
self._write_proto.foreach_writer.python_function.command = CloudPickleSerializer().dumps(
command
)
try:
self._write_proto.foreach_writer.python_function.command = (
CloudPickleSerializer().dumps(command)
)
except pickle.PicklingError:
raise PySparkRuntimeError(
error_class="STREAMING_CONNECT_SERIALIZATION_ERROR",
message_parameters={"name": "foreach"},
)
self._write_proto.foreach_writer.python_function.python_ver = "%d.%d" % sys.version_info[:2]
return self

foreach.__doc__ = PySparkDataStreamWriter.foreach.__doc__

def foreachBatch(self, func: Callable[["DataFrame", int], None]) -> "DataStreamWriter":
self._write_proto.foreach_batch.python_function.command = CloudPickleSerializer().dumps(
func
)
try:
self._write_proto.foreach_batch.python_function.command = CloudPickleSerializer().dumps(
func
)
except pickle.PicklingError:
raise PySparkRuntimeError(
error_class="STREAMING_CONNECT_SERIALIZATION_ERROR",
message_parameters={"name": "foreachBatch"},
)
self._write_proto.foreach_batch.python_function.python_ver = get_python_ver()
return self

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from pyspark.sql.tests.streaming.test_streaming_foreachBatch import StreamingTestsForeachBatchMixin
from pyspark.testing.connectutils import ReusedConnectTestCase
from pyspark.errors import PySparkRuntimeError


class StreamingForeachBatchParityTests(StreamingTestsForeachBatchMixin, ReusedConnectTestCase):
Expand All @@ -30,6 +31,35 @@ def test_streaming_foreachBatch_propagates_python_errors(self):
def test_streaming_foreachBatch_graceful_stop(self):
super().test_streaming_foreachBatch_graceful_stop()

# class StreamingForeachBatchParityTests(ReusedConnectTestCase):
def test_accessing_spark_session(self):
spark = self.spark

def func(df, _):
spark.createDataFrame([("do", "not"), ("serialize", "spark")]).collect()

error_thrown = False
try:
self.spark.readStream.format("rate").load().writeStream.foreachBatch(func).start()
except PySparkRuntimeError as e:
self.assertEqual(e.getErrorClass(), "STREAMING_CONNECT_SERIALIZATION_ERROR")
error_thrown = True
self.assertTrue(error_thrown)

def test_accessing_spark_session_through_df(self):
dataframe = self.spark.createDataFrame([("do", "not"), ("serialize", "dataframe")])

def func(df, _):
dataframe.collect()

error_thrown = False
try:
self.spark.readStream.format("rate").load().writeStream.foreachBatch(func).start()
except PySparkRuntimeError as e:
self.assertEqual(e.getErrorClass(), "STREAMING_CONNECT_SERIALIZATION_ERROR")
error_thrown = True
self.assertTrue(error_thrown)


if __name__ == "__main__":
import unittest
Expand Down
49 changes: 49 additions & 0 deletions python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import unittest
import time

from pyspark.errors import PySparkRuntimeError
from pyspark.sql.tests.streaming.test_streaming_listener import StreamingListenerTestsMixin
from pyspark.sql.streaming.listener import StreamingQueryListener, QueryStartedEvent
from pyspark.sql.types import StructType, StructField, StringType
Expand Down Expand Up @@ -83,6 +84,54 @@ def test_listener_events(self):
# Remove again to verify this won't throw any error
self.spark.streams.removeListener(test_listener)

def test_accessing_spark_session(self):
spark = self.spark

class TestListener(StreamingQueryListener):
def onQueryStarted(self, event):
spark.createDataFrame([("do", "not"), ("serialize", "spark")]).collect()

def onQueryProgress(self, event):
pass

def onQueryIdle(self, event):
pass

def onQueryTerminated(self, event):
pass

error_thrown = False
try:
self.spark.streams.addListener(TestListener())
except PySparkRuntimeError as e:
self.assertEqual(e.getErrorClass(), "STREAMING_CONNECT_SERIALIZATION_ERROR")
error_thrown = True
self.assertTrue(error_thrown)

def test_accessing_spark_session_through_df(self):
dataframe = self.spark.createDataFrame([("do", "not"), ("serialize", "dataframe")])

class TestListener(StreamingQueryListener):
def onQueryStarted(self, event):
dataframe.collect()

def onQueryProgress(self, event):
pass

def onQueryIdle(self, event):
pass

def onQueryTerminated(self, event):
pass

error_thrown = False
try:
self.spark.streams.addListener(TestListener())
except PySparkRuntimeError as e:
self.assertEqual(e.getErrorClass(), "STREAMING_CONNECT_SERIALIZATION_ERROR")
error_thrown = True
self.assertTrue(error_thrown)


if __name__ == "__main__":
import unittest
Expand Down