Skip to content

Commit 0c23a39

Browse files
[SPARK-26205][SQL] Optimize InSet Expression for bytes, shorts, ints, dates
## What changes were proposed in this pull request? This PR optimizes `InSet` expressions for byte, short, integer, date types. It is a follow-up on PR #21442 from dbtsai. `In` expressions are compiled into a sequence of if-else statements, which results in O\(n\) time complexity. `InSet` is an optimized version of `In`, which is supposed to improve the performance if all values are literals and the number of elements is big enough. However, `InSet` actually worsens the performance in many cases due to various reasons. The main idea of this PR is to use Java `switch` statements to significantly improve the performance of `InSet` expressions for bytes, shorts, ints, dates. All `switch` statements are compiled into `tableswitch` and `lookupswitch` bytecode instructions. We will have O\(1\) time complexity if our case values are compact and `tableswitch` can be used. Otherwise, `lookupswitch` will give us O\(log n\). Locally, I tried Spark `OpenHashSet` and primitive collections from `fastutils` in order to solve the boxing issue in `InSet`. Both options significantly decreased the memory consumption and `fastutils` improved the time compared to `HashSet` from Scala. However, the switch-based approach was still more than two times faster even on 500+ non-compact elements. I also noticed that applying the switch-based approach on less than 10 elements gives a relatively minor improvement compared to the if-else approach. Therefore, I placed the switch-based logic into `InSet` and added a new config to track when it is applied. Even if we migrate to primitive collections at some point, the switch logic will be still faster unless the number of elements is really big. Another option is to have a separate `InSwitch` expression. However, this would mean we need to modify other places (e.g., `DataSourceStrategy`). See [here](https://docs.oracle.com/javase/specs/jvms/se7/html/jvms-3.html#jvms-3.10) and [here](https://stackoverflow.com/questions/10287700/difference-between-jvms-lookupswitch-and-tableswitch) for more information. This PR does not cover long values as Java `switch` statements cannot be used on them. However, we can have a follow-up PR with an approach similar to binary search. ## How was this patch tested? There are new tests that verify the logic of the proposed optimization. The performance was evaluated using existing benchmarks. This PR was also tested on an EC2 instance (OpenJDK 64-Bit Server VM 1.8.0_191-b12 on Linux 4.14.77-70.59.amzn1.x86_64, Intel(R) Xeon(R) CPU E5-2686 v4 2.30GHz). ## Notes - [This link](http://hg.openjdk.java.net/jdk8/jdk8/langtools/file/30db5e0aaf83/src/share/classes/com/sun/tools/javac/jvm/Gen.java#l1153) contains source code that decides between `tableswitch` and `lookupswitch`. The logic was re-used in the benchmarks. See the `isLookupSwitch` method. Closes #23171 from aokolnychyi/spark-26205. Lead-authored-by: Anton Okolnychyi <aokolnychyi@apple.com> Co-authored-by: Dongjoon Hyun <dhyun@apple.com> Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
1 parent 0deebd3 commit 0c23a39

File tree

7 files changed

+695
-328
lines changed

7 files changed

+695
-328
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ object Block {
224224
} else {
225225
args.foreach {
226226
case _: ExprValue | _: Inline | _: Block =>
227-
case _: Int | _: Long | _: Float | _: Double | _: String =>
227+
case _: Boolean | _: Int | _: Long | _: Float | _: Double | _: String =>
228228
case other => throw new IllegalArgumentException(
229229
s"Can not interpolate ${other.getClass.getName} into code block.")
230230
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGe
2525
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
2626
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2727
import org.apache.spark.sql.catalyst.util.TypeUtils
28+
import org.apache.spark.sql.internal.SQLConf
2829
import org.apache.spark.sql.types._
2930

3031

@@ -375,6 +376,19 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
375376
}
376377

377378
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
379+
if (canBeComputedUsingSwitch && hset.size <= SQLConf.get.optimizerInSetSwitchThreshold) {
380+
genCodeWithSwitch(ctx, ev)
381+
} else {
382+
genCodeWithSet(ctx, ev)
383+
}
384+
}
385+
386+
private def canBeComputedUsingSwitch: Boolean = child.dataType match {
387+
case ByteType | ShortType | IntegerType | DateType => true
388+
case _ => false
389+
}
390+
391+
private def genCodeWithSet(ctx: CodegenContext, ev: ExprCode): ExprCode = {
378392
nullSafeCodeGen(ctx, ev, c => {
379393
val setTerm = ctx.addReferenceObj("set", set)
380394
val setIsNull = if (hasNull) {
@@ -389,6 +403,34 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
389403
})
390404
}
391405

406+
// spark.sql.optimizer.inSetSwitchThreshold has an appropriate upper limit,
407+
// so the code size should not exceed 64KB
408+
private def genCodeWithSwitch(ctx: CodegenContext, ev: ExprCode): ExprCode = {
409+
val caseValuesGen = hset.filter(_ != null).map(Literal(_).genCode(ctx))
410+
val valueGen = child.genCode(ctx)
411+
412+
val caseBranches = caseValuesGen.map(literal =>
413+
code"""
414+
case ${literal.value}:
415+
${ev.value} = true;
416+
break;
417+
""")
418+
419+
ev.copy(code =
420+
code"""
421+
${valueGen.code}
422+
${CodeGenerator.JAVA_BOOLEAN} ${ev.isNull} = ${valueGen.isNull};
423+
${CodeGenerator.JAVA_BOOLEAN} ${ev.value} = false;
424+
if (!${valueGen.isNull}) {
425+
switch (${valueGen.value}) {
426+
${caseBranches.mkString("\n")}
427+
default:
428+
${ev.isNull} = $hasNull;
429+
}
430+
}
431+
""")
432+
}
433+
392434
override def sql: String = {
393435
val valueSQL = child.sql
394436
val listSQL = hset.toSeq.map(Literal(_).sql).mkString(", ")

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,16 @@ object SQLConf {
171171
.intConf
172172
.createWithDefault(10)
173173

174+
val OPTIMIZER_INSET_SWITCH_THRESHOLD =
175+
buildConf("spark.sql.optimizer.inSetSwitchThreshold")
176+
.internal()
177+
.doc("Configures the max set size in InSet for which Spark will generate code with " +
178+
"switch statements. This is applicable only to bytes, shorts, ints, dates.")
179+
.intConf
180+
.checkValue(threshold => threshold >= 0 && threshold <= 600, "The max set size " +
181+
"for using switch statements in InSet must be non-negative and less than or equal to 600")
182+
.createWithDefault(400)
183+
174184
val OPTIMIZER_PLAN_CHANGE_LOG_LEVEL = buildConf("spark.sql.optimizer.planChangeLog.level")
175185
.internal()
176186
.doc("Configures the log level for logging the change from the original plan to the new " +
@@ -1725,6 +1735,8 @@ class SQLConf extends Serializable with Logging {
17251735

17261736
def optimizerInSetConversionThreshold: Int = getConf(OPTIMIZER_INSET_CONVERSION_THRESHOLD)
17271737

1738+
def optimizerInSetSwitchThreshold: Int = getConf(OPTIMIZER_INSET_SWITCH_THRESHOLD)
1739+
17281740
def optimizerPlanChangeLogLevel: String = getConf(OPTIMIZER_PLAN_CHANGE_LOG_LEVEL)
17291741

17301742
def optimizerPlanChangeRules: Option[String] = getConf(OPTIMIZER_PLAN_CHANGE_LOG_RULES)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,12 @@ import scala.collection.immutable.HashSet
2323

2424
import org.apache.spark.SparkFunSuite
2525
import org.apache.spark.sql.RandomDataGenerator
26-
import org.apache.spark.sql.catalyst.InternalRow
26+
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
2727
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2828
import org.apache.spark.sql.catalyst.encoders.ExamplePointUDT
2929
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
3030
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
31+
import org.apache.spark.sql.internal.SQLConf
3132
import org.apache.spark.sql.types._
3233

3334

@@ -241,6 +242,52 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
241242
}
242243
}
243244

245+
test("switch statements in InSet for bytes, shorts, ints, dates") {
246+
val byteValues = Set[Any](1.toByte, 2.toByte, Byte.MinValue, Byte.MaxValue)
247+
val shortValues = Set[Any](-10.toShort, 20.toShort, Short.MinValue, Short.MaxValue)
248+
val intValues = Set[Any](20, -100, 30, Int.MinValue, Int.MaxValue)
249+
val dateValues = Set[Any](
250+
CatalystTypeConverters.convertToCatalyst(Date.valueOf("2017-01-01")),
251+
CatalystTypeConverters.convertToCatalyst(Date.valueOf("1950-01-02")))
252+
253+
def check(presentValue: Expression, absentValue: Expression, values: Set[Any]): Unit = {
254+
require(presentValue.dataType == absentValue.dataType)
255+
256+
val nullLiteral = Literal(null, presentValue.dataType)
257+
258+
checkEvaluation(InSet(nullLiteral, values), expected = null)
259+
checkEvaluation(InSet(nullLiteral, values + null), expected = null)
260+
checkEvaluation(InSet(presentValue, values), expected = true)
261+
checkEvaluation(InSet(presentValue, values + null), expected = true)
262+
checkEvaluation(InSet(absentValue, values), expected = false)
263+
checkEvaluation(InSet(absentValue, values + null), expected = null)
264+
}
265+
266+
def checkAllTypes(): Unit = {
267+
check(presentValue = Literal(2.toByte), absentValue = Literal(3.toByte), byteValues)
268+
check(presentValue = Literal(Byte.MinValue), absentValue = Literal(5.toByte), byteValues)
269+
check(presentValue = Literal(20.toShort), absentValue = Literal(-14.toShort), shortValues)
270+
check(presentValue = Literal(Short.MaxValue), absentValue = Literal(30.toShort), shortValues)
271+
check(presentValue = Literal(20), absentValue = Literal(-14), intValues)
272+
check(presentValue = Literal(Int.MinValue), absentValue = Literal(2), intValues)
273+
check(
274+
presentValue = Literal(Date.valueOf("2017-01-01")),
275+
absentValue = Literal(Date.valueOf("2017-01-02")),
276+
dateValues)
277+
check(
278+
presentValue = Literal(Date.valueOf("1950-01-02")),
279+
absentValue = Literal(Date.valueOf("2017-10-02")),
280+
dateValues)
281+
}
282+
283+
withSQLConf(SQLConf.OPTIMIZER_INSET_SWITCH_THRESHOLD.key -> "0") {
284+
checkAllTypes()
285+
}
286+
withSQLConf(SQLConf.OPTIMIZER_INSET_SWITCH_THRESHOLD.key -> "20") {
287+
checkAllTypes()
288+
}
289+
}
290+
244291
test("SPARK-22501: In should not generate codes beyond 64KB") {
245292
val N = 3000
246293
val sets = (1 to N).map(i => Literal(i.toDouble))

0 commit comments

Comments
 (0)