Skip to content

Commit 2c1d5d8

Browse files
committed
Prototype
1 parent 32e73dd commit 2c1d5d8

File tree

4 files changed

+94
-3
lines changed

4 files changed

+94
-3
lines changed

python/pyspark/errors/utils.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,16 @@
1616
#
1717

1818
import re
19-
from typing import Dict, Match
19+
import functools
20+
import inspect
21+
from typing import Any, Callable, Dict, Match, TypeVar, Type
2022

2123
from pyspark.errors.error_classes import ERROR_CLASSES_MAP
2224

2325

26+
T = TypeVar("T")
27+
28+
2429
class ErrorClassesReader:
2530
"""
2631
A reader to load error information from error_classes.py.
@@ -119,3 +124,62 @@ def get_message_template(self, error_class: str) -> str:
119124
message_template = main_message_template + " " + sub_message_template
120125

121126
return message_template
127+
128+
129+
def _capture_call_site(func_name: str) -> None:
130+
"""
131+
Capture the call site information including file name, line number, and function name.
132+
133+
This function updates the thread-local storage from server side (PySparkCurrentOrigin)
134+
with the current call site information when a PySpark API function is called.
135+
136+
Parameters
137+
----------
138+
func_name : str
139+
The name of the PySpark API function being captured.
140+
141+
Notes
142+
-----
143+
The call site information is used to enhance error messages with the exact location
144+
in the user code that led to the error.
145+
"""
146+
from pyspark.sql.session import SparkSession
147+
148+
spark = SparkSession._getActiveSessionOrCreate()
149+
assert spark._jvm is not None
150+
151+
stack = inspect.stack()
152+
frame_info = stack[-1]
153+
function = func_name
154+
filename = frame_info.filename
155+
lineno = frame_info.lineno
156+
call_site = f'"{function}" was called from\n{filename}:{lineno}'
157+
158+
pyspark_origin = spark._jvm.org.apache.spark.sql.catalyst.trees.PySparkCurrentOrigin
159+
pyspark_origin.set(call_site)
160+
161+
162+
def with_origin(func: Callable[..., Any]) -> Callable[..., Any]:
163+
"""
164+
A decorator to capture and provide the call site information to the server side
165+
when PySpark API functions are invoked.
166+
"""
167+
168+
@functools.wraps(func)
169+
def wrapper(*args: Any, **kwargs: Any) -> Any:
170+
# Update call site when the function is called
171+
_capture_call_site(func.__name__)
172+
173+
return func(*args, **kwargs)
174+
175+
return wrapper
176+
177+
178+
def with_origin_to_class(cls: Type[T]) -> Type[T]:
179+
"""
180+
Decorate all methods of a class with `with_origin` to capture call site information.
181+
"""
182+
for name, method in cls.__dict__.items():
183+
if callable(method) and name != "__init__":
184+
setattr(cls, name, with_origin(method))
185+
return cls

python/pyspark/sql/column.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
from pyspark.context import SparkContext
3737
from pyspark.errors import PySparkAttributeError, PySparkTypeError, PySparkValueError
38+
from pyspark.errors.utils import with_origin_to_class
3839
from pyspark.sql.types import DataType
3940
from pyspark.sql.utils import get_active_spark_context
4041

@@ -177,6 +178,7 @@ def _(
177178
return Column(njc)
178179

179180
_.__doc__ = doc
181+
_.__name__ = name
180182
return _
181183

182184

@@ -195,6 +197,7 @@ def _(self: "Column", other: Union["LiteralType", "DecimalLiteral"]) -> "Column"
195197
return _
196198

197199

200+
@with_origin_to_class
198201
class Column:
199202

200203
"""

sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,8 @@ case class SQLQueryContext(
134134
override def callSite: String = throw SparkUnsupportedOperationException()
135135
}
136136

137-
case class DataFrameQueryContext(stackTrace: Seq[StackTraceElement]) extends QueryContext {
137+
case class DataFrameQueryContext(stackTrace: Seq[StackTraceElement], pysparkCallSite: String)
138+
extends QueryContext {
138139
override val contextType = QueryContextType.DataFrame
139140

140141
override def objectType: String = throw SparkUnsupportedOperationException()
@@ -165,6 +166,12 @@ case class DataFrameQueryContext(stackTrace: Seq[StackTraceElement]) extends Que
165166
builder ++= " was called from\n"
166167
builder ++= callSite
167168
builder += '\n'
169+
170+
if (pysparkCallSite.nonEmpty) {
171+
builder ++= "\n== PySpark call site ==\n"
172+
builder ++= pysparkCallSite
173+
builder += '\n'
174+
}
168175
builder.result()
169176
}
170177
}

sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ case class Origin(
3535
stackTrace: Option[Array[StackTraceElement]] = None) {
3636

3737
lazy val context: QueryContext = if (stackTrace.isDefined) {
38-
DataFrameQueryContext(stackTrace.get.toImmutableArraySeq)
38+
val pysparkCallSite = PySparkCurrentOrigin.get()
39+
DataFrameQueryContext(stackTrace.get.toImmutableArraySeq, pysparkCallSite)
3940
} else {
4041
SQLQueryContext(
4142
line, startPosition, startIndex, stopIndex, sqlText, objectType, objectName)
@@ -84,3 +85,19 @@ object CurrentOrigin {
8485
ret
8586
}
8687
}
88+
89+
/**
90+
* Provides detailed call site information on PySpark.
91+
* This information is generated in PySpark in the form of a String.
92+
*/
93+
object PySparkCurrentOrigin {
94+
private val pysparkCallSite = new ThreadLocal[String]() {
95+
override def initialValue(): String = ""
96+
}
97+
98+
def set(value: String): Unit = pysparkCallSite.set(value)
99+
100+
def get(): String = pysparkCallSite.get()
101+
102+
def clear(): Unit = pysparkCallSite.remove()
103+
}

0 commit comments

Comments
 (0)