Skip to content

Commit 80bba44

Browse files
committed
[SPARK-48459][CONNECT][PYTHON] Implement DataFrameQueryContext in Spark Connect
### What changes were proposed in this pull request? This PR proposes to Implement DataFrameQueryContext in Spark Connect. 1. Add two new protobuf messages packed together with `Expression`: ```proto message Origin { // (Required) Indicate the origin type. oneof function { PythonOrigin python_origin = 1; } } message PythonOrigin { // (Required) Name of the origin, for example, the name of the function string fragment = 1; // (Required) Callsite to show to end users, for example, stacktrace. string call_site = 2; } ``` 2. Merge `DataFrameQueryContext.pysparkFragment` and `DataFrameQueryContext.pysparkcallSite` to existing `DataFrameQueryContext.fragment` and `DataFrameQueryContext.callSite` 3. Separate `QueryContext` into `SQLQueryContext` and `DataFrameQueryContext` for consistency w/ Scala side 4. Implement the origin logic. `current_origin` thread local holds the current call site/the function name, and `Expression` gets it from it. They are set to individual expression messages, and are used when analysis happens - this resembles Spark SQL implementation. See also #45377. ### Why are the changes needed? See #45377 ### Does this PR introduce _any_ user-facing change? Yes, same as #45377 but in Spark Connect. ### How was this patch tested? Same unittests reused in Spark Connect. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46789 from HyukjinKwon/connect-context. Authored-by: Hyukjin Kwon <gurwls223@apache.org> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
1 parent 58701d8 commit 80bba44

File tree

19 files changed

+463
-205
lines changed

19 files changed

+463
-205
lines changed

connector/connect/common/src/main/protobuf/spark/connect/common.proto

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,18 @@ message ResourceProfile {
8181
// (e.g., cores, memory, CPU) to its specific request.
8282
map<string, TaskResourceRequest> task_resources = 2;
8383
}
84+
85+
message Origin {
86+
// (Required) Indicate the origin type.
87+
oneof function {
88+
PythonOrigin python_origin = 1;
89+
}
90+
}
91+
92+
message PythonOrigin {
93+
// (Required) Name of the origin, for example, the name of the function
94+
string fragment = 1;
95+
96+
// (Required) Callsite to show to end users, for example, stacktrace.
97+
string call_site = 2;
98+
}

connector/connect/common/src/main/protobuf/spark/connect/expressions.proto

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ syntax = 'proto3';
1919

2020
import "google/protobuf/any.proto";
2121
import "spark/connect/types.proto";
22+
import "spark/connect/common.proto";
2223

2324
package spark.connect;
2425

@@ -30,6 +31,7 @@ option go_package = "internal/generated";
3031
// expressions in SQL appear.
3132
message Expression {
3233

34+
ExpressionCommon common = 18;
3335
oneof expr_type {
3436
Literal literal = 1;
3537
UnresolvedAttribute unresolved_attribute = 2;
@@ -342,6 +344,11 @@ message Expression {
342344
}
343345
}
344346

347+
message ExpressionCommon {
348+
// (Required) Keep the information of the origin for this expression such as stacktrace.
349+
Origin origin = 1;
350+
}
351+
345352
message CommonInlineUserDefinedFunction {
346353
// (Required) Name of the user-defined function.
347354
string function_name = 1;

connector/connect/common/src/main/protobuf/spark/connect/relations.proto

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ message Unknown {}
106106

107107
// Common metadata of all relations.
108108
message RelationCommon {
109+
// TODO(SPARK-48639): Add origin like Expression.ExpressionCommon
110+
109111
// (Required) Shared relation metadata.
110112
string source_info = 1;
111113

connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ import org.apache.spark.internal.{Logging, MDC}
4444
import org.apache.spark.internal.LogKeys.{DATAFRAME_ID, SESSION_ID}
4545
import org.apache.spark.ml.{functions => MLFunctions}
4646
import org.apache.spark.resource.{ExecutorResourceRequest, ResourceProfile, TaskResourceProfile, TaskResourceRequest}
47-
import org.apache.spark.sql.{Column, Dataset, Encoders, ForeachWriter, Observation, RelationalGroupedDataset, SparkSession}
47+
import org.apache.spark.sql.{withOrigin, Column, Dataset, Encoders, ForeachWriter, Observation, RelationalGroupedDataset, SparkSession}
4848
import org.apache.spark.sql.avro.{AvroDataToCatalyst, CatalystDataToAvro}
4949
import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier, QueryPlanningTracker}
5050
import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedDataFrameStar, UnresolvedDeserializer, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar}
@@ -57,6 +57,7 @@ import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, L
5757
import org.apache.spark.sql.catalyst.plans.logical
5858
import org.apache.spark.sql.catalyst.plans.logical.{AppendColumns, CoGroup, CollectMetrics, CommandResult, Deduplicate, DeduplicateWithinWatermark, DeserializeToObject, Except, FlatMapGroupsWithState, Intersect, JoinWith, LocalRelation, LogicalGroupState, LogicalPlan, MapGroups, MapPartitions, Project, Sample, SerializeFromObject, Sort, SubqueryAlias, TypedFilter, Union, Unpivot, UnresolvedHint}
5959
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
60+
import org.apache.spark.sql.catalyst.trees.PySparkCurrentOrigin
6061
import org.apache.spark.sql.catalyst.types.DataTypeUtils
6162
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils}
6263
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, ForeachWriterPacket, InvalidPlanInput, LiteralValueProtoConverter, StorageLevelProtoConverter, StreamingListenerPacket, UdfPacket}
@@ -1471,7 +1472,21 @@ class SparkConnectPlanner(
14711472
* Catalyst expression
14721473
*/
14731474
@DeveloperApi
1474-
def transformExpression(exp: proto.Expression): Expression = {
1475+
def transformExpression(exp: proto.Expression): Expression = if (exp.hasCommon) {
1476+
try {
1477+
val origin = exp.getCommon.getOrigin
1478+
PySparkCurrentOrigin.set(
1479+
origin.getPythonOrigin.getFragment,
1480+
origin.getPythonOrigin.getCallSite)
1481+
withOrigin { doTransformExpression(exp) }
1482+
} finally {
1483+
PySparkCurrentOrigin.clear()
1484+
}
1485+
} else {
1486+
doTransformExpression(exp)
1487+
}
1488+
1489+
private def doTransformExpression(exp: proto.Expression): Expression = {
14751490
exp.getExprTypeCase match {
14761491
case proto.Expression.ExprTypeCase.LITERAL => transformLiteral(exp.getLiteral)
14771492
case proto.Expression.ExprTypeCase.UNRESOLVED_ATTRIBUTE =>

connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ import org.apache.commons.lang3.exception.ExceptionUtils
3535
import org.json4s.JsonDSL._
3636
import org.json4s.jackson.JsonMethods
3737

38-
import org.apache.spark.{SparkEnv, SparkException, SparkThrowable}
38+
import org.apache.spark.{QueryContextType, SparkEnv, SparkException, SparkThrowable}
3939
import org.apache.spark.api.python.PythonException
4040
import org.apache.spark.connect.proto.FetchErrorDetailsResponse
4141
import org.apache.spark.internal.{Logging, MDC}
@@ -118,15 +118,27 @@ private[connect] object ErrorUtils extends Logging {
118118
sparkThrowableBuilder.setErrorClass(sparkThrowable.getErrorClass)
119119
}
120120
for (queryCtx <- sparkThrowable.getQueryContext) {
121-
sparkThrowableBuilder.addQueryContexts(
122-
FetchErrorDetailsResponse.QueryContext
123-
.newBuilder()
121+
val builder = FetchErrorDetailsResponse.QueryContext
122+
.newBuilder()
123+
val context = if (queryCtx.contextType() == QueryContextType.SQL) {
124+
builder
125+
.setContextType(FetchErrorDetailsResponse.QueryContext.ContextType.SQL)
124126
.setObjectType(queryCtx.objectType())
125127
.setObjectName(queryCtx.objectName())
126128
.setStartIndex(queryCtx.startIndex())
127129
.setStopIndex(queryCtx.stopIndex())
128130
.setFragment(queryCtx.fragment())
129-
.build())
131+
.setSummary(queryCtx.summary())
132+
.build()
133+
} else {
134+
builder
135+
.setContextType(FetchErrorDetailsResponse.QueryContext.ContextType.DATAFRAME)
136+
.setFragment(queryCtx.fragment())
137+
.setCallSite(queryCtx.callSite())
138+
.setSummary(queryCtx.summary())
139+
.build()
140+
}
141+
sparkThrowableBuilder.addQueryContexts(context)
130142
}
131143
if (sparkThrowable.getSqlState != null) {
132144
sparkThrowableBuilder.setSqlState(sparkThrowable.getSqlState)

python/pyspark/errors/exceptions/captured.py

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,14 @@ def getQueryContext(self) -> List[BaseQueryContext]:
166166
if self._origin is not None and is_instance_of(
167167
gw, self._origin, "org.apache.spark.SparkThrowable"
168168
):
169-
return [QueryContext(q) for q in self._origin.getQueryContext()]
169+
contexts: List[BaseQueryContext] = []
170+
for q in self._origin.getQueryContext():
171+
if q.contextType().toString() == "SQL":
172+
contexts.append(SQLQueryContext(q))
173+
else:
174+
contexts.append(DataFrameQueryContext(q))
175+
176+
return contexts
170177
else:
171178
return []
172179

@@ -379,17 +386,12 @@ class UnknownException(CapturedException, BaseUnknownException):
379386
"""
380387

381388

382-
class QueryContext(BaseQueryContext):
389+
class SQLQueryContext(BaseQueryContext):
383390
def __init__(self, q: "JavaObject"):
384391
self._q = q
385392

386393
def contextType(self) -> QueryContextType:
387-
context_type = self._q.contextType().toString()
388-
assert context_type in ("SQL", "DataFrame")
389-
if context_type == "DataFrame":
390-
return QueryContextType.DataFrame
391-
else:
392-
return QueryContextType.SQL
394+
return QueryContextType.SQL
393395

394396
def objectType(self) -> str:
395397
return str(self._q.objectType())
@@ -409,13 +411,34 @@ def fragment(self) -> str:
409411
def callSite(self) -> str:
410412
return str(self._q.callSite())
411413

412-
def pysparkFragment(self) -> Optional[str]: # type: ignore[return]
413-
if self.contextType() == QueryContextType.DataFrame:
414-
return str(self._q.pysparkFragment())
414+
def summary(self) -> str:
415+
return str(self._q.summary())
416+
417+
418+
class DataFrameQueryContext(BaseQueryContext):
419+
def __init__(self, q: "JavaObject"):
420+
self._q = q
421+
422+
def contextType(self) -> QueryContextType:
423+
return QueryContextType.DataFrame
424+
425+
def objectType(self) -> str:
426+
return str(self._q.objectType())
427+
428+
def objectName(self) -> str:
429+
return str(self._q.objectName())
415430

416-
def pysparkCallSite(self) -> Optional[str]: # type: ignore[return]
417-
if self.contextType() == QueryContextType.DataFrame:
418-
return str(self._q.pysparkCallSite())
431+
def startIndex(self) -> int:
432+
return int(self._q.startIndex())
433+
434+
def stopIndex(self) -> int:
435+
return int(self._q.stopIndex())
436+
437+
def fragment(self) -> str:
438+
return str(self._q.fragment())
439+
440+
def callSite(self) -> str:
441+
return str(self._q.callSite())
419442

420443
def summary(self) -> str:
421444
return str(self._q.summary())

python/pyspark/errors/exceptions/connect.py

Lines changed: 75 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,10 @@ def convert_exception(
9191
)
9292
query_contexts = []
9393
for query_context in resp.errors[resp.root_error_idx].spark_throwable.query_contexts:
94-
query_contexts.append(QueryContext(query_context))
94+
if query_context.context_type == pb2.FetchErrorDetailsResponse.QueryContext.SQL:
95+
query_contexts.append(SQLQueryContext(query_context))
96+
else:
97+
query_contexts.append(DataFrameQueryContext(query_context))
9598

9699
if "org.apache.spark.sql.catalyst.parser.ParseException" in classes:
97100
return ParseException(
@@ -430,17 +433,12 @@ class SparkNoSuchElementException(SparkConnectGrpcException, BaseNoSuchElementEx
430433
"""
431434

432435

433-
class QueryContext(BaseQueryContext):
436+
class SQLQueryContext(BaseQueryContext):
434437
def __init__(self, q: pb2.FetchErrorDetailsResponse.QueryContext):
435438
self._q = q
436439

437440
def contextType(self) -> QueryContextType:
438-
context_type = self._q.context_type
439-
440-
if int(context_type) == QueryContextType.DataFrame.value:
441-
return QueryContextType.DataFrame
442-
else:
443-
return QueryContextType.SQL
441+
return QueryContextType.SQL
444442

445443
def objectType(self) -> str:
446444
return str(self._q.object_type)
@@ -457,6 +455,75 @@ def stopIndex(self) -> int:
457455
def fragment(self) -> str:
458456
return str(self._q.fragment)
459457

458+
def callSite(self) -> str:
459+
raise UnsupportedOperationException(
460+
"",
461+
error_class="UNSUPPORTED_CALL.WITHOUT_SUGGESTION",
462+
message_parameters={"className": "SQLQueryContext", "methodName": "callSite"},
463+
sql_state="0A000",
464+
server_stacktrace=None,
465+
display_server_stacktrace=False,
466+
query_contexts=[],
467+
)
468+
469+
def summary(self) -> str:
470+
return str(self._q.summary)
471+
472+
473+
class DataFrameQueryContext(BaseQueryContext):
474+
def __init__(self, q: pb2.FetchErrorDetailsResponse.QueryContext):
475+
self._q = q
476+
477+
def contextType(self) -> QueryContextType:
478+
return QueryContextType.DataFrame
479+
480+
def objectType(self) -> str:
481+
raise UnsupportedOperationException(
482+
"",
483+
error_class="UNSUPPORTED_CALL.WITHOUT_SUGGESTION",
484+
message_parameters={"className": "DataFrameQueryContext", "methodName": "objectType"},
485+
sql_state="0A000",
486+
server_stacktrace=None,
487+
display_server_stacktrace=False,
488+
query_contexts=[],
489+
)
490+
491+
def objectName(self) -> str:
492+
raise UnsupportedOperationException(
493+
"",
494+
error_class="UNSUPPORTED_CALL.WITHOUT_SUGGESTION",
495+
message_parameters={"className": "DataFrameQueryContext", "methodName": "objectName"},
496+
sql_state="0A000",
497+
server_stacktrace=None,
498+
display_server_stacktrace=False,
499+
query_contexts=[],
500+
)
501+
502+
def startIndex(self) -> int:
503+
raise UnsupportedOperationException(
504+
"",
505+
error_class="UNSUPPORTED_CALL.WITHOUT_SUGGESTION",
506+
message_parameters={"className": "DataFrameQueryContext", "methodName": "startIndex"},
507+
sql_state="0A000",
508+
server_stacktrace=None,
509+
display_server_stacktrace=False,
510+
query_contexts=[],
511+
)
512+
513+
def stopIndex(self) -> int:
514+
raise UnsupportedOperationException(
515+
"",
516+
error_class="UNSUPPORTED_CALL.WITHOUT_SUGGESTION",
517+
message_parameters={"className": "DataFrameQueryContext", "methodName": "stopIndex"},
518+
sql_state="0A000",
519+
server_stacktrace=None,
520+
display_server_stacktrace=False,
521+
query_contexts=[],
522+
)
523+
524+
def fragment(self) -> str:
525+
return str(self._q.fragment)
526+
460527
def callSite(self) -> str:
461528
return str(self._q.call_site)
462529

0 commit comments

Comments
 (0)