Skip to content

Commit

Permalink
Merge branch 'apache:master' into 5438_add_common_method_to_support_s…
Browse files Browse the repository at this point in the history
…ession_config
  • Loading branch information
davidyuan1223 authored Oct 25, 2023
2 parents c1024bd + ed0d997 commit 618c0f6
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ package org.apache.spark.kyuubi
import java.util.Properties
import java.util.concurrent.ConcurrentHashMap

import scala.collection.JavaConverters._

import org.apache.spark.scheduler._
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionEnd
Expand All @@ -43,14 +45,15 @@ class SQLOperationListener(
spark: SparkSession) extends StatsReportListener with Logging {

private val operationId: String = operation.getHandle.identifier.toString
private lazy val activeJobs = new java.util.HashSet[Int]()
private lazy val activeJobs = new ConcurrentHashMap[Int, SparkJobInfo]()
private lazy val activeStages = new ConcurrentHashMap[SparkStageAttempt, SparkStageInfo]()
private var executionId: Option[Long] = None

private lazy val consoleProgressBar =
if (getSessionConf(ENGINE_SPARK_SHOW_PROGRESS, spark)) {
Some(new SparkConsoleProgressBar(
operation,
activeJobs,
activeStages,
getSessionConf(ENGINE_SPARK_SHOW_PROGRESS_UPDATE_INTERVAL, spark),
getSessionConf(ENGINE_SPARK_SHOW_PROGRESS_TIME_FORMAT, spark)))
Expand All @@ -77,9 +80,10 @@ class SQLOperationListener(
}
}

override def onJobStart(jobStart: SparkListenerJobStart): Unit = activeJobs.synchronized {
override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
if (sameGroupId(jobStart.properties)) {
val jobId = jobStart.jobId
val stageIds = jobStart.stageInfos.map(_.stageId).toSet
val stageSize = jobStart.stageInfos.size
if (executionId.isEmpty) {
executionId = Option(jobStart.properties.getProperty(SPARK_SQL_EXECUTION_ID_KEY))
Expand All @@ -91,17 +95,19 @@ class SQLOperationListener(
case _ =>
}
}
activeJobs.put(
jobId,
new SparkJobInfo(stageSize, stageIds))
withOperationLog {
activeJobs.add(jobId)
info(s"Query [$operationId]: Job $jobId started with $stageSize stages," +
s" ${activeJobs.size()} active jobs running")
}
}
}

override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = activeJobs.synchronized {
override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {
val jobId = jobEnd.jobId
if (activeJobs.remove(jobId)) {
if (activeJobs.remove(jobId) != null) {
val hint = jobEnd.jobResult match {
case JobSucceeded => "succeeded"
case _ => "failed" // TODO: Handle JobFailed(exception: Exception)
Expand Down Expand Up @@ -132,9 +138,18 @@ class SQLOperationListener(

override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = {
val stageInfo = stageCompleted.stageInfo
val stageId = stageInfo.stageId
val stageAttempt = SparkStageAttempt(stageInfo.stageId, stageInfo.attemptNumber())
activeStages.synchronized {
if (activeStages.remove(stageAttempt) != null) {
stageInfo.getStatusString match {
case "succeeded" =>
activeJobs.asScala.foreach { case (_, jobInfo) =>
if (jobInfo.stageIds.contains(stageId)) {
jobInfo.numCompleteStages.getAndIncrement()
}
}
}
withOperationLog(super.onStageCompleted(stageCompleted))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.kyuubi.operation.Operation

class SparkConsoleProgressBar(
operation: Operation,
liveJobs: ConcurrentHashMap[Int, SparkJobInfo],
liveStages: ConcurrentHashMap[SparkStageAttempt, SparkStageInfo],
updatePeriodMSec: Long,
timeFormat: String)
Expand Down Expand Up @@ -72,6 +73,17 @@ class SparkConsoleProgressBar(
}
}

/**
* Use stageId to find stage's jobId
* @param stageId
* @return jobId (Optional)
*/
private def findJobId(stageId: Int): Option[Int] = {
liveJobs.asScala.collectFirst {
case (jobId, jobInfo) if jobInfo.stageIds.contains(stageId) => jobId
}
}

/**
* Show progress bar in console. The progress bar is displayed in the next line
* after your last output, keeps overwriting itself to hold in one line. The logging will follow
Expand All @@ -81,9 +93,13 @@ class SparkConsoleProgressBar(
val width = TerminalWidth / stages.size
val bar = stages.map { s =>
val total = s.numTasks
val header = s"[Stage ${s.stageId}:"
val jobHeader = findJobId(s.stageId).map(jobId =>
s"[Job $jobId (${liveJobs.get(jobId).numCompleteStages} " +
s"/ ${liveJobs.get(jobId).numStages}) Stages] ").getOrElse(
"[There is no job about this stage] ")
val header = jobHeader + s"[Stage ${s.stageId}:"
val tailer = s"(${s.numCompleteTasks} + ${s.numActiveTasks}) / $total]"
val w = width - header.length - tailer.length
val w = width + jobHeader.length - header.length - tailer.length
val bar =
if (w > 0) {
val percent = w * s.numCompleteTasks.get / total
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ case class SparkStageAttempt(stageId: Int, stageAttemptId: Int) {
}

class SparkStageInfo(val stageId: Int, val numTasks: Int) {
var numActiveTasks = new AtomicInteger(0)
var numCompleteTasks = new AtomicInteger(0)
val numActiveTasks = new AtomicInteger(0)
val numCompleteTasks = new AtomicInteger(0)
}

class SparkJobInfo(val numStages: Int, val stageIds: Set[Int]) {
val numCompleteStages = new AtomicInteger(0)
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,16 @@ import scala.collection.JavaConverters.asScalaBufferConverter
import org.apache.hive.service.rpc.thrift.{TExecuteStatementReq, TFetchOrientation, TFetchResultsReq, TOperationHandle}
import org.scalatest.time.SpanSugar._

import org.apache.kyuubi.config.KyuubiConf
import org.apache.kyuubi.config.KyuubiConf.OPERATION_SPARK_LISTENER_ENABLED
import org.apache.kyuubi.engine.spark.WithSparkSQLEngine
import org.apache.kyuubi.operation.HiveJDBCTestHelper

class SQLOperationListenerSuite extends WithSparkSQLEngine with HiveJDBCTestHelper {

override def withKyuubiConf: Map[String, String] = Map.empty
override def withKyuubiConf: Map[String, String] = Map(
KyuubiConf.ENGINE_SPARK_SHOW_PROGRESS.key -> "true",
KyuubiConf.ENGINE_SPARK_SHOW_PROGRESS_UPDATE_INTERVAL.key -> "200")

override protected def jdbcUrl: String = getJdbcUrl

Expand All @@ -54,6 +57,24 @@ class SQLOperationListenerSuite extends WithSparkSQLEngine with HiveJDBCTestHelp
}
}

test("operation listener with progress job info") {
val sql = "SELECT java_method('java.lang.Thread', 'sleep', 10000l) FROM range(1, 3, 1, 2);"
withSessionHandle { (client, handle) =>
val req = new TExecuteStatementReq()
req.setSessionHandle(handle)
req.setStatement(sql)
val tExecuteStatementResp = client.ExecuteStatement(req)
val opHandle = tExecuteStatementResp.getOperationHandle
val fetchResultsReq = new TFetchResultsReq(opHandle, TFetchOrientation.FETCH_NEXT, 1000)
fetchResultsReq.setFetchType(1.toShort)
eventually(timeout(90.seconds), interval(500.milliseconds)) {
val resultsResp = client.FetchResults(fetchResultsReq)
val logs = resultsResp.getResults.getColumns.get(0).getStringVal.getValues.asScala
assert(logs.exists(_.matches(".*\\[Job .* Stages\\] \\[Stage .*\\]")))
}
}
}

test("SQLOperationListener configurable") {
val sql = "select /*+ REPARTITION(3, a) */ a from values(1) t(a);"
withSessionHandle { (client, handle) =>
Expand Down

0 comments on commit 618c0f6

Please sign in to comment.