Skip to content

[SPARK-19607] Finding QueryExecution that matches provided executionId #16940

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.execution

import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicLong

import org.apache.spark.SparkContext
Expand All @@ -32,6 +33,12 @@ object SQLExecution {

private def nextExecutionId: Long = _nextExecutionId.getAndIncrement

private val executionIdToQueryExecution = new ConcurrentHashMap[Long, QueryExecution]()

def getQueryExecution(executionId: Long): QueryExecution = {
executionIdToQueryExecution.get(executionId)
}

/**
* Wrap an action that will execute "queryExecution" to track all Spark jobs in the body so that
* we can connect them with an execution.
Expand All @@ -44,6 +51,7 @@ object SQLExecution {
if (oldExecutionId == null) {
val executionId = SQLExecution.nextExecutionId
sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString)
executionIdToQueryExecution.put(executionId, queryExecution)
val r = try {
// sparkContext.getCallSite() would first try to pick up any call site that was previously
// set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on
Expand All @@ -60,6 +68,7 @@ object SQLExecution {
executionId, System.currentTimeMillis()))
}
} finally {
executionIdToQueryExecution.remove(executionId)
sc.setLocalProperty(EXECUTION_ID_KEY, null)
}
r
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution
import java.util.Properties

import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
import org.apache.spark.sql.SparkSession

class SQLExecutionSuite extends SparkFunSuite {
Expand Down Expand Up @@ -102,6 +103,33 @@ class SQLExecutionSuite extends SparkFunSuite {
}
}


test("Finding QueryExecution for given executionId") {
val spark = SparkSession.builder.master("local[*]").appName("test").getOrCreate()
import spark.implicits._

var queryExecution: QueryExecution = null

spark.sparkContext.addSparkListener(new SparkListener {
override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
val executionIdStr = jobStart.properties.getProperty(SQLExecution.EXECUTION_ID_KEY)
if (executionIdStr != null) {
queryExecution = SQLExecution.getQueryExecution(executionIdStr.toLong)
}
SQLExecutionSuite.canProgress = true
}
})

val df = spark.range(1).map { x =>
while (!SQLExecutionSuite.canProgress) {
Thread.sleep(1)
}
x
}
df.collect()

assert(df.queryExecution === queryExecution)
}
}

/**
Expand All @@ -114,3 +142,7 @@ private class BadSparkContext(conf: SparkConf) extends SparkContext(conf) {
override protected def initialValue(): Properties = new Properties()
}
}

object SQLExecutionSuite {
@volatile var canProgress = false
}