Skip to content

[SPARK-48459][CONNECT][PYTHON] Implement DataFrameQueryContext in Spark Connect #46789

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,18 @@ message ResourceProfile {
// (e.g., cores, memory, CPU) to its specific request.
map<string, TaskResourceRequest> task_resources = 2;
}

message Origin {
// (Required) Indicate the origin type.
oneof function {
PythonOrigin python_origin = 1;
}
}

message PythonOrigin {
// (Required) Name of the origin, for example, the name of the function
string fragment = 1;

// (Required) Callsite to show to end users, for example, stacktrace.
string call_site = 2;
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ syntax = 'proto3';

import "google/protobuf/any.proto";
import "spark/connect/types.proto";
import "spark/connect/common.proto";

package spark.connect;

Expand All @@ -30,6 +31,7 @@ option go_package = "internal/generated";
// expressions in SQL appear.
message Expression {

ExpressionCommon common = 18;
oneof expr_type {
Literal literal = 1;
UnresolvedAttribute unresolved_attribute = 2;
Expand Down Expand Up @@ -342,6 +344,11 @@ message Expression {
}
}

message ExpressionCommon {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we over-engineering? what else can be put here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's same as Relation.RelationCommon so it's more for consistency (so we can reuse Origin as well for call site). I think it's fine.

// (Required) Keep the information of the origin for this expression such as stacktrace.
Origin origin = 1;
}

message CommonInlineUserDefinedFunction {
// (Required) Name of the user-defined function.
string function_name = 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ message Unknown {}

// Common metadata of all relations.
message RelationCommon {
// TODO(SPARK-48639): Add origin like Expression.ExpressionCommon

// (Required) Shared relation metadata.
string source_info = 1;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys.{DATAFRAME_ID, SESSION_ID}
import org.apache.spark.ml.{functions => MLFunctions}
import org.apache.spark.resource.{ExecutorResourceRequest, ResourceProfile, TaskResourceProfile, TaskResourceRequest}
import org.apache.spark.sql.{Column, Dataset, Encoders, ForeachWriter, Observation, RelationalGroupedDataset, SparkSession}
import org.apache.spark.sql.{withOrigin, Column, Dataset, Encoders, ForeachWriter, Observation, RelationalGroupedDataset, SparkSession}
import org.apache.spark.sql.avro.{AvroDataToCatalyst, CatalystDataToAvro}
import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier, QueryPlanningTracker}
import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedDataFrameStar, UnresolvedDeserializer, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar}
Expand All @@ -57,6 +57,7 @@ import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, L
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.{AppendColumns, CoGroup, CollectMetrics, CommandResult, Deduplicate, DeduplicateWithinWatermark, DeserializeToObject, Except, FlatMapGroupsWithState, Intersect, JoinWith, LocalRelation, LogicalGroupState, LogicalPlan, MapGroups, MapPartitions, Project, Sample, SerializeFromObject, Sort, SubqueryAlias, TypedFilter, Union, Unpivot, UnresolvedHint}
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
import org.apache.spark.sql.catalyst.trees.PySparkCurrentOrigin
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils}
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, ForeachWriterPacket, InvalidPlanInput, LiteralValueProtoConverter, StorageLevelProtoConverter, StreamingListenerPacket, UdfPacket}
Expand Down Expand Up @@ -1471,7 +1472,21 @@ class SparkConnectPlanner(
* Catalyst expression
*/
@DeveloperApi
def transformExpression(exp: proto.Expression): Expression = {
def transformExpression(exp: proto.Expression): Expression = if (exp.hasCommon) {
try {
val origin = exp.getCommon.getOrigin
PySparkCurrentOrigin.set(
origin.getPythonOrigin.getFragment,
origin.getPythonOrigin.getCallSite)
withOrigin { doTransformExpression(exp) }
} finally {
PySparkCurrentOrigin.clear()
}
} else {
doTransformExpression(exp)
}

private def doTransformExpression(exp: proto.Expression): Expression = {
exp.getExprTypeCase match {
case proto.Expression.ExprTypeCase.LITERAL => transformLiteral(exp.getLiteral)
case proto.Expression.ExprTypeCase.UNRESOLVED_ATTRIBUTE =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.commons.lang3.exception.ExceptionUtils
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods

import org.apache.spark.{SparkEnv, SparkException, SparkThrowable}
import org.apache.spark.{QueryContextType, SparkEnv, SparkException, SparkThrowable}
import org.apache.spark.api.python.PythonException
import org.apache.spark.connect.proto.FetchErrorDetailsResponse
import org.apache.spark.internal.{Logging, MDC}
Expand Down Expand Up @@ -118,15 +118,27 @@ private[connect] object ErrorUtils extends Logging {
sparkThrowableBuilder.setErrorClass(sparkThrowable.getErrorClass)
}
for (queryCtx <- sparkThrowable.getQueryContext) {
sparkThrowableBuilder.addQueryContexts(
FetchErrorDetailsResponse.QueryContext
.newBuilder()
val builder = FetchErrorDetailsResponse.QueryContext
.newBuilder()
val context = if (queryCtx.contextType() == QueryContextType.SQL) {
builder
.setContextType(FetchErrorDetailsResponse.QueryContext.ContextType.SQL)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did we never set this before?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah .. so it has been always SQLQueryContext by default ...

.setObjectType(queryCtx.objectType())
.setObjectName(queryCtx.objectName())
.setStartIndex(queryCtx.startIndex())
.setStopIndex(queryCtx.stopIndex())
.setFragment(queryCtx.fragment())
.build())
.setSummary(queryCtx.summary())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, so we did not have QueryContext.sumary() API before this change.

.build()
} else {
Copy link
Contributor

@grundprinzip grundprinzip Jun 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this really an unconditional else?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, yes because we only have QueryContextType.SQL and QueryContextType.DataFrame.

builder
.setContextType(FetchErrorDetailsResponse.QueryContext.ContextType.DATAFRAME)
.setFragment(queryCtx.fragment())
.setCallSite(queryCtx.callSite())
.setSummary(queryCtx.summary())
.build()
}
sparkThrowableBuilder.addQueryContexts(context)
}
if (sparkThrowable.getSqlState != null) {
sparkThrowableBuilder.setSqlState(sparkThrowable.getSqlState)
Expand Down
51 changes: 37 additions & 14 deletions python/pyspark/errors/exceptions/captured.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,14 @@ def getQueryContext(self) -> List[BaseQueryContext]:
if self._origin is not None and is_instance_of(
gw, self._origin, "org.apache.spark.SparkThrowable"
):
return [QueryContext(q) for q in self._origin.getQueryContext()]
contexts: List[BaseQueryContext] = []
for q in self._origin.getQueryContext():
if q.contextType().toString() == "SQL":
contexts.append(SQLQueryContext(q))
else:
contexts.append(DataFrameQueryContext(q))

return contexts
else:
return []

Expand Down Expand Up @@ -379,17 +386,12 @@ class UnknownException(CapturedException, BaseUnknownException):
"""


class QueryContext(BaseQueryContext):
class SQLQueryContext(BaseQueryContext):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we consider this a private / developer API?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only parent class QueryContext is an API (at pyspark.errors.QueryContext) for now. This is at least consistent with Scala side.

def __init__(self, q: "JavaObject"):
self._q = q

def contextType(self) -> QueryContextType:
context_type = self._q.contextType().toString()
assert context_type in ("SQL", "DataFrame")
if context_type == "DataFrame":
return QueryContextType.DataFrame
else:
return QueryContextType.SQL
return QueryContextType.SQL

def objectType(self) -> str:
return str(self._q.objectType())
Expand All @@ -409,13 +411,34 @@ def fragment(self) -> str:
def callSite(self) -> str:
return str(self._q.callSite())

def pysparkFragment(self) -> Optional[str]: # type: ignore[return]
if self.contextType() == QueryContextType.DataFrame:
return str(self._q.pysparkFragment())
def summary(self) -> str:
return str(self._q.summary())


class DataFrameQueryContext(BaseQueryContext):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't the type annotation wrong here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see this is for the classic side.

def __init__(self, q: "JavaObject"):
self._q = q

def contextType(self) -> QueryContextType:
return QueryContextType.DataFrame

def objectType(self) -> str:
return str(self._q.objectType())

def objectName(self) -> str:
return str(self._q.objectName())

def pysparkCallSite(self) -> Optional[str]: # type: ignore[return]
if self.contextType() == QueryContextType.DataFrame:
return str(self._q.pysparkCallSite())
def startIndex(self) -> int:
return int(self._q.startIndex())

def stopIndex(self) -> int:
return int(self._q.stopIndex())

def fragment(self) -> str:
return str(self._q.fragment())

def callSite(self) -> str:
return str(self._q.callSite())

def summary(self) -> str:
return str(self._q.summary())
83 changes: 75 additions & 8 deletions python/pyspark/errors/exceptions/connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,10 @@ def convert_exception(
)
query_contexts = []
for query_context in resp.errors[resp.root_error_idx].spark_throwable.query_contexts:
query_contexts.append(QueryContext(query_context))
if query_context.context_type == pb2.FetchErrorDetailsResponse.QueryContext.SQL:
query_contexts.append(SQLQueryContext(query_context))
else:
query_contexts.append(DataFrameQueryContext(query_context))

if "org.apache.spark.sql.catalyst.parser.ParseException" in classes:
return ParseException(
Expand Down Expand Up @@ -430,17 +433,12 @@ class SparkNoSuchElementException(SparkConnectGrpcException, BaseNoSuchElementEx
"""


class QueryContext(BaseQueryContext):
class SQLQueryContext(BaseQueryContext):
def __init__(self, q: pb2.FetchErrorDetailsResponse.QueryContext):
self._q = q

def contextType(self) -> QueryContextType:
context_type = self._q.context_type

if int(context_type) == QueryContextType.DataFrame.value:
return QueryContextType.DataFrame
else:
return QueryContextType.SQL
return QueryContextType.SQL

def objectType(self) -> str:
return str(self._q.object_type)
Expand All @@ -457,6 +455,75 @@ def stopIndex(self) -> int:
def fragment(self) -> str:
return str(self._q.fragment)

def callSite(self) -> str:
raise UnsupportedOperationException(
"",
error_class="UNSUPPORTED_CALL.WITHOUT_SUGGESTION",
message_parameters={"className": "SQLQueryContext", "methodName": "callSite"},
sql_state="0A000",
server_stacktrace=None,
display_server_stacktrace=False,
query_contexts=[],
)

def summary(self) -> str:
return str(self._q.summary)


class DataFrameQueryContext(BaseQueryContext):
def __init__(self, q: pb2.FetchErrorDetailsResponse.QueryContext):
self._q = q

def contextType(self) -> QueryContextType:
return QueryContextType.DataFrame

def objectType(self) -> str:
raise UnsupportedOperationException(
"",
error_class="UNSUPPORTED_CALL.WITHOUT_SUGGESTION",
message_parameters={"className": "DataFrameQueryContext", "methodName": "objectType"},
sql_state="0A000",
server_stacktrace=None,
display_server_stacktrace=False,
query_contexts=[],
)

def objectName(self) -> str:
raise UnsupportedOperationException(
"",
error_class="UNSUPPORTED_CALL.WITHOUT_SUGGESTION",
message_parameters={"className": "DataFrameQueryContext", "methodName": "objectName"},
sql_state="0A000",
server_stacktrace=None,
display_server_stacktrace=False,
query_contexts=[],
)

def startIndex(self) -> int:
raise UnsupportedOperationException(
"",
error_class="UNSUPPORTED_CALL.WITHOUT_SUGGESTION",
message_parameters={"className": "DataFrameQueryContext", "methodName": "startIndex"},
sql_state="0A000",
server_stacktrace=None,
display_server_stacktrace=False,
query_contexts=[],
)

def stopIndex(self) -> int:
raise UnsupportedOperationException(
"",
error_class="UNSUPPORTED_CALL.WITHOUT_SUGGESTION",
message_parameters={"className": "DataFrameQueryContext", "methodName": "stopIndex"},
sql_state="0A000",
server_stacktrace=None,
display_server_stacktrace=False,
query_contexts=[],
)

def fragment(self) -> str:
return str(self._q.fragment)

def callSite(self) -> str:
return str(self._q.call_site)

Expand Down
Loading