Skip to content

[SPARK-23032][SQL] Add a per-query codegenStageId to WholeStageCodegenExec #20224

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,14 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val WHOLESTAGE_CODEGEN_USE_ID_IN_CLASS_NAME =
buildConf("spark.sql.codegen.useIdInClassName")
.internal()
.doc("When true, embed the (whole-stage) codegen stage ID into " +
"the class name of the generated class as a suffix")
.booleanConf
.createWithDefault(true)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we disable codegen stage id in both explain result and generated class name at the same time? It seems not be useful if we disable it in class name but keep it in explain result.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's always good to have id in explain and generated classes. The only concern is we may have codegen cache issues if putting id in the class name, so we need a config to turn it off.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense to me.


val WHOLESTAGE_MAX_NUM_FIELDS = buildConf("spark.sql.codegen.maxFields")
.internal()
.doc("The maximum number of fields (including nested fields) that will be supported before" +
Expand Down Expand Up @@ -1264,6 +1272,8 @@ class SQLConf extends Serializable with Logging {

def wholeStageEnabled: Boolean = getConf(WHOLESTAGE_CODEGEN_ENABLED)

def wholeStageUseIdInClassName: Boolean = getConf(WHOLESTAGE_CODEGEN_USE_ID_IN_CLASS_NAME)

def wholeStageMaxNumFields: Int = getConf(WHOLESTAGE_MAX_NUM_FIELDS)

def codegenFallback: Boolean = getConf(CODEGEN_FALLBACK)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ case class FileSourceScanExec(
// in the case of fallback, this batched scan should never fail because of:
// 1) only primitive types are supported
// 2) the number of columns should be smaller than spark.sql.codegen.maxFields
WholeStageCodegenExec(this).execute()
WholeStageCodegenExec(this)(codegenStageId = 0).execute()
} else {
val unsafeRows = {
val scan = inputRDD
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.execution

import java.util.Locale
import java.util.function.Supplier

import scala.collection.mutable

Expand Down Expand Up @@ -414,6 +415,58 @@ object WholeStageCodegenExec {
}
}

object WholeStageCodegenId {
// codegenStageId: ID for codegen stages within a query plan.
// It does not affect equality, nor does it participate in destructuring pattern matching
// of WholeStageCodegenExec.
//
// This ID is used to help differentiate between codegen stages. It is included as a part
// of the explain output for physical plans, e.g.
//
// == Physical Plan ==
// *(5) SortMergeJoin [x#3L], [y#9L], Inner
// :- *(2) Sort [x#3L ASC NULLS FIRST], false, 0
// : +- Exchange hashpartitioning(x#3L, 200)
// : +- *(1) Project [(id#0L % 2) AS x#3L]
// : +- *(1) Filter isnotnull((id#0L % 2))
// : +- *(1) Range (0, 5, step=1, splits=8)
// +- *(4) Sort [y#9L ASC NULLS FIRST], false, 0
// +- Exchange hashpartitioning(y#9L, 200)
// +- *(3) Project [(id#6L % 2) AS y#9L]
// +- *(3) Filter isnotnull((id#6L % 2))
// +- *(3) Range (0, 5, step=1, splits=8)
//
// where the ID makes it obvious that not all adjacent codegen'd plan operators are of the
// same codegen stage.
//
// The codegen stage ID is also optionally included in the name of the generated classes as
// a suffix, so that it's easier to associate a generated class back to the physical operator.
// This is controlled by SQLConf: spark.sql.codegen.useIdInClassName
//
// The ID is also included in various log messages.
//
// Within a query, a codegen stage in a plan starts counting from 1, in "insertion order".
// WholeStageCodegenExec operators are inserted into a plan in depth-first post-order.
// See CollapseCodegenStages.insertWholeStageCodegen for the definition of insertion order.
//
// 0 is reserved as a special ID value to indicate a temporary WholeStageCodegenExec object
// is created, e.g. for special fallback handling when an existing WholeStageCodegenExec
// failed to generate/compile code.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should describe about the usage of such codegen stage id, e.g., the codegen stage id would show up in explain string and generated class name.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure thing. Will address it in the next update. Thanks!

private val codegenStageCounter = ThreadLocal.withInitial(new Supplier[Integer] {
override def get() = 1 // TODO: change to Scala lambda syntax when upgraded to Scala 2.12+
})

def resetPerQuery(): Unit = codegenStageCounter.set(1)

def getNextStageId(): Int = {
val counter = codegenStageCounter
val id = counter.get()
counter.set(id + 1)
id
}
}

/**
* WholeStageCodegen compiles a subtree of plans that support codegen together into single Java
* function.
Expand Down Expand Up @@ -442,7 +495,8 @@ object WholeStageCodegenExec {
* `doCodeGen()` will create a `CodeGenContext`, which will hold a list of variables for input,
* used to generated code for [[BoundReference]].
*/
case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with CodegenSupport {
case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int)
extends UnaryExecNode with CodegenSupport {

override def output: Seq[Attribute] = child.output

Expand All @@ -454,6 +508,12 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co
"pipelineTime" -> SQLMetrics.createTimingMetric(sparkContext,
WholeStageCodegenExec.PIPELINE_DURATION_METRIC))

def generatedClassName(): String = if (conf.wholeStageUseIdInClassName) {
s"GeneratedIteratorForCodegenStage$codegenStageId"
} else {
"GeneratedIterator"
}

/**
* Generates code for this subtree.
*
Expand All @@ -471,19 +531,23 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co
}
""", inlineToOuterClass = true)

val className = generatedClassName()

val source = s"""
public Object generate(Object[] references) {
return new GeneratedIterator(references);
return new $className(references);
}

${ctx.registerComment(s"""Codegend pipeline for\n${child.treeString.trim}""")}
final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
${ctx.registerComment(
s"""Codegend pipeline for stage (id=$codegenStageId)
|${this.treeString.trim}""".stripMargin)}
final class $className extends ${classOf[BufferedRowIterator].getName} {

private Object[] references;
private scala.collection.Iterator[] inputs;
${ctx.declareMutableStates()}

public GeneratedIterator(Object[] references) {
public $className(Object[] references) {
this.references = references;
}

Expand Down Expand Up @@ -516,7 +580,7 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co
} catch {
case _: Exception if !Utils.isTesting && sqlContext.conf.codegenFallback =>
// We should already saw the error message
logWarning(s"Whole-stage codegen disabled for this plan:\n $treeString")
logWarning(s"Whole-stage codegen disabled for plan (id=$codegenStageId):\n $treeString")
return child.execute()
}

Expand All @@ -525,7 +589,7 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co
logInfo(s"Found too long generated codes and JIT optimization might not work: " +
s"the bytecode size ($maxCodeSize) is above the limit " +
s"${sqlContext.conf.hugeMethodLimit}, and the whole-stage codegen was disabled " +
s"for this plan. To avoid this, you can raise the limit " +
s"for this plan (id=$codegenStageId). To avoid this, you can raise the limit " +
s"`${SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.key}`:\n$treeString")
child match {
// The fallback solution of batch file source scan still uses WholeStageCodegenExec
Expand Down Expand Up @@ -603,10 +667,12 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co
verbose: Boolean,
prefix: String = "",
addSuffix: Boolean = false): StringBuilder = {
child.generateTreeString(depth, lastChildren, builder, verbose, "*")
child.generateTreeString(depth, lastChildren, builder, verbose, s"*($codegenStageId) ")
}

override def needStopCheck: Boolean = true

override protected def otherCopyArgs: Seq[AnyRef] = Seq(codegenStageId.asInstanceOf[Integer])
}


Expand Down Expand Up @@ -657,13 +723,14 @@ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] {
case plan if plan.output.length == 1 && plan.output.head.dataType.isInstanceOf[ObjectType] =>
plan.withNewChildren(plan.children.map(insertWholeStageCodegen))
case plan: CodegenSupport if supportCodegen(plan) =>
WholeStageCodegenExec(insertInputAdapter(plan))
WholeStageCodegenExec(insertInputAdapter(plan))(WholeStageCodegenId.getNextStageId())
case other =>
other.withNewChildren(other.children.map(insertWholeStageCodegen))
}

def apply(plan: SparkPlan): SparkPlan = {
if (conf.wholeStageEnabled) {
WholeStageCodegenId.resetPerQuery()
insertWholeStageCodegen(plan)
} else {
plan
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ case class InMemoryTableScanExec(

protected override def doExecute(): RDD[InternalRow] = {
if (supportsBatch) {
WholeStageCodegenExec(this).execute()
WholeStageCodegenExec(this)(codegenStageId = 0).execute()
} else {
inputRDD
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ case class DataSourceV2ScanExec(

override protected def doExecute(): RDD[InternalRow] = {
if (supportsBatch) {
WholeStageCodegenExec(this).execute()
WholeStageCodegenExec(this)(codegenStageId = 0).execute()
} else {
val numOutputRows = longMetric("numOutputRows")
inputRDD.map { r =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,8 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext {
.replaceAll("Location.*/sql/core/", s"Location ${notIncludedMsg}sql/core/")
.replaceAll("Created By.*", s"Created By $notIncludedMsg")
.replaceAll("Created Time.*", s"Created Time $notIncludedMsg")
.replaceAll("Last Access.*", s"Last Access $notIncludedMsg"))
.replaceAll("Last Access.*", s"Last Access $notIncludedMsg")
.replaceAll("\\*\\(\\d+\\) ", "*")) // remove the WholeStageCodegen codegenStageIds

// If the output is not pre-sorted, sort it.
if (isSorted(df.queryExecution.analyzed)) (schema, answer) else (schema, answer.sorted)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.execution

import org.apache.spark.metrics.source.CodegenMetrics
import org.apache.spark.sql.{QueryTest, Row, SaveMode}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeGenerator}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
Expand Down Expand Up @@ -273,4 +274,37 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext {
}
}
}

test("codegen stage IDs should be preserved in transformations after CollapseCodegenStages") {
// test case adapted from DataFrameSuite to trigger ReuseExchange
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2") {
val df = spark.range(100)
val join = df.join(df, "id")
val plan = join.queryExecution.executedPlan
assert(!plan.find(p =>
p.isInstanceOf[WholeStageCodegenExec] &&
p.asInstanceOf[WholeStageCodegenExec].codegenStageId == 0).isDefined,
"codegen stage IDs should be preserved through ReuseExchange")
checkAnswer(join, df.toDF)
}
}

test("including codegen stage ID in generated class name should not regress codegen caching") {
import testImplicits._

withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_USE_ID_IN_CLASS_NAME.key -> "true") {
val bytecodeSizeHisto = CodegenMetrics.METRIC_GENERATED_METHOD_BYTECODE_SIZE

// the same query run twice should hit the codegen cache
spark.range(3).select('id + 2).collect
val after1 = bytecodeSizeHisto.getCount
spark.range(3).select('id + 2).collect
val after2 = bytecodeSizeHisto.getCount // same query shape as above, deliberately
// bytecodeSizeHisto's count is always monotonically increasing if new compilation to
// bytecode had occurred. If the count stayed the same that means we've got a cache hit.
assert(after1 == after2, "Should hit codegen cache. No new compilation to bytecode expected")

// a different query can result in codegen cache miss, that's by design
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
assert(planBeforeFilter.head.isInstanceOf[InMemoryTableScanExec])

val execPlan = if (enabled == "true") {
WholeStageCodegenExec(planBeforeFilter.head)
WholeStageCodegenExec(planBeforeFilter.head)(codegenStageId = 0)
} else {
planBeforeFilter.head
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,14 +154,39 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto
}
}

test("EXPLAIN CODEGEN command") {
checkKeywordsExist(sql("EXPLAIN CODEGEN SELECT 1"),
"WholeStageCodegen",
"Generated code:",
"/* 001 */ public Object generate(Object[] references) {",
"/* 002 */ return new GeneratedIterator(references);",
"/* 003 */ }"
test("explain output of physical plan should contain proper codegen stage ID") {
checkKeywordsExist(sql(
"""
|EXPLAIN SELECT t1.id AS a, t2.id AS b FROM
|(SELECT * FROM range(3)) t1 JOIN
|(SELECT * FROM range(10)) t2 ON t1.id == t2.id % 3
""".stripMargin),
"== Physical Plan ==",
"*(2) Project ",
"+- *(2) BroadcastHashJoin ",
" :- BroadcastExchange ",
" : +- *(1) Range ",
" +- *(2) Range "
)
}

test("EXPLAIN CODEGEN command") {
// the generated class name in this test should stay in sync with
// org.apache.spark.sql.execution.WholeStageCodegenExec.generatedClassName()
for ((useIdInClassName, expectedClassName) <- Seq(
("true", "GeneratedIteratorForCodegenStage1"),
("false", "GeneratedIterator"))) {
withSQLConf(
SQLConf.WHOLESTAGE_CODEGEN_USE_ID_IN_CLASS_NAME.key -> useIdInClassName) {
checkKeywordsExist(sql("EXPLAIN CODEGEN SELECT 1"),
"WholeStageCodegen",
"Generated code:",
"/* 001 */ public Object generate(Object[] references) {",
s"/* 002 */ return new $expectedClassName(references);",
"/* 003 */ }"
)
}
}

checkKeywordsNotExist(sql("EXPLAIN CODEGEN SELECT 1"),
"== Physical Plan =="
Expand Down