Skip to content

Commit c79e771

Browse files
viiryacloud-fan
authored andcommitted
[SPARK-21717][SQL] Decouple consume functions of physical operators in whole-stage codegen
## What changes were proposed in this pull request? It has been observed in SPARK-21603 that whole-stage codegen suffers performance degradation, if the generated functions are too long to be optimized by JIT. We basically produce a single function to incorporate generated codes from all physical operators in whole-stage. Thus, it is possibly to grow the size of generated function over a threshold that we can't have JIT optimization for it anymore. This patch is trying to decouple the logic of consuming rows in physical operators to avoid a giant function processing rows. ## How was this patch tested? Added tests. Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #18931 from viirya/SPARK-21717. (cherry picked from commit d20bbc2) Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent e66c66c commit c79e771

File tree

4 files changed

+203
-29
lines changed

4 files changed

+203
-29
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1245,6 +1245,31 @@ class CodegenContext {
12451245
""
12461246
}
12471247
}
1248+
1249+
/**
1250+
* Returns the length of parameters for a Java method descriptor. `this` contributes one unit
1251+
* and a parameter of type long or double contributes two units. Besides, for nullable parameter,
1252+
* we also need to pass a boolean parameter for the null status.
1253+
*/
1254+
def calculateParamLength(params: Seq[Expression]): Int = {
1255+
def paramLengthForExpr(input: Expression): Int = {
1256+
// For a nullable expression, we need to pass in an extra boolean parameter.
1257+
(if (input.nullable) 1 else 0) + javaType(input.dataType) match {
1258+
case JAVA_LONG | JAVA_DOUBLE => 2
1259+
case _ => 1
1260+
}
1261+
}
1262+
// Initial value is 1 for `this`.
1263+
1 + params.map(paramLengthForExpr(_)).sum
1264+
}
1265+
1266+
/**
1267+
* In Java, a method descriptor is valid only if it represents method parameters with a total
1268+
* length less than a pre-defined constant.
1269+
*/
1270+
def isValidParamLength(paramLength: Int): Boolean = {
1271+
paramLength <= CodeGenerator.MAX_JVM_METHOD_PARAMS_LENGTH
1272+
}
12481273
}
12491274

12501275
/**
@@ -1311,26 +1336,29 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
13111336
object CodeGenerator extends Logging {
13121337

13131338
// This is the value of HugeMethodLimit in the OpenJDK JVM settings
1314-
val DEFAULT_JVM_HUGE_METHOD_LIMIT = 8000
1339+
final val DEFAULT_JVM_HUGE_METHOD_LIMIT = 8000
1340+
1341+
// The max valid length of method parameters in JVM.
1342+
final val MAX_JVM_METHOD_PARAMS_LENGTH = 255
13151343

13161344
// This is the threshold over which the methods in an inner class are grouped in a single
13171345
// method which is going to be called by the outer class instead of the many small ones
1318-
val MERGE_SPLIT_METHODS_THRESHOLD = 3
1346+
final val MERGE_SPLIT_METHODS_THRESHOLD = 3
13191347

13201348
// The number of named constants that can exist in the class is limited by the Constant Pool
13211349
// limit, 65,536. We cannot know how many constants will be inserted for a class, so we use a
13221350
// threshold of 1000k bytes to determine when a function should be inlined to a private, inner
13231351
// class.
1324-
val GENERATED_CLASS_SIZE_THRESHOLD = 1000000
1352+
final val GENERATED_CLASS_SIZE_THRESHOLD = 1000000
13251353

13261354
// This is the threshold for the number of global variables, whose types are primitive type or
13271355
// complex type (e.g. more than one-dimensional array), that will be placed at the outer class
1328-
val OUTER_CLASS_VARIABLES_THRESHOLD = 10000
1356+
final val OUTER_CLASS_VARIABLES_THRESHOLD = 10000
13291357

13301358
// This is the maximum number of array elements to keep global variables in one Java array
13311359
// 32767 is the maximum integer value that does not require a constant pool entry in a Java
13321360
// bytecode instruction
1333-
val MUTABLESTATEARRAY_SIZE_LIMIT = 32768
1361+
final val MUTABLESTATEARRAY_SIZE_LIMIT = 32768
13341362

13351363
/**
13361364
* Compile the Java source code into a Java class, using Janino.

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -661,6 +661,15 @@ object SQLConf {
661661
.intConf
662662
.createWithDefault(CodeGenerator.DEFAULT_JVM_HUGE_METHOD_LIMIT)
663663

664+
val WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR =
665+
buildConf("spark.sql.codegen.splitConsumeFuncByOperator")
666+
.internal()
667+
.doc("When true, whole stage codegen would put the logic of consuming rows of each " +
668+
"physical operator into individual methods, instead of a single big method. This can be " +
669+
"used to avoid oversized function that can miss the opportunity of JIT optimization.")
670+
.booleanConf
671+
.createWithDefault(true)
672+
664673
val FILES_MAX_PARTITION_BYTES = buildConf("spark.sql.files.maxPartitionBytes")
665674
.doc("The maximum number of bytes to pack into a single partition when reading files.")
666675
.longConf
@@ -1263,6 +1272,9 @@ class SQLConf extends Serializable with Logging {
12631272

12641273
def hugeMethodLimit: Int = getConf(WHOLESTAGE_HUGE_METHOD_LIMIT)
12651274

1275+
def wholeStageSplitConsumeFuncByOperator: Boolean =
1276+
getConf(WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR)
1277+
12661278
def tableRelationCacheSize: Int =
12671279
getConf(StaticSQLConf.FILESOURCE_TABLE_RELATION_CACHE_SIZE)
12681280

sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala

Lines changed: 112 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ package org.apache.spark.sql.execution
1919

2020
import java.util.Locale
2121

22+
import scala.collection.mutable
23+
2224
import org.apache.spark.broadcast
2325
import org.apache.spark.rdd.RDD
2426
import org.apache.spark.sql.catalyst.InternalRow
@@ -106,6 +108,31 @@ trait CodegenSupport extends SparkPlan {
106108
*/
107109
protected def doProduce(ctx: CodegenContext): String
108110

111+
private def prepareRowVar(ctx: CodegenContext, row: String, colVars: Seq[ExprCode]): ExprCode = {
112+
if (row != null) {
113+
ExprCode("", "false", row)
114+
} else {
115+
if (colVars.nonEmpty) {
116+
val colExprs = output.zipWithIndex.map { case (attr, i) =>
117+
BoundReference(i, attr.dataType, attr.nullable)
118+
}
119+
val evaluateInputs = evaluateVariables(colVars)
120+
// generate the code to create a UnsafeRow
121+
ctx.INPUT_ROW = row
122+
ctx.currentVars = colVars
123+
val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
124+
val code = s"""
125+
|$evaluateInputs
126+
|${ev.code.trim}
127+
""".stripMargin.trim
128+
ExprCode(code, "false", ev.value)
129+
} else {
130+
// There is no columns
131+
ExprCode("", "false", "unsafeRow")
132+
}
133+
}
134+
}
135+
109136
/**
110137
* Consume the generated columns or row from current SparkPlan, call its parent's `doConsume()`.
111138
*
@@ -126,28 +153,7 @@ trait CodegenSupport extends SparkPlan {
126153
}
127154
}
128155

129-
val rowVar = if (row != null) {
130-
ExprCode("", "false", row)
131-
} else {
132-
if (outputVars.nonEmpty) {
133-
val colExprs = output.zipWithIndex.map { case (attr, i) =>
134-
BoundReference(i, attr.dataType, attr.nullable)
135-
}
136-
val evaluateInputs = evaluateVariables(outputVars)
137-
// generate the code to create a UnsafeRow
138-
ctx.INPUT_ROW = row
139-
ctx.currentVars = outputVars
140-
val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
141-
val code = s"""
142-
|$evaluateInputs
143-
|${ev.code.trim}
144-
""".stripMargin.trim
145-
ExprCode(code, "false", ev.value)
146-
} else {
147-
// There is no columns
148-
ExprCode("", "false", "unsafeRow")
149-
}
150-
}
156+
val rowVar = prepareRowVar(ctx, row, outputVars)
151157

152158
// Set up the `currentVars` in the codegen context, as we generate the code of `inputVars`
153159
// before calling `parent.doConsume`. We can't set up `INPUT_ROW`, because parent needs to
@@ -156,13 +162,96 @@ trait CodegenSupport extends SparkPlan {
156162
ctx.INPUT_ROW = null
157163
ctx.freshNamePrefix = parent.variablePrefix
158164
val evaluated = evaluateRequiredVariables(output, inputVars, parent.usedInputs)
165+
166+
// Under certain conditions, we can put the logic to consume the rows of this operator into
167+
// another function. So we can prevent a generated function too long to be optimized by JIT.
168+
// The conditions:
169+
// 1. The config "spark.sql.codegen.splitConsumeFuncByOperator" is enabled.
170+
// 2. `inputVars` are all materialized. That is guaranteed to be true if the parent plan uses
171+
// all variables in output (see `requireAllOutput`).
172+
// 3. The number of output variables must less than maximum number of parameters in Java method
173+
// declaration.
174+
val confEnabled = SQLConf.get.wholeStageSplitConsumeFuncByOperator
175+
val requireAllOutput = output.forall(parent.usedInputs.contains(_))
176+
val paramLength = ctx.calculateParamLength(output) + (if (row != null) 1 else 0)
177+
val consumeFunc = if (confEnabled && requireAllOutput && ctx.isValidParamLength(paramLength)) {
178+
constructDoConsumeFunction(ctx, inputVars, row)
179+
} else {
180+
parent.doConsume(ctx, inputVars, rowVar)
181+
}
159182
s"""
160183
|${ctx.registerComment(s"CONSUME: ${parent.simpleString}")}
161184
|$evaluated
162-
|${parent.doConsume(ctx, inputVars, rowVar)}
185+
|$consumeFunc
186+
""".stripMargin
187+
}
188+
189+
/**
190+
* To prevent concatenated function growing too long to be optimized by JIT. We can separate the
191+
* parent's `doConsume` codes of a `CodegenSupport` operator into a function to call.
192+
*/
193+
private def constructDoConsumeFunction(
194+
ctx: CodegenContext,
195+
inputVars: Seq[ExprCode],
196+
row: String): String = {
197+
val (args, params, inputVarsInFunc) = constructConsumeParameters(ctx, output, inputVars, row)
198+
val rowVar = prepareRowVar(ctx, row, inputVarsInFunc)
199+
200+
val doConsume = ctx.freshName("doConsume")
201+
ctx.currentVars = inputVarsInFunc
202+
ctx.INPUT_ROW = null
203+
204+
val doConsumeFuncName = ctx.addNewFunction(doConsume,
205+
s"""
206+
| private void $doConsume(${params.mkString(", ")}) throws java.io.IOException {
207+
| ${parent.doConsume(ctx, inputVarsInFunc, rowVar)}
208+
| }
209+
""".stripMargin)
210+
211+
s"""
212+
| $doConsumeFuncName(${args.mkString(", ")});
163213
""".stripMargin
164214
}
165215

216+
/**
217+
* Returns arguments for calling method and method definition parameters of the consume function.
218+
* And also returns the list of `ExprCode` for the parameters.
219+
*/
220+
private def constructConsumeParameters(
221+
ctx: CodegenContext,
222+
attributes: Seq[Attribute],
223+
variables: Seq[ExprCode],
224+
row: String): (Seq[String], Seq[String], Seq[ExprCode]) = {
225+
val arguments = mutable.ArrayBuffer[String]()
226+
val parameters = mutable.ArrayBuffer[String]()
227+
val paramVars = mutable.ArrayBuffer[ExprCode]()
228+
229+
if (row != null) {
230+
arguments += row
231+
parameters += s"InternalRow $row"
232+
}
233+
234+
variables.zipWithIndex.foreach { case (ev, i) =>
235+
val paramName = ctx.freshName(s"expr_$i")
236+
val paramType = ctx.javaType(attributes(i).dataType)
237+
238+
arguments += ev.value
239+
parameters += s"$paramType $paramName"
240+
val paramIsNull = if (!attributes(i).nullable) {
241+
// Use constant `false` without passing `isNull` for non-nullable variable.
242+
"false"
243+
} else {
244+
val isNull = ctx.freshName(s"exprIsNull_$i")
245+
arguments += ev.isNull
246+
parameters += s"boolean $isNull"
247+
isNull
248+
}
249+
250+
paramVars += ExprCode("", paramIsNull, paramName)
251+
}
252+
(arguments, parameters, paramVars)
253+
}
254+
166255
/**
167256
* Returns source code to evaluate all the variables, and clear the code of them, to prevent
168257
* them to be evaluated twice.

sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext {
205205
val codeWithShortFunctions = genGroupByCode(3)
206206
val (_, maxCodeSize1) = CodeGenerator.compile(codeWithShortFunctions)
207207
assert(maxCodeSize1 < SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.defaultValue.get)
208-
val codeWithLongFunctions = genGroupByCode(20)
208+
val codeWithLongFunctions = genGroupByCode(50)
209209
val (_, maxCodeSize2) = CodeGenerator.compile(codeWithLongFunctions)
210210
assert(maxCodeSize2 > SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.defaultValue.get)
211211
}
@@ -228,4 +228,49 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext {
228228
}
229229
}
230230
}
231+
232+
test("Control splitting consume function by operators with config") {
233+
import testImplicits._
234+
val df = spark.range(10).select(Seq.tabulate(2) {i => ('id + i).as(s"c$i")} : _*)
235+
236+
Seq(true, false).foreach { config =>
237+
withSQLConf(SQLConf.WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR.key -> s"$config") {
238+
val plan = df.queryExecution.executedPlan
239+
val wholeStageCodeGenExec = plan.find(p => p match {
240+
case wp: WholeStageCodegenExec => true
241+
case _ => false
242+
})
243+
assert(wholeStageCodeGenExec.isDefined)
244+
val code = wholeStageCodeGenExec.get.asInstanceOf[WholeStageCodegenExec].doCodeGen()._2
245+
assert(code.body.contains("project_doConsume") == config)
246+
}
247+
}
248+
}
249+
250+
test("Skip splitting consume function when parameter number exceeds JVM limit") {
251+
import testImplicits._
252+
253+
Seq((255, false), (254, true)).foreach { case (columnNum, hasSplit) =>
254+
withTempPath { dir =>
255+
val path = dir.getCanonicalPath
256+
spark.range(10).select(Seq.tabulate(columnNum) {i => ('id + i).as(s"c$i")} : _*)
257+
.write.mode(SaveMode.Overwrite).parquet(path)
258+
259+
withSQLConf(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> "255",
260+
SQLConf.WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR.key -> "true") {
261+
val projection = Seq.tabulate(columnNum)(i => s"c$i + c$i as newC$i")
262+
val df = spark.read.parquet(path).selectExpr(projection: _*)
263+
264+
val plan = df.queryExecution.executedPlan
265+
val wholeStageCodeGenExec = plan.find(p => p match {
266+
case wp: WholeStageCodegenExec => true
267+
case _ => false
268+
})
269+
assert(wholeStageCodeGenExec.isDefined)
270+
val code = wholeStageCodeGenExec.get.asInstanceOf[WholeStageCodegenExec].doCodeGen()._2
271+
assert(code.body.contains("project_doConsume") == hasSplit)
272+
}
273+
}
274+
}
275+
}
231276
}

0 commit comments

Comments
 (0)