Skip to content

Commit fd8983e

Browse files
committed
Add a per-query codegenStageId to WholeStageCodegenExec
Also added a new test case to HiveExplainSuite to make sure the codegen stage ID is indeed included in the explain output of the physical plan, and another new test case in WholeStageCodegenSuite to make sure with the ID included into the generated class name, the generated code can still hit the codegen cache for the same query.
1 parent 8532e26 commit fd8983e

File tree

9 files changed

+158
-21
lines changed

9 files changed

+158
-21
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,14 @@ object SQLConf {
629629
.booleanConf
630630
.createWithDefault(true)
631631

632+
val WHOLESTAGE_CODEGEN_USE_ID_IN_CLASS_NAME =
633+
buildConf("spark.sql.codegen.useIdInClassName")
634+
.internal()
635+
.doc("When true, embed the (whole-stage) codegen stage ID into " +
636+
"the class name of the generated class as a suffix")
637+
.booleanConf
638+
.createWithDefault(true)
639+
632640
val WHOLESTAGE_MAX_NUM_FIELDS = buildConf("spark.sql.codegen.maxFields")
633641
.internal()
634642
.doc("The maximum number of fields (including nested fields) that will be supported before" +
@@ -1264,6 +1272,8 @@ class SQLConf extends Serializable with Logging {
12641272

12651273
def wholeStageEnabled: Boolean = getConf(WHOLESTAGE_CODEGEN_ENABLED)
12661274

1275+
def wholeStageUseIdInClassName: Boolean = getConf(WHOLESTAGE_CODEGEN_USE_ID_IN_CLASS_NAME)
1276+
12671277
def wholeStageMaxNumFields: Int = getConf(WHOLESTAGE_MAX_NUM_FIELDS)
12681278

12691279
def codegenFallback: Boolean = getConf(CODEGEN_FALLBACK)

sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ case class FileSourceScanExec(
324324
// in the case of fallback, this batched scan should never fail because of:
325325
// 1) only primitive types are supported
326326
// 2) the number of columns should be smaller than spark.sql.codegen.maxFields
327-
WholeStageCodegenExec(this).execute()
327+
WholeStageCodegenExec(this)(codegenStageId = 0).execute()
328328
} else {
329329
val unsafeRows = {
330330
val scan = inputRDD

sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala

Lines changed: 76 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.execution
1919

2020
import java.util.Locale
21+
import java.util.function.Supplier
2122

2223
import scala.collection.mutable
2324

@@ -414,6 +415,58 @@ object WholeStageCodegenExec {
414415
}
415416
}
416417

418+
object WholeStageCodegenId {
419+
// codegenStageId: ID for codegen stages within a query plan.
420+
// It does not affect equality, nor does it participate in destructuring pattern matching
421+
// of WholeStageCodegenExec.
422+
//
423+
// This ID is used to help differentiate between codegen stages. It is included as a part
424+
// of the explain output for physical plans, e.g.
425+
//
426+
// == Physical Plan ==
427+
// *(5) SortMergeJoin [x#3L], [y#9L], Inner
428+
// :- *(2) Sort [x#3L ASC NULLS FIRST], false, 0
429+
// : +- Exchange hashpartitioning(x#3L, 200)
430+
// : +- *(1) Project [(id#0L % 2) AS x#3L]
431+
// : +- *(1) Filter isnotnull((id#0L % 2))
432+
// : +- *(1) Range (0, 5, step=1, splits=8)
433+
// +- *(4) Sort [y#9L ASC NULLS FIRST], false, 0
434+
// +- Exchange hashpartitioning(y#9L, 200)
435+
// +- *(3) Project [(id#6L % 2) AS y#9L]
436+
// +- *(3) Filter isnotnull((id#6L % 2))
437+
// +- *(3) Range (0, 5, step=1, splits=8)
438+
//
439+
// where the ID makes it obvious that not all adjacent codegen'd plan operators are of the
440+
// same codegen stage.
441+
//
442+
// The codegen stage ID is also optionally included in the name of the generated classes as
443+
// a suffix, so that it's easier to associate a generated class back to the physical operator.
444+
// This is controlled by SQLConf: spark.sql.codegen.useIdInClassName
445+
//
446+
// The ID is also included in various log messages.
447+
//
448+
// Within a query, a codegen stage in a plan starts counting from 1, in "insertion order".
449+
// WholeStageCodegenExec operators are inserted into a plan in depth-first post-order.
450+
// See CollapseCodegenStages.insertWholeStageCodegen for the definition of insertion order.
451+
//
452+
// 0 is reserved as a special ID value to indicate a temporary WholeStageCodegenExec object
453+
// is created, e.g. for special fallback handling when an existing WholeStageCodegenExec
454+
// failed to generate/compile code.
455+
456+
private val codegenStageCounter = ThreadLocal.withInitial(new Supplier[Integer] {
457+
override def get() = 1 // TODO: change to Scala lambda syntax when upgraded to Scala 2.12+
458+
})
459+
460+
def resetPerQuery(): Unit = codegenStageCounter.set(1)
461+
462+
def getNextStageId(): Int = {
463+
val counter = codegenStageCounter
464+
val id = counter.get()
465+
counter.set(id + 1)
466+
id
467+
}
468+
}
469+
417470
/**
418471
* WholeStageCodegen compiles a subtree of plans that support codegen together into single Java
419472
* function.
@@ -442,7 +495,8 @@ object WholeStageCodegenExec {
442495
* `doCodeGen()` will create a `CodeGenContext`, which will hold a list of variables for input,
443496
* used to generated code for [[BoundReference]].
444497
*/
445-
case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with CodegenSupport {
498+
case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int)
499+
extends UnaryExecNode with CodegenSupport {
446500

447501
override def output: Seq[Attribute] = child.output
448502

@@ -454,6 +508,12 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co
454508
"pipelineTime" -> SQLMetrics.createTimingMetric(sparkContext,
455509
WholeStageCodegenExec.PIPELINE_DURATION_METRIC))
456510

511+
def generatedClassName(): String = if (conf.wholeStageUseIdInClassName) {
512+
s"GeneratedIteratorForCodegenStage$codegenStageId"
513+
} else {
514+
"GeneratedIterator"
515+
}
516+
457517
/**
458518
* Generates code for this subtree.
459519
*
@@ -471,19 +531,23 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co
471531
}
472532
""", inlineToOuterClass = true)
473533

534+
val className = generatedClassName()
535+
474536
val source = s"""
475537
public Object generate(Object[] references) {
476-
return new GeneratedIterator(references);
538+
return new $className(references);
477539
}
478540

479-
${ctx.registerComment(s"""Codegend pipeline for\n${child.treeString.trim}""")}
480-
final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
541+
${ctx.registerComment(
542+
s"""Codegend pipeline for stage (id=$codegenStageId)
543+
|${this.treeString.trim}""".stripMargin)}
544+
final class $className extends ${classOf[BufferedRowIterator].getName} {
481545

482546
private Object[] references;
483547
private scala.collection.Iterator[] inputs;
484548
${ctx.declareMutableStates()}
485549

486-
public GeneratedIterator(Object[] references) {
550+
public $className(Object[] references) {
487551
this.references = references;
488552
}
489553

@@ -516,7 +580,7 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co
516580
} catch {
517581
case _: Exception if !Utils.isTesting && sqlContext.conf.codegenFallback =>
518582
// We should already saw the error message
519-
logWarning(s"Whole-stage codegen disabled for this plan:\n $treeString")
583+
logWarning(s"Whole-stage codegen disabled for plan (id=$codegenStageId):\n $treeString")
520584
return child.execute()
521585
}
522586

@@ -525,7 +589,7 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co
525589
logInfo(s"Found too long generated codes and JIT optimization might not work: " +
526590
s"the bytecode size ($maxCodeSize) is above the limit " +
527591
s"${sqlContext.conf.hugeMethodLimit}, and the whole-stage codegen was disabled " +
528-
s"for this plan. To avoid this, you can raise the limit " +
592+
s"for this plan (id=$codegenStageId). To avoid this, you can raise the limit " +
529593
s"`${SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.key}`:\n$treeString")
530594
child match {
531595
// The fallback solution of batch file source scan still uses WholeStageCodegenExec
@@ -603,10 +667,12 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co
603667
verbose: Boolean,
604668
prefix: String = "",
605669
addSuffix: Boolean = false): StringBuilder = {
606-
child.generateTreeString(depth, lastChildren, builder, verbose, "*")
670+
child.generateTreeString(depth, lastChildren, builder, verbose, s"*($codegenStageId) ")
607671
}
608672

609673
override def needStopCheck: Boolean = true
674+
675+
override protected def otherCopyArgs: Seq[AnyRef] = Seq(codegenStageId.asInstanceOf[Integer])
610676
}
611677

612678

@@ -657,13 +723,14 @@ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] {
657723
case plan if plan.output.length == 1 && plan.output.head.dataType.isInstanceOf[ObjectType] =>
658724
plan.withNewChildren(plan.children.map(insertWholeStageCodegen))
659725
case plan: CodegenSupport if supportCodegen(plan) =>
660-
WholeStageCodegenExec(insertInputAdapter(plan))
726+
WholeStageCodegenExec(insertInputAdapter(plan))(WholeStageCodegenId.getNextStageId())
661727
case other =>
662728
other.withNewChildren(other.children.map(insertWholeStageCodegen))
663729
}
664730

665731
def apply(plan: SparkPlan): SparkPlan = {
666732
if (conf.wholeStageEnabled) {
733+
WholeStageCodegenId.resetPerQuery()
667734
insertWholeStageCodegen(plan)
668735
} else {
669736
plan

sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ case class InMemoryTableScanExec(
274274

275275
protected override def doExecute(): RDD[InternalRow] = {
276276
if (supportsBatch) {
277-
WholeStageCodegenExec(this).execute()
277+
WholeStageCodegenExec(this)(codegenStageId = 0).execute()
278278
} else {
279279
inputRDD
280280
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ case class DataSourceV2ScanExec(
8888

8989
override protected def doExecute(): RDD[InternalRow] = {
9090
if (supportsBatch) {
91-
WholeStageCodegenExec(this).execute()
91+
WholeStageCodegenExec(this)(codegenStageId = 0).execute()
9292
} else {
9393
val numOutputRows = longMetric("numOutputRows")
9494
inputRDD.map { r =>

sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,8 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext {
230230
.replaceAll("Location.*/sql/core/", s"Location ${notIncludedMsg}sql/core/")
231231
.replaceAll("Created By.*", s"Created By $notIncludedMsg")
232232
.replaceAll("Created Time.*", s"Created Time $notIncludedMsg")
233-
.replaceAll("Last Access.*", s"Last Access $notIncludedMsg"))
233+
.replaceAll("Last Access.*", s"Last Access $notIncludedMsg")
234+
.replaceAll("\\*\\(\\d+\\) ", "*")) // remove the WholeStageCodegen codegenStageIds
234235

235236
// If the output is not pre-sorted, sort it.
236237
if (isSorted(df.queryExecution.analyzed)) (schema, answer) else (schema, answer.sorted)

sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.sql.execution
1919

20+
import org.apache.spark.metrics.source.CodegenMetrics
2021
import org.apache.spark.sql.{QueryTest, Row, SaveMode}
2122
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeGenerator}
2223
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
@@ -273,4 +274,37 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext {
273274
}
274275
}
275276
}
277+
278+
test("codegen stage IDs should be preserved in transformations after CollapseCodegenStages") {
279+
// test case adapted from DataFrameSuite to trigger ReuseExchange
280+
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2") {
281+
val df = spark.range(100)
282+
val join = df.join(df, "id")
283+
val plan = join.queryExecution.executedPlan
284+
assert(!plan.find(p =>
285+
p.isInstanceOf[WholeStageCodegenExec] &&
286+
p.asInstanceOf[WholeStageCodegenExec].codegenStageId == 0).isDefined,
287+
"codegen stage IDs should be preserved through ReuseExchange")
288+
checkAnswer(join, df.toDF)
289+
}
290+
}
291+
292+
test("including codegen stage ID in generated class name should not regress codegen caching") {
293+
import testImplicits._
294+
295+
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_USE_ID_IN_CLASS_NAME.key -> "true") {
296+
val bytecodeSizeHisto = CodegenMetrics.METRIC_GENERATED_METHOD_BYTECODE_SIZE
297+
298+
// the same query run twice should hit the codegen cache
299+
spark.range(3).select('id + 2).collect
300+
val after1 = bytecodeSizeHisto.getCount
301+
spark.range(3).select('id + 2).collect
302+
val after2 = bytecodeSizeHisto.getCount // same query shape as above, deliberately
303+
// bytecodeSizeHisto's count is always monotonically increasing if new compilation to
304+
// bytecode had occurred. If the count stayed the same that means we've got a cache hit.
305+
assert(after1 == after2, "Should hit codegen cache. No new compilation to bytecode expected")
306+
307+
// a different query can result in codegen cache miss, that's by design
308+
}
309+
}
276310
}

sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
477477
assert(planBeforeFilter.head.isInstanceOf[InMemoryTableScanExec])
478478

479479
val execPlan = if (enabled == "true") {
480-
WholeStageCodegenExec(planBeforeFilter.head)
480+
WholeStageCodegenExec(planBeforeFilter.head)(codegenStageId = 0)
481481
} else {
482482
planBeforeFilter.head
483483
}

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -154,14 +154,39 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto
154154
}
155155
}
156156

157-
test("EXPLAIN CODEGEN command") {
158-
checkKeywordsExist(sql("EXPLAIN CODEGEN SELECT 1"),
159-
"WholeStageCodegen",
160-
"Generated code:",
161-
"/* 001 */ public Object generate(Object[] references) {",
162-
"/* 002 */ return new GeneratedIterator(references);",
163-
"/* 003 */ }"
157+
test("explain output of physical plan should contain proper codegen stage ID") {
158+
checkKeywordsExist(sql(
159+
"""
160+
|EXPLAIN SELECT t1.id AS a, t2.id AS b FROM
161+
|(SELECT * FROM range(3)) t1 JOIN
162+
|(SELECT * FROM range(10)) t2 ON t1.id == t2.id % 3
163+
""".stripMargin),
164+
"== Physical Plan ==",
165+
"*(2) Project ",
166+
"+- *(2) BroadcastHashJoin ",
167+
" :- BroadcastExchange ",
168+
" : +- *(1) Range ",
169+
" +- *(2) Range "
164170
)
171+
}
172+
173+
test("EXPLAIN CODEGEN command") {
174+
// the generated class name in this test should stay in sync with
175+
// org.apache.spark.sql.execution.WholeStageCodegenExec.generatedClassName()
176+
for ((useIdInClassName, expectedClassName) <- Seq(
177+
("true", "GeneratedIteratorForCodegenStage1"),
178+
("false", "GeneratedIterator"))) {
179+
withSQLConf(
180+
SQLConf.WHOLESTAGE_CODEGEN_USE_ID_IN_CLASS_NAME.key -> useIdInClassName) {
181+
checkKeywordsExist(sql("EXPLAIN CODEGEN SELECT 1"),
182+
"WholeStageCodegen",
183+
"Generated code:",
184+
"/* 001 */ public Object generate(Object[] references) {",
185+
s"/* 002 */ return new $expectedClassName(references);",
186+
"/* 003 */ }"
187+
)
188+
}
189+
}
165190

166191
checkKeywordsNotExist(sql("EXPLAIN CODEGEN SELECT 1"),
167192
"== Physical Plan =="

0 commit comments

Comments
 (0)