Skip to content

[SPARK-50903][CONNECT] Cache logical plans after analysis #49584

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ import io.grpc.stub.StreamObserver
import org.apache.spark.SparkEnv
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.ExecutePlanResponse
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.classic.{DataFrame, Dataset}
import org.apache.spark.sql.connect.common.DataTypeProtoConverter
import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto
import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_SIZE
Expand Down Expand Up @@ -68,12 +68,9 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder)
} else {
DoNotCleanup
}
val rel = request.getPlan.getRoot
val dataframe =
Dataset.ofRows(
sessionHolder.session,
planner.transformRelation(request.getPlan.getRoot, cachePlan = true),
tracker,
shuffleCleanupMode)
sessionHolder.createDataFrame(rel, planner, Some((tracker, shuffleCleanupMode)))
responseObserver.onNext(createSchemaResponse(request.getSessionId, dataframe.schema))
processAsArrowBatches(dataframe, responseObserver, executeHolder)
responseObserver.onNext(MetricGenerator.createMetricsResponse(sessionHolder, dataframe))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,31 +120,27 @@ class SparkConnectPlanner(
sys.env.getOrElse("PYSPARK_PYTHON", sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", "python3"))

/**
* The root of the query plan is a relation and we apply the transformations to it. The resolved
* logical plan will not get cached. If the result needs to be cached, use
* `transformRelation(rel, cachePlan = true)` instead.
* The root of the query plan is a relation and we apply the transformations to it.
* @param rel
* The relation to transform.
* @return
* The resolved logical plan.
*/
@DeveloperApi
def transformRelation(rel: proto.Relation): LogicalPlan =
transformRelation(rel, cachePlan = false)
def transformRelation(rel: proto.Relation): LogicalPlan = transformRelationWithCache(rel)._1

/**
* The root of the query plan is a relation and we apply the transformations to it.
* The root of the query plan is a relation and we apply the transformations to it. If the
* relation exists in the plan cache, return the cached plan, but it does not update the plan
* cache.
* @param rel
* The relation to transform.
* @param cachePlan
* Set to true for a performance optimization, if the plan is likely to be reused, e.g. built
* upon by further dataset transformation. The default is false.
* @return
* The resolved logical plan.
* The resolved logical plan and a flag indicating that the cache was hit.
*/
@DeveloperApi
def transformRelation(rel: proto.Relation, cachePlan: Boolean): LogicalPlan = {
sessionHolder.usePlanCache(rel, cachePlan) { rel =>
def transformRelationWithCache(rel: proto.Relation): (LogicalPlan, Boolean) = {
sessionHolder.usePlanCache(rel) { rel =>
val plan = rel.getRelTypeCase match {
// DataFrame API
case proto.Relation.RelTypeCase.SHOW_STRING => transformShowString(rel.getShowString)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,17 @@ import org.apache.spark.api.python.PythonFunction.PythonAccumulator
import org.apache.spark.connect.proto
import org.apache.spark.internal.{Logging, LogKeys, MDC}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.catalyst.QueryPlanningTracker
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.classic.SparkSession
import org.apache.spark.sql.classic.{Dataset, SparkSession}
import org.apache.spark.sql.connect.common.InvalidPlanInput
import org.apache.spark.sql.connect.config.Connect
import org.apache.spark.sql.connect.ml.MLCache
import org.apache.spark.sql.connect.planner.PythonStreamingQueryListener
import org.apache.spark.sql.connect.planner.SparkConnectPlanner
import org.apache.spark.sql.connect.planner.StreamingForeachBatchHelper
import org.apache.spark.sql.connect.service.SessionHolder.{ERROR_CACHE_SIZE, ERROR_CACHE_TIMEOUT_SEC}
import org.apache.spark.sql.execution.ShuffleCleanupMode
import org.apache.spark.sql.streaming.StreamingQueryListener
import org.apache.spark.util.{SystemClock, Utils}

Expand Down Expand Up @@ -440,46 +443,77 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio
* `spark.connect.session.planCache.enabled` is true.
* @param rel
* The relation to transform.
* @param cachePlan
* Whether to cache the result logical plan.
* @param transform
* Function to transform the relation into a logical plan.
* @return
* The logical plan.
* The logical plan and a flag indicating that the plan cache was hit.
*/
private[connect] def usePlanCache(rel: proto.Relation, cachePlan: Boolean)(
transform: proto.Relation => LogicalPlan): LogicalPlan = {
val planCacheEnabled = Option(session)
.forall(_.sessionState.conf.getConf(Connect.CONNECT_SESSION_PLAN_CACHE_ENABLED, true))
// We only cache plans that have a plan ID.
val hasPlanId = rel.hasCommon && rel.getCommon.hasPlanId

def getPlanCache(rel: proto.Relation): Option[LogicalPlan] =
planCache match {
case Some(cache) if planCacheEnabled && hasPlanId =>
Option(cache.getIfPresent(rel)) match {
case Some(plan) =>
private[connect] def usePlanCache(rel: proto.Relation)(
transform: proto.Relation => LogicalPlan): (LogicalPlan, Boolean) = {
planCache match {
case Some(cache) if canCachePlan(rel) =>
Option(cache.getIfPresent(rel)) match {
case Some(plan) =>
if (isPlanOutdated(plan)) {
// The plan is outdated, therefore remove it from the cache.
cache.invalidate(rel)
} else {
logDebug(s"Using cached plan for relation '$rel': $plan")
Some(plan)
case None => None
}
case _ => None
}
def putPlanCache(rel: proto.Relation, plan: LogicalPlan): Unit =
planCache match {
case Some(cache) if planCacheEnabled && hasPlanId =>
cache.put(rel, plan)
case _ =>
return (plan, true)
}
case None => ()
}
case _ => ()
}
(transform(rel), false)
}

/**
* Create a data frame from the supplied relation, and update the plan cache.
*
* @param rel
* A proto.Relation to create a data frame.
* @param options
* Options to pass to the data frame.
* @return
* The created data frame.
*/
private[connect] def createDataFrame(
rel: proto.Relation,
planner: SparkConnectPlanner,
options: Option[(QueryPlanningTracker, ShuffleCleanupMode)] = None): DataFrame = {
val (plan, cacheHit) = planner.transformRelationWithCache(rel)
val df = options match {
case Some((tracker, shuffleCleanupMode)) =>
Dataset.ofRows(session, plan, tracker, shuffleCleanupMode)
case _ => Dataset.ofRows(session, plan)
}
if (!cacheHit && planCache.isDefined && canCachePlan(rel)) {
if (df.queryExecution.isLazyAnalysis) {
val plan = df.queryExecution.logical
logDebug(s"Cache a lazyily analyzed logical plan for '$rel': $plan")
planCache.get.put(rel, plan)
} else {
val plan = df.queryExecution.analyzed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may have to add some invalidation logic here. The problem is that some of objects (tables/views/udfs) used in the query can change, in that case we may want to validate that the objects used are the most recent ones.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. I'll need to think about the interface.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added LogicalPlan.isOutdated and updated the usePlanCache method.
-> Calling 'refresh' seemed suboptimal, as schema changes often require a re-planning of the plan.

logDebug(s"Cache an analyzed logical plan for '$rel': $plan")
planCache.get.put(rel, plan)
}
}
df
}

getPlanCache(rel)
.getOrElse({
val plan = transform(rel)
if (cachePlan) {
putPlanCache(rel, plan)
}
plan
})
// Return true if the plan is outdated and should be removed from the cache.
private def isPlanOutdated(plan: LogicalPlan): Boolean = {
// Currently, nothing is checked.
false
}

// Return true if the plan cache is enabled for the session and the relation.
private def canCachePlan(rel: proto.Relation): Boolean = {
// We only cache plans that have a plan ID.
rel.hasCommon && rel.getCommon.hasPlanId &&
Option(session)
.forall(_.sessionState.conf.getConf(Connect.CONNECT_SESSION_PLAN_CACHE_ENABLED, true))
}

// For testing. Expose the plan cache for testing purposes.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import io.grpc.stub.StreamObserver

import org.apache.spark.connect.proto
import org.apache.spark.internal.Logging
import org.apache.spark.sql.classic.Dataset
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, InvalidPlanInput, StorageLevelProtoConverter}
import org.apache.spark.sql.connect.planner.SparkConnectPlanner
import org.apache.spark.sql.execution.{CodegenMode, CostMode, ExtendedMode, FormattedMode, SimpleMode}
Expand Down Expand Up @@ -59,23 +58,21 @@ private[connect] class SparkConnectAnalyzeHandler(
val session = sessionHolder.session
val builder = proto.AnalyzePlanResponse.newBuilder()

def transformRelation(rel: proto.Relation) = planner.transformRelation(rel, cachePlan = true)

request.getAnalyzeCase match {
case proto.AnalyzePlanRequest.AnalyzeCase.SCHEMA =>
val schema = Dataset
.ofRows(session, transformRelation(request.getSchema.getPlan.getRoot))
.schema
val schema =
sessionHolder.createDataFrame(request.getSchema.getPlan.getRoot, planner).schema
builder.setSchema(
proto.AnalyzePlanResponse.Schema
.newBuilder()
.setSchema(DataTypeProtoConverter.toConnectProtoType(schema))
.build())

case proto.AnalyzePlanRequest.AnalyzeCase.EXPLAIN =>
val queryExecution = Dataset
.ofRows(session, transformRelation(request.getExplain.getPlan.getRoot))
.queryExecution
val queryExecution =
sessionHolder
.createDataFrame(request.getExplain.getPlan.getRoot, planner)
.queryExecution
val explainString = request.getExplain.getExplainMode match {
case proto.AnalyzePlanRequest.Explain.ExplainMode.EXPLAIN_MODE_SIMPLE =>
queryExecution.explainString(SimpleMode)
Expand All @@ -96,9 +93,8 @@ private[connect] class SparkConnectAnalyzeHandler(
.build())

case proto.AnalyzePlanRequest.AnalyzeCase.TREE_STRING =>
val schema = Dataset
.ofRows(session, transformRelation(request.getTreeString.getPlan.getRoot))
.schema
val schema =
sessionHolder.createDataFrame(request.getTreeString.getPlan.getRoot, planner).schema
val treeString = if (request.getTreeString.hasLevel) {
schema.treeString(request.getTreeString.getLevel)
} else {
Expand All @@ -111,29 +107,28 @@ private[connect] class SparkConnectAnalyzeHandler(
.build())

case proto.AnalyzePlanRequest.AnalyzeCase.IS_LOCAL =>
val isLocal = Dataset
.ofRows(session, transformRelation(request.getIsLocal.getPlan.getRoot))
.isLocal
val isLocal =
sessionHolder.createDataFrame(request.getIsLocal.getPlan.getRoot, planner).isLocal
builder.setIsLocal(
proto.AnalyzePlanResponse.IsLocal
.newBuilder()
.setIsLocal(isLocal)
.build())

case proto.AnalyzePlanRequest.AnalyzeCase.IS_STREAMING =>
val isStreaming = Dataset
.ofRows(session, transformRelation(request.getIsStreaming.getPlan.getRoot))
.isStreaming
val isStreaming =
sessionHolder
.createDataFrame(request.getIsStreaming.getPlan.getRoot, planner)
.isStreaming
builder.setIsStreaming(
proto.AnalyzePlanResponse.IsStreaming
.newBuilder()
.setIsStreaming(isStreaming)
.build())

case proto.AnalyzePlanRequest.AnalyzeCase.INPUT_FILES =>
val inputFiles = Dataset
.ofRows(session, transformRelation(request.getInputFiles.getPlan.getRoot))
.inputFiles
val inputFiles =
sessionHolder.createDataFrame(request.getInputFiles.getPlan.getRoot, planner).inputFiles
builder.setInputFiles(
proto.AnalyzePlanResponse.InputFiles
.newBuilder()
Expand All @@ -156,29 +151,27 @@ private[connect] class SparkConnectAnalyzeHandler(
.build())

case proto.AnalyzePlanRequest.AnalyzeCase.SAME_SEMANTICS =>
val target = Dataset.ofRows(
session,
transformRelation(request.getSameSemantics.getTargetPlan.getRoot))
val other = Dataset.ofRows(
session,
transformRelation(request.getSameSemantics.getOtherPlan.getRoot))
val target =
sessionHolder.createDataFrame(request.getSameSemantics.getTargetPlan.getRoot, planner)
val other =
sessionHolder.createDataFrame(request.getSameSemantics.getOtherPlan.getRoot, planner)
builder.setSameSemantics(
proto.AnalyzePlanResponse.SameSemantics
.newBuilder()
.setResult(target.sameSemantics(other)))

case proto.AnalyzePlanRequest.AnalyzeCase.SEMANTIC_HASH =>
val semanticHash = Dataset
.ofRows(session, transformRelation(request.getSemanticHash.getPlan.getRoot))
val semanticHash = sessionHolder
.createDataFrame(request.getSemanticHash.getPlan.getRoot, planner)
.semanticHash()
builder.setSemanticHash(
proto.AnalyzePlanResponse.SemanticHash
.newBuilder()
.setResult(semanticHash))

case proto.AnalyzePlanRequest.AnalyzeCase.PERSIST =>
val target = Dataset
.ofRows(session, transformRelation(request.getPersist.getRelation))
val target = sessionHolder
.createDataFrame(request.getPersist.getRelation, planner)
if (request.getPersist.hasStorageLevel) {
target.persist(
StorageLevelProtoConverter.toStorageLevel(request.getPersist.getStorageLevel))
Expand All @@ -188,8 +181,8 @@ private[connect] class SparkConnectAnalyzeHandler(
builder.setPersist(proto.AnalyzePlanResponse.Persist.newBuilder().build())

case proto.AnalyzePlanRequest.AnalyzeCase.UNPERSIST =>
val target = Dataset
.ofRows(session, transformRelation(request.getUnpersist.getRelation))
val target = sessionHolder
.createDataFrame(request.getUnpersist.getRelation, planner)
if (request.getUnpersist.hasBlocking) {
target.unpersist(request.getUnpersist.getBlocking)
} else {
Expand All @@ -198,8 +191,8 @@ private[connect] class SparkConnectAnalyzeHandler(
builder.setUnpersist(proto.AnalyzePlanResponse.Unpersist.newBuilder().build())

case proto.AnalyzePlanRequest.AnalyzeCase.GET_STORAGE_LEVEL =>
val target = Dataset
.ofRows(session, transformRelation(request.getGetStorageLevel.getRelation))
val target = sessionHolder
.createDataFrame(request.getGetStorageLevel.getRelation, planner)
val storageLevel = target.storageLevel
builder.setGetStorageLevel(
proto.AnalyzePlanResponse.GetStorageLevel
Expand Down
Loading