Skip to content

Commit 883dad2

Browse files
committed
[SPARK-50909][PYTHON] Setup faulthandler in PythonPlannerRunners
### What changes were proposed in this pull request? Setups `faulthandler` in `PythonPlannerRunner`s. It can be enabled by the same config as UDFs. - SQL conf: `spark.sql.execution.pyspark.udf.faulthandler.enabled` - It fallback to Spark conf: `spark.python.worker.faulthandler.enabled` - `False` by default ### Why are the changes needed? The `faulthandler` is not set up in `PythonPlannerRunner`s. ### Does this PR introduce _any_ user-facing change? When enabled, if Python worker crashes, it may generate thread-dump in the error message on the best-effort basis of Python process. ### How was this patch tested? Added the related tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes apache#49592 from ueshin/issues/SPARK-50909/faulthandler. Authored-by: Takuya Ueshin <ueshin@databricks.com> Signed-off-by: Takuya Ueshin <ueshin@databricks.com>
1 parent 1bf4720 commit 883dad2

File tree

11 files changed

+275
-14
lines changed

11 files changed

+275
-14
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import org.apache.spark._
3333
import org.apache.spark.api.python.PythonFunction.PythonAccumulator
3434
import org.apache.spark.internal.{Logging, LogKeys, MDC}
3535
import org.apache.spark.internal.LogKeys.TASK_NAME
36-
import org.apache.spark.internal.config.{BUFFER_SIZE, EXECUTOR_CORES, Python}
36+
import org.apache.spark.internal.config.{BUFFER_SIZE, EXECUTOR_CORES}
3737
import org.apache.spark.internal.config.Python._
3838
import org.apache.spark.rdd.InputFileBlockHolder
3939
import org.apache.spark.resource.ResourceProfile.{EXECUTOR_CORES_LOCAL_PROPERTY, PYSPARK_MEMORY_LOCAL_PROPERTY}
@@ -90,11 +90,11 @@ private[spark] object PythonEvalType {
9090
}
9191
}
9292

93-
private object BasePythonRunner {
93+
private[spark] object BasePythonRunner {
9494

95-
private lazy val faultHandlerLogDir = Utils.createTempDir(namePrefix = "faulthandler")
95+
private[spark] lazy val faultHandlerLogDir = Utils.createTempDir(namePrefix = "faulthandler")
9696

97-
private def faultHandlerLogPath(pid: Int): Path = {
97+
private[spark] def faultHandlerLogPath(pid: Int): Path = {
9898
new File(faultHandlerLogDir, pid.toString).toPath
9999
}
100100
}
@@ -574,15 +574,15 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
574574
JavaFiles.deleteIfExists(path)
575575
throw new SparkException(s"Python worker exited unexpectedly (crashed): $error", e)
576576

577-
case eof: EOFException if !faultHandlerEnabled =>
577+
case e: IOException if !faultHandlerEnabled =>
578578
throw new SparkException(
579579
s"Python worker exited unexpectedly (crashed). " +
580580
"Consider setting 'spark.sql.execution.pyspark.udf.faulthandler.enabled' or" +
581-
s"'${Python.PYTHON_WORKER_FAULTHANLDER_ENABLED.key}' configuration to 'true' for" +
582-
"the better Python traceback.", eof)
581+
s"'${PYTHON_WORKER_FAULTHANLDER_ENABLED.key}' configuration to 'true' for " +
582+
"the better Python traceback.", e)
583583

584-
case eof: EOFException =>
585-
throw new SparkException("Python worker exited unexpectedly (crashed)", eof)
584+
case e: IOException =>
585+
throw new SparkException("Python worker exited unexpectedly (crashed)", e)
586586
}
587587
}
588588

python/pyspark/sql/tests/test_python_datasource.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,126 @@ def write(self, iterator):
508508
):
509509
df.write.format("test").mode("append").saveAsTable("test_table")
510510

511+
def test_data_source_segfault(self):
512+
import ctypes
513+
514+
for enabled, expected in [
515+
(True, "Segmentation fault"),
516+
(False, "Consider setting .* for the better Python traceback."),
517+
]:
518+
with self.subTest(enabled=enabled), self.sql_conf(
519+
{"spark.sql.execution.pyspark.udf.faulthandler.enabled": enabled}
520+
):
521+
with self.subTest(worker="pyspark.sql.worker.create_data_source"):
522+
523+
class TestDataSource(DataSource):
524+
@classmethod
525+
def name(cls):
526+
return "test"
527+
528+
def schema(self):
529+
return ctypes.string_at(0)
530+
531+
self.spark.dataSource.register(TestDataSource)
532+
533+
with self.assertRaisesRegex(Exception, expected):
534+
self.spark.read.format("test").load().show()
535+
536+
with self.subTest(worker="pyspark.sql.worker.plan_data_source_read"):
537+
538+
class TestDataSource(DataSource):
539+
@classmethod
540+
def name(cls):
541+
return "test"
542+
543+
def schema(self):
544+
return "x string"
545+
546+
def reader(self, schema):
547+
return TestReader()
548+
549+
class TestReader(DataSourceReader):
550+
def partitions(self):
551+
ctypes.string_at(0)
552+
return []
553+
554+
def read(self, partition):
555+
return []
556+
557+
self.spark.dataSource.register(TestDataSource)
558+
559+
with self.assertRaisesRegex(Exception, expected):
560+
self.spark.read.format("test").load().show()
561+
562+
with self.subTest(worker="pyspark.worker"):
563+
564+
class TestDataSource(DataSource):
565+
@classmethod
566+
def name(cls):
567+
return "test"
568+
569+
def schema(self):
570+
return "x string"
571+
572+
def reader(self, schema):
573+
return TestReader()
574+
575+
class TestReader(DataSourceReader):
576+
def read(self, partition):
577+
ctypes.string_at(0)
578+
yield "x",
579+
580+
self.spark.dataSource.register(TestDataSource)
581+
582+
with self.assertRaisesRegex(Exception, expected):
583+
self.spark.read.format("test").load().show()
584+
585+
with self.subTest(worker="pyspark.sql.worker.write_into_data_source"):
586+
587+
class TestDataSource(DataSource):
588+
@classmethod
589+
def name(cls):
590+
return "test"
591+
592+
def writer(self, schema, overwrite):
593+
return TestWriter()
594+
595+
class TestWriter(DataSourceWriter):
596+
def write(self, iterator):
597+
ctypes.string_at(0)
598+
return WriterCommitMessage()
599+
600+
self.spark.dataSource.register(TestDataSource)
601+
602+
with self.assertRaisesRegex(Exception, expected):
603+
self.spark.range(10).write.format("test").mode("append").saveAsTable(
604+
"test_table"
605+
)
606+
607+
with self.subTest(worker="pyspark.sql.worker.commit_data_source_write"):
608+
609+
class TestDataSource(DataSource):
610+
@classmethod
611+
def name(cls):
612+
return "test"
613+
614+
def writer(self, schema, overwrite):
615+
return TestWriter()
616+
617+
class TestWriter(DataSourceWriter):
618+
def write(self, iterator):
619+
return WriterCommitMessage()
620+
621+
def commit(self, messages):
622+
ctypes.string_at(0)
623+
624+
self.spark.dataSource.register(TestDataSource)
625+
626+
with self.assertRaisesRegex(Exception, expected):
627+
self.spark.range(10).write.format("test").mode("append").saveAsTable(
628+
"test_table"
629+
)
630+
511631

512632
class PythonDataSourceTests(BasePythonDataSourceTestsMixin, ReusedSQLTestCase):
513633
...

python/pyspark/sql/tests/test_udtf.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2761,6 +2761,43 @@ def eval(self, n):
27612761
res = self.spark.sql("select i, to_json(v['v1']) from test_udtf_struct(8)")
27622762
assertDataFrameEqual(res, [Row(i=n, s=f'{{"a":"{chr(99 + n)}"}}') for n in range(8)])
27632763

2764+
def test_udtf_segfault(self):
2765+
for enabled, expected in [
2766+
(True, "Segmentation fault"),
2767+
(False, "Consider setting .* for the better Python traceback."),
2768+
]:
2769+
with self.subTest(enabled=enabled), self.sql_conf(
2770+
{"spark.sql.execution.pyspark.udf.faulthandler.enabled": enabled}
2771+
):
2772+
with self.subTest(method="eval"):
2773+
2774+
class TestUDTF:
2775+
def eval(self):
2776+
import ctypes
2777+
2778+
yield ctypes.string_at(0),
2779+
2780+
self._check_result_or_exception(
2781+
TestUDTF, "x: string", expected, err_type=Exception
2782+
)
2783+
2784+
with self.subTest(method="analyze"):
2785+
2786+
class TestUDTFWithAnalyze:
2787+
@staticmethod
2788+
def analyze():
2789+
import ctypes
2790+
2791+
ctypes.string_at(0)
2792+
return AnalyzeResult(StructType().add("x", StringType()))
2793+
2794+
def eval(self):
2795+
yield "x",
2796+
2797+
self._check_result_or_exception(
2798+
TestUDTFWithAnalyze, None, expected, err_type=Exception
2799+
)
2800+
27642801

27652802
class UDTFTests(BaseUDTFTestsMixin, ReusedSQLTestCase):
27662803
@classmethod

python/pyspark/sql/worker/analyze_udtf.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# limitations under the License.
1616
#
1717

18+
import faulthandler
1819
import inspect
1920
import os
2021
import sys
@@ -106,7 +107,13 @@ def main(infile: IO, outfile: IO) -> None:
106107
in JVM and receive the Python UDTF and its arguments for the `analyze` static method,
107108
and call the `analyze` static method, and send back a AnalyzeResult as a result of the method.
108109
"""
110+
faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None)
109111
try:
112+
if faulthandler_log_path:
113+
faulthandler_log_path = os.path.join(faulthandler_log_path, str(os.getpid()))
114+
faulthandler_log_file = open(faulthandler_log_path, "w")
115+
faulthandler.enable(file=faulthandler_log_file)
116+
110117
check_python_version(infile)
111118

112119
memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", "-1"))
@@ -247,6 +254,11 @@ def invalid_analyze_result_field(field_name: str, expected_field: str) -> PySpar
247254
except BaseException as e:
248255
handle_worker_exception(e, outfile)
249256
sys.exit(-1)
257+
finally:
258+
if faulthandler_log_path:
259+
faulthandler.disable()
260+
faulthandler_log_file.close()
261+
os.remove(faulthandler_log_path)
250262

251263
send_accumulator_updates(outfile)
252264

python/pyspark/sql/worker/commit_data_source_write.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
#
17+
import faulthandler
1718
import os
1819
import sys
1920
from typing import IO
@@ -47,7 +48,13 @@ def main(infile: IO, outfile: IO) -> None:
4748
responsible for invoking either the `commit` or the `abort` method on a data source
4849
writer instance, given a list of commit messages.
4950
"""
51+
faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None)
5052
try:
53+
if faulthandler_log_path:
54+
faulthandler_log_path = os.path.join(faulthandler_log_path, str(os.getpid()))
55+
faulthandler_log_file = open(faulthandler_log_path, "w")
56+
faulthandler.enable(file=faulthandler_log_file)
57+
5158
check_python_version(infile)
5259

5360
memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", "-1"))
@@ -93,6 +100,11 @@ def main(infile: IO, outfile: IO) -> None:
93100
except BaseException as e:
94101
handle_worker_exception(e, outfile)
95102
sys.exit(-1)
103+
finally:
104+
if faulthandler_log_path:
105+
faulthandler.disable()
106+
faulthandler_log_file.close()
107+
os.remove(faulthandler_log_path)
96108

97109
send_accumulator_updates(outfile)
98110

python/pyspark/sql/worker/create_data_source.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
#
17+
import faulthandler
1718
import inspect
1819
import os
1920
import sys
@@ -60,7 +61,13 @@ def main(infile: IO, outfile: IO) -> None:
6061
This process then creates a `DataSource` instance using the above information and
6162
sends the pickled instance as well as the schema back to the JVM.
6263
"""
64+
faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None)
6365
try:
66+
if faulthandler_log_path:
67+
faulthandler_log_path = os.path.join(faulthandler_log_path, str(os.getpid()))
68+
faulthandler_log_file = open(faulthandler_log_path, "w")
69+
faulthandler.enable(file=faulthandler_log_file)
70+
6471
check_python_version(infile)
6572

6673
memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", "-1"))
@@ -158,6 +165,11 @@ def main(infile: IO, outfile: IO) -> None:
158165
except BaseException as e:
159166
handle_worker_exception(e, outfile)
160167
sys.exit(-1)
168+
finally:
169+
if faulthandler_log_path:
170+
faulthandler.disable()
171+
faulthandler_log_file.close()
172+
os.remove(faulthandler_log_path)
161173

162174
send_accumulator_updates(outfile)
163175

python/pyspark/sql/worker/lookup_data_sources.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
#
17+
import faulthandler
1718
from importlib import import_module
1819
from pkgutil import iter_modules
1920
import os
@@ -50,7 +51,13 @@ def main(infile: IO, outfile: IO) -> None:
5051
This is responsible for searching the available Python Data Sources so they can be
5152
statically registered automatically.
5253
"""
54+
faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None)
5355
try:
56+
if faulthandler_log_path:
57+
faulthandler_log_path = os.path.join(faulthandler_log_path, str(os.getpid()))
58+
faulthandler_log_file = open(faulthandler_log_path, "w")
59+
faulthandler.enable(file=faulthandler_log_file)
60+
5461
check_python_version(infile)
5562

5663
memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", "-1"))
@@ -78,6 +85,11 @@ def main(infile: IO, outfile: IO) -> None:
7885
except BaseException as e:
7986
handle_worker_exception(e, outfile)
8087
sys.exit(-1)
88+
finally:
89+
if faulthandler_log_path:
90+
faulthandler.disable()
91+
faulthandler_log_file.close()
92+
os.remove(faulthandler_log_path)
8193

8294
send_accumulator_updates(outfile)
8395

python/pyspark/sql/worker/plan_data_source_read.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# limitations under the License.
1616
#
1717

18+
import faulthandler
1819
import os
1920
import sys
2021
import functools
@@ -187,7 +188,13 @@ def main(infile: IO, outfile: IO) -> None:
187188
The partition values and the Arrow Batch are then serialized and sent back to the JVM
188189
via the socket.
189190
"""
191+
faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None)
190192
try:
193+
if faulthandler_log_path:
194+
faulthandler_log_path = os.path.join(faulthandler_log_path, str(os.getpid()))
195+
faulthandler_log_file = open(faulthandler_log_path, "w")
196+
faulthandler.enable(file=faulthandler_log_file)
197+
191198
check_python_version(infile)
192199

193200
memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", "-1"))
@@ -351,6 +358,11 @@ def data_source_read_func(iterator: Iterable[pa.RecordBatch]) -> Iterable[pa.Rec
351358
except BaseException as e:
352359
handle_worker_exception(e, outfile)
353360
sys.exit(-1)
361+
finally:
362+
if faulthandler_log_path:
363+
faulthandler.disable()
364+
faulthandler_log_file.close()
365+
os.remove(faulthandler_log_path)
354366

355367
send_accumulator_updates(outfile)
356368

0 commit comments

Comments
 (0)