Skip to content

Commit 4c0d755

Browse files
committed
Implement DataFrameQueryContext in Spark Connect
1 parent 47c55f4 commit 4c0d755

File tree

14 files changed

+422
-202
lines changed

14 files changed

+422
-202
lines changed

connector/connect/common/src/main/protobuf/spark/connect/expressions.proto

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ message Expression {
5454
google.protobuf.Any extension = 999;
5555
}
5656

57+
// (Optional) Keep the information of the origin for this expression such as stacktrace.
58+
Origin origin = 18;
5759

5860
// Expression for the OVER clause or WINDOW clause.
5961
message Window {
@@ -405,3 +407,18 @@ message NamedArgumentExpression {
405407
// (Required) The value expression of the named argument.
406408
Expression value = 2;
407409
}
410+
411+
message Origin {
412+
// (Required) Indicate the origin type.
413+
oneof function {
414+
PythonOrigin python_origin = 1;
415+
}
416+
}
417+
418+
message PythonOrigin {
419+
// (Required) Name of the origin, for example, the name of the function
420+
string fragment = 1;
421+
422+
// (Required) Callsite to show to end users, for example, stacktrace.
423+
string call_site = 2;
424+
}

connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ import org.apache.spark.internal.{Logging, MDC}
4444
import org.apache.spark.internal.LogKeys.{DATAFRAME_ID, SESSION_ID}
4545
import org.apache.spark.ml.{functions => MLFunctions}
4646
import org.apache.spark.resource.{ExecutorResourceRequest, ResourceProfile, TaskResourceProfile, TaskResourceRequest}
47-
import org.apache.spark.sql.{Column, Dataset, Encoders, ForeachWriter, Observation, RelationalGroupedDataset, SparkSession}
47+
import org.apache.spark.sql.{withOrigin, Column, Dataset, Encoders, ForeachWriter, Observation, RelationalGroupedDataset, SparkSession}
4848
import org.apache.spark.sql.avro.{AvroDataToCatalyst, CatalystDataToAvro}
4949
import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier, QueryPlanningTracker}
5050
import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedDataFrameStar, UnresolvedDeserializer, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar}
@@ -57,6 +57,7 @@ import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, L
5757
import org.apache.spark.sql.catalyst.plans.logical
5858
import org.apache.spark.sql.catalyst.plans.logical.{AppendColumns, CoGroup, CollectMetrics, CommandResult, Deduplicate, DeduplicateWithinWatermark, DeserializeToObject, Except, FlatMapGroupsWithState, Intersect, JoinWith, LocalRelation, LogicalGroupState, LogicalPlan, MapGroups, MapPartitions, Project, Sample, SerializeFromObject, Sort, SubqueryAlias, TypedFilter, Union, Unpivot, UnresolvedHint}
5959
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
60+
import org.apache.spark.sql.catalyst.trees.PySparkCurrentOrigin
6061
import org.apache.spark.sql.catalyst.types.DataTypeUtils
6162
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils}
6263
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, ForeachWriterPacket, InvalidPlanInput, LiteralValueProtoConverter, StorageLevelProtoConverter, StreamingListenerPacket, UdfPacket}
@@ -1471,7 +1472,20 @@ class SparkConnectPlanner(
14711472
* Catalyst expression
14721473
*/
14731474
@DeveloperApi
1474-
def transformExpression(exp: proto.Expression): Expression = {
1475+
def transformExpression(exp: proto.Expression): Expression = if (exp.hasOrigin) {
1476+
try {
1477+
PySparkCurrentOrigin.set(
1478+
exp.getOrigin.getPythonOrigin.getFragment,
1479+
exp.getOrigin.getPythonOrigin.getCallSite)
1480+
withOrigin { doTransformExpression(exp) }
1481+
} finally {
1482+
PySparkCurrentOrigin.clear()
1483+
}
1484+
} else {
1485+
doTransformExpression(exp)
1486+
}
1487+
1488+
private def doTransformExpression(exp: proto.Expression): Expression = {
14751489
exp.getExprTypeCase match {
14761490
case proto.Expression.ExprTypeCase.LITERAL => transformLiteral(exp.getLiteral)
14771491
case proto.Expression.ExprTypeCase.UNRESOLVED_ATTRIBUTE =>

connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ import org.apache.commons.lang3.exception.ExceptionUtils
3535
import org.json4s.JsonDSL._
3636
import org.json4s.jackson.JsonMethods
3737

38-
import org.apache.spark.{SparkEnv, SparkException, SparkThrowable}
38+
import org.apache.spark.{QueryContextType, SparkEnv, SparkException, SparkThrowable}
3939
import org.apache.spark.api.python.PythonException
4040
import org.apache.spark.connect.proto.FetchErrorDetailsResponse
4141
import org.apache.spark.internal.{Logging, MDC}
@@ -118,15 +118,27 @@ private[connect] object ErrorUtils extends Logging {
118118
sparkThrowableBuilder.setErrorClass(sparkThrowable.getErrorClass)
119119
}
120120
for (queryCtx <- sparkThrowable.getQueryContext) {
121-
sparkThrowableBuilder.addQueryContexts(
122-
FetchErrorDetailsResponse.QueryContext
123-
.newBuilder()
121+
val builder = FetchErrorDetailsResponse.QueryContext
122+
.newBuilder()
123+
val context = if (queryCtx.contextType() == QueryContextType.SQL) {
124+
builder
125+
.setContextType(FetchErrorDetailsResponse.QueryContext.ContextType.SQL)
124126
.setObjectType(queryCtx.objectType())
125127
.setObjectName(queryCtx.objectName())
126128
.setStartIndex(queryCtx.startIndex())
127129
.setStopIndex(queryCtx.stopIndex())
128130
.setFragment(queryCtx.fragment())
129-
.build())
131+
.setSummary(queryCtx.summary())
132+
.build()
133+
} else {
134+
builder
135+
.setContextType(FetchErrorDetailsResponse.QueryContext.ContextType.DATAFRAME)
136+
.setFragment(queryCtx.fragment())
137+
.setCallSite(queryCtx.callSite())
138+
.setSummary(queryCtx.summary())
139+
.build()
140+
}
141+
sparkThrowableBuilder.addQueryContexts(context)
130142
}
131143
if (sparkThrowable.getSqlState != null) {
132144
sparkThrowableBuilder.setSqlState(sparkThrowable.getSqlState)

python/pyspark/errors/exceptions/captured.py

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,14 @@ def getQueryContext(self) -> List[BaseQueryContext]:
166166
if self._origin is not None and is_instance_of(
167167
gw, self._origin, "org.apache.spark.SparkThrowable"
168168
):
169-
return [QueryContext(q) for q in self._origin.getQueryContext()]
169+
contexts: List[BaseQueryContext] = []
170+
for q in self._origin.getQueryContext():
171+
if q.contextType().toString() == "SQL":
172+
contexts.append(SQLQueryContext(q))
173+
else:
174+
contexts.append(DataFrameQueryContext(q))
175+
176+
return contexts
170177
else:
171178
return []
172179

@@ -379,17 +386,12 @@ class UnknownException(CapturedException, BaseUnknownException):
379386
"""
380387

381388

382-
class QueryContext(BaseQueryContext):
389+
class SQLQueryContext(BaseQueryContext):
383390
def __init__(self, q: "JavaObject"):
384391
self._q = q
385392

386393
def contextType(self) -> QueryContextType:
387-
context_type = self._q.contextType().toString()
388-
assert context_type in ("SQL", "DataFrame")
389-
if context_type == "DataFrame":
390-
return QueryContextType.DataFrame
391-
else:
392-
return QueryContextType.SQL
394+
return QueryContextType.SQL
393395

394396
def objectType(self) -> str:
395397
return str(self._q.objectType())
@@ -409,13 +411,34 @@ def fragment(self) -> str:
409411
def callSite(self) -> str:
410412
return str(self._q.callSite())
411413

412-
def pysparkFragment(self) -> Optional[str]: # type: ignore[return]
413-
if self.contextType() == QueryContextType.DataFrame:
414-
return str(self._q.pysparkFragment())
414+
def summary(self) -> str:
415+
return str(self._q.summary())
416+
417+
418+
class DataFrameQueryContext(BaseQueryContext):
419+
def __init__(self, q: "JavaObject"):
420+
self._q = q
421+
422+
def contextType(self) -> QueryContextType:
423+
return QueryContextType.DataFrame
424+
425+
def objectType(self) -> str:
426+
return str(self._q.objectType())
427+
428+
def objectName(self) -> str:
429+
return str(self._q.objectName())
415430

416-
def pysparkCallSite(self) -> Optional[str]: # type: ignore[return]
417-
if self.contextType() == QueryContextType.DataFrame:
418-
return str(self._q.pysparkCallSite())
431+
def startIndex(self) -> int:
432+
return int(self._q.startIndex())
433+
434+
def stopIndex(self) -> int:
435+
return int(self._q.stopIndex())
436+
437+
def fragment(self) -> str:
438+
return str(self._q.fragment())
439+
440+
def callSite(self) -> str:
441+
return str(self._q.callSite())
419442

420443
def summary(self) -> str:
421444
return str(self._q.summary())

python/pyspark/errors/exceptions/connect.py

Lines changed: 75 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,10 @@ def convert_exception(
9191
)
9292
query_contexts = []
9393
for query_context in resp.errors[resp.root_error_idx].spark_throwable.query_contexts:
94-
query_contexts.append(QueryContext(query_context))
94+
if query_context.context_type == pb2.FetchErrorDetailsResponse.QueryContext.SQL:
95+
query_contexts.append(SQLQueryContext(query_context))
96+
else:
97+
query_contexts.append(DataFrameQueryContext(query_context))
9598

9699
if "org.apache.spark.sql.catalyst.parser.ParseException" in classes:
97100
return ParseException(
@@ -430,17 +433,12 @@ class SparkNoSuchElementException(SparkConnectGrpcException, BaseNoSuchElementEx
430433
"""
431434

432435

433-
class QueryContext(BaseQueryContext):
436+
class SQLQueryContext(BaseQueryContext):
434437
def __init__(self, q: pb2.FetchErrorDetailsResponse.QueryContext):
435438
self._q = q
436439

437440
def contextType(self) -> QueryContextType:
438-
context_type = self._q.context_type
439-
440-
if int(context_type) == QueryContextType.DataFrame.value:
441-
return QueryContextType.DataFrame
442-
else:
443-
return QueryContextType.SQL
441+
return QueryContextType.SQL
444442

445443
def objectType(self) -> str:
446444
return str(self._q.object_type)
@@ -457,6 +455,75 @@ def stopIndex(self) -> int:
457455
def fragment(self) -> str:
458456
return str(self._q.fragment)
459457

458+
def callSite(self) -> str:
459+
raise UnsupportedOperationException(
460+
"",
461+
error_class="UNSUPPORTED_CALL.WITHOUT_SUGGESTION",
462+
message_parameters={"className": "SQLQueryContext", "methodName": "callSite"},
463+
sql_state="0A000",
464+
server_stacktrace=None,
465+
display_server_stacktrace=False,
466+
query_contexts=[],
467+
)
468+
469+
def summary(self) -> str:
470+
return str(self._q.summary)
471+
472+
473+
class DataFrameQueryContext(BaseQueryContext):
474+
def __init__(self, q: pb2.FetchErrorDetailsResponse.QueryContext):
475+
self._q = q
476+
477+
def contextType(self) -> QueryContextType:
478+
return QueryContextType.DataFrame
479+
480+
def objectType(self) -> str:
481+
raise UnsupportedOperationException(
482+
"",
483+
error_class="UNSUPPORTED_CALL.WITHOUT_SUGGESTION",
484+
message_parameters={"className": "DataFrameQueryContext", "methodName": "objectType"},
485+
sql_state="0A000",
486+
server_stacktrace=None,
487+
display_server_stacktrace=False,
488+
query_contexts=[],
489+
)
490+
491+
def objectName(self) -> str:
492+
raise UnsupportedOperationException(
493+
"",
494+
error_class="UNSUPPORTED_CALL.WITHOUT_SUGGESTION",
495+
message_parameters={"className": "DataFrameQueryContext", "methodName": "objectName"},
496+
sql_state="0A000",
497+
server_stacktrace=None,
498+
display_server_stacktrace=False,
499+
query_contexts=[],
500+
)
501+
502+
def startIndex(self) -> int:
503+
raise UnsupportedOperationException(
504+
"",
505+
error_class="UNSUPPORTED_CALL.WITHOUT_SUGGESTION",
506+
message_parameters={"className": "DataFrameQueryContext", "methodName": "startIndex"},
507+
sql_state="0A000",
508+
server_stacktrace=None,
509+
display_server_stacktrace=False,
510+
query_contexts=[],
511+
)
512+
513+
def stopIndex(self) -> int:
514+
raise UnsupportedOperationException(
515+
"",
516+
error_class="UNSUPPORTED_CALL.WITHOUT_SUGGESTION",
517+
message_parameters={"className": "DataFrameQueryContext", "methodName": "stopIndex"},
518+
sql_state="0A000",
519+
server_stacktrace=None,
520+
display_server_stacktrace=False,
521+
query_contexts=[],
522+
)
523+
524+
def fragment(self) -> str:
525+
return str(self._q.fragment)
526+
460527
def callSite(self) -> str:
461528
return str(self._q.call_site)
462529

0 commit comments

Comments
 (0)