Skip to content

Commit c604d65

Browse files
hvanhovellcloud-fan
authored andcommitted
[SPARK-23951][SQL] Use actual java class instead of string representation.
## What changes were proposed in this pull request? This PR slightly refactors the newly added `ExprValue` API by quite a bit. The following changes are introduced: 1. `ExprValue` now uses the actual class instead of the class name as its type. This should give some more flexibility with generating code in the future. 2. Renamed `StatementValue` to `SimpleExprValue`. The statement concept is broader then an expression (untyped and it cannot be on the right hand side of an assignment), and this was not really what we were using it for. I have added a top level `JavaCode` trait that can be used in the future to reinstate (no pun intended) a statement a-like code fragment. 3. Added factory methods to the `JavaCode` companion object to make it slightly less verbose to create `JavaCode`/`ExprValue` objects. This is also what makes the diff quite large. 4. Added one more factory method to `ExprCode` to make it easier to create code-less expressions. ## How was this patch tested? Existing tests. Author: Herman van Hovell <hvanhovell@databricks.com> Closes #21026 from hvanhovell/SPARK-23951.
1 parent 87611bb commit c604d65

26 files changed

+315
-212
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,9 @@ abstract class Expression extends TreeNode[Expression] {
104104
}.getOrElse {
105105
val isNull = ctx.freshName("isNull")
106106
val value = ctx.freshName("value")
107-
val eval = doGenCode(ctx, ExprCode("",
108-
VariableValue(isNull, CodeGenerator.JAVA_BOOLEAN),
109-
VariableValue(value, CodeGenerator.javaType(dataType))))
107+
val eval = doGenCode(ctx, ExprCode(
108+
JavaCode.isNullVariable(isNull),
109+
JavaCode.variable(value, dataType)))
110110
reduceCodeSize(ctx, eval)
111111
if (eval.code.nonEmpty) {
112112
// Add `this` in the comment.
@@ -123,7 +123,7 @@ abstract class Expression extends TreeNode[Expression] {
123123
val setIsNull = if (!eval.isNull.isInstanceOf[LiteralValue]) {
124124
val globalIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "globalIsNull")
125125
val localIsNull = eval.isNull
126-
eval.isNull = GlobalValue(globalIsNull, CodeGenerator.JAVA_BOOLEAN)
126+
eval.isNull = JavaCode.isNullGlobal(globalIsNull)
127127
s"$globalIsNull = $localIsNull;"
128128
} else {
129129
""
@@ -142,7 +142,7 @@ abstract class Expression extends TreeNode[Expression] {
142142
|}
143143
""".stripMargin)
144144

145-
eval.value = VariableValue(newValue, javaType)
145+
eval.value = JavaCode.variable(newValue, dataType)
146146
eval.code = s"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});"
147147
}
148148
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -591,8 +591,7 @@ case class Least(children: Seq[Expression]) extends Expression {
591591

592592
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
593593
val evalChildren = children.map(_.genCode(ctx))
594-
ev.isNull = GlobalValue(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull),
595-
CodeGenerator.JAVA_BOOLEAN)
594+
ev.isNull = JavaCode.isNullGlobal(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull))
596595
val evals = evalChildren.map(eval =>
597596
s"""
598597
|${eval.code}
@@ -671,8 +670,7 @@ case class Greatest(children: Seq[Expression]) extends Expression {
671670

672671
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
673672
val evalChildren = children.map(_.genCode(ctx))
674-
ev.isNull = GlobalValue(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull),
675-
CodeGenerator.JAVA_BOOLEAN)
673+
ev.isNull = JavaCode.isNullGlobal(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull))
676674
val evals = evalChildren.map(eval =>
677675
s"""
678676
|${eval.code}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,12 @@ import org.apache.spark.util.{ParentClassLoader, Utils}
5959
case class ExprCode(var code: String, var isNull: ExprValue, var value: ExprValue)
6060

6161
object ExprCode {
62+
def apply(isNull: ExprValue, value: ExprValue): ExprCode = {
63+
ExprCode(code = "", isNull, value)
64+
}
65+
6266
def forNullValue(dataType: DataType): ExprCode = {
63-
val defaultValueLiteral = CodeGenerator.defaultValue(dataType, typedNull = true)
64-
ExprCode(code = "", isNull = TrueLiteral,
65-
value = LiteralValue(defaultValueLiteral, CodeGenerator.javaType(dataType)))
67+
ExprCode(code = "", isNull = TrueLiteral, JavaCode.defaultLiteral(dataType))
6668
}
6769

6870
def forNonNullValue(value: ExprValue): ExprCode = {
@@ -331,7 +333,7 @@ class CodegenContext {
331333
case _: StructType | _: ArrayType | _: MapType => s"$value = $initCode.copy();"
332334
case _ => s"$value = $initCode;"
333335
}
334-
ExprCode(code, FalseLiteral, GlobalValue(value, javaType(dataType)))
336+
ExprCode(code, FalseLiteral, JavaCode.global(value, dataType))
335337
}
336338

337339
def declareMutableStates(): String = {
@@ -1004,8 +1006,9 @@ class CodegenContext {
10041006
// at least two nodes) as the cost of doing it is expected to be low.
10051007

10061008
subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);"
1007-
val state = SubExprEliminationState(GlobalValue(isNull, JAVA_BOOLEAN),
1008-
GlobalValue(value, javaType(expr.dataType)))
1009+
val state = SubExprEliminationState(
1010+
JavaCode.isNullGlobal(isNull),
1011+
JavaCode.global(value, expr.dataType))
10091012
subExprEliminationExprs ++= e.map(_ -> state).toMap
10101013
}
10111014
}
@@ -1479,6 +1482,26 @@ object CodeGenerator extends Logging {
14791482
case _ => "Object"
14801483
}
14811484

1485+
def javaClass(dt: DataType): Class[_] = dt match {
1486+
case BooleanType => java.lang.Boolean.TYPE
1487+
case ByteType => java.lang.Byte.TYPE
1488+
case ShortType => java.lang.Short.TYPE
1489+
case IntegerType | DateType => java.lang.Integer.TYPE
1490+
case LongType | TimestampType => java.lang.Long.TYPE
1491+
case FloatType => java.lang.Float.TYPE
1492+
case DoubleType => java.lang.Double.TYPE
1493+
case _: DecimalType => classOf[Decimal]
1494+
case BinaryType => classOf[Array[Byte]]
1495+
case StringType => classOf[UTF8String]
1496+
case CalendarIntervalType => classOf[CalendarInterval]
1497+
case _: StructType => classOf[InternalRow]
1498+
case _: ArrayType => classOf[ArrayData]
1499+
case _: MapType => classOf[MapData]
1500+
case udt: UserDefinedType[_] => javaClass(udt.sqlType)
1501+
case ObjectType(cls) => cls
1502+
case _ => classOf[Object]
1503+
}
1504+
14821505
/**
14831506
* Returns the boxed type in Java.
14841507
*/

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/ExprValue.scala

Lines changed: 0 additions & 76 deletions
This file was deleted.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -52,43 +52,45 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP
5252
expressions: Seq[Expression],
5353
useSubexprElimination: Boolean): MutableProjection = {
5454
val ctx = newCodeGenContext()
55-
val (validExpr, index) = expressions.zipWithIndex.filter {
55+
val validExpr = expressions.zipWithIndex.filter {
5656
case (NoOp, _) => false
5757
case _ => true
58-
}.unzip
59-
val exprVals = ctx.generateExpressions(validExpr, useSubexprElimination)
58+
}
59+
val exprVals = ctx.generateExpressions(validExpr.map(_._1), useSubexprElimination)
6060

6161
// 4-tuples: (code for projection, isNull variable name, value variable name, column index)
62-
val projectionCodes: Seq[(String, ExprValue, String, Int)] = exprVals.zip(index).map {
63-
case (ev, i) =>
64-
val e = expressions(i)
65-
val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "value")
66-
if (e.nullable) {
62+
val projectionCodes: Seq[(String, String)] = validExpr.zip(exprVals).map {
63+
case ((e, i), ev) =>
64+
val value = JavaCode.global(
65+
ctx.addMutableState(CodeGenerator.javaType(e.dataType), "value"),
66+
e.dataType)
67+
val (code, isNull) = if (e.nullable) {
6768
val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "isNull")
6869
(s"""
6970
|${ev.code}
7071
|$isNull = ${ev.isNull};
7172
|$value = ${ev.value};
72-
""".stripMargin, GlobalValue(isNull, CodeGenerator.JAVA_BOOLEAN), value, i)
73+
""".stripMargin, JavaCode.isNullGlobal(isNull))
7374
} else {
7475
(s"""
7576
|${ev.code}
7677
|$value = ${ev.value};
77-
""".stripMargin, ev.isNull, value, i)
78+
""".stripMargin, FalseLiteral)
7879
}
80+
val update = CodeGenerator.updateColumn(
81+
"mutableRow",
82+
e.dataType,
83+
i,
84+
ExprCode(isNull, value),
85+
e.nullable)
86+
(code, update)
7987
}
8088

8189
// Evaluate all the subexpressions.
8290
val evalSubexpr = ctx.subexprFunctions.mkString("\n")
8391

84-
val updates = validExpr.zip(projectionCodes).map {
85-
case (e, (_, isNull, value, i)) =>
86-
val ev = ExprCode("", isNull, GlobalValue(value, CodeGenerator.javaType(e.dataType)))
87-
CodeGenerator.updateColumn("mutableRow", e.dataType, i, ev, e.nullable)
88-
}
89-
9092
val allProjections = ctx.splitExpressionsWithCurrentInputs(projectionCodes.map(_._1))
91-
val allUpdates = ctx.splitExpressionsWithCurrentInputs(updates)
93+
val allUpdates = ctx.splitExpressionsWithCurrentInputs(projectionCodes.map(_._2))
9294

9395
val codeBody = s"""
9496
public java.lang.Object generate(Object[] references) {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,10 @@ package org.apache.spark.sql.catalyst.expressions.codegen
1919

2020
import scala.annotation.tailrec
2121

22+
import org.apache.spark.sql.catalyst.InternalRow
2223
import org.apache.spark.sql.catalyst.expressions._
2324
import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
24-
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
25+
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
2526
import org.apache.spark.sql.types._
2627

2728
/**
@@ -53,9 +54,10 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
5354
val rowClass = classOf[GenericInternalRow].getName
5455

5556
val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) =>
56-
val converter = convertToSafe(ctx,
57-
StatementValue(CodeGenerator.getValue(tmpInput, dt, i.toString),
58-
CodeGenerator.javaType(dt)), dt)
57+
val converter = convertToSafe(
58+
ctx,
59+
JavaCode.expression(CodeGenerator.getValue(tmpInput, dt, i.toString), dt),
60+
dt)
5961
s"""
6062
if (!$tmpInput.isNullAt($i)) {
6163
${converter.code}
@@ -76,7 +78,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
7678
|final InternalRow $output = new $rowClass($values);
7779
""".stripMargin
7880

79-
ExprCode(code, FalseLiteral, VariableValue(output, "InternalRow"))
81+
ExprCode(code, FalseLiteral, JavaCode.variable(output, classOf[InternalRow]))
8082
}
8183

8284
private def createCodeForArray(
@@ -91,9 +93,10 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
9193
val index = ctx.freshName("index")
9294
val arrayClass = classOf[GenericArrayData].getName
9395

94-
val elementConverter = convertToSafe(ctx,
95-
StatementValue(CodeGenerator.getValue(tmpInput, elementType, index),
96-
CodeGenerator.javaType(elementType)), elementType)
96+
val elementConverter = convertToSafe(
97+
ctx,
98+
JavaCode.expression(CodeGenerator.getValue(tmpInput, elementType, index), elementType),
99+
elementType)
97100
val code = s"""
98101
final ArrayData $tmpInput = $input;
99102
final int $numElements = $tmpInput.numElements();
@@ -107,7 +110,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
107110
final ArrayData $output = new $arrayClass($values);
108111
"""
109112

110-
ExprCode(code, FalseLiteral, VariableValue(output, "ArrayData"))
113+
ExprCode(code, FalseLiteral, JavaCode.variable(output, classOf[ArrayData]))
111114
}
112115

113116
private def createCodeForMap(
@@ -128,7 +131,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
128131
final MapData $output = new $mapClass(${keyConverter.value}, ${valueConverter.value});
129132
"""
130133

131-
ExprCode(code, FalseLiteral, VariableValue(output, "MapData"))
134+
ExprCode(code, FalseLiteral, JavaCode.variable(output, classOf[MapData]))
132135
}
133136

134137
@tailrec
@@ -140,7 +143,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
140143
case ArrayType(elementType, _) => createCodeForArray(ctx, input, elementType)
141144
case MapType(keyType, valueType, _) => createCodeForMap(ctx, input, keyType, valueType)
142145
case udt: UserDefinedType[_] => convertToSafe(ctx, input, udt.sqlType)
143-
case _ => ExprCode("", FalseLiteral, input)
146+
case _ => ExprCode(FalseLiteral, input)
144147
}
145148

146149
protected def create(expressions: Seq[Expression]): Projection = {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
5252
// Puts `input` in a local variable to avoid to re-evaluate it if it's a statement.
5353
val tmpInput = ctx.freshName("tmpInput")
5454
val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) =>
55-
ExprCode("", StatementValue(s"$tmpInput.isNullAt($i)", CodeGenerator.JAVA_BOOLEAN),
56-
StatementValue(CodeGenerator.getValue(tmpInput, dt, i.toString),
57-
CodeGenerator.javaType(dt)))
55+
ExprCode(
56+
JavaCode.isNullExpression(s"$tmpInput.isNullAt($i)"),
57+
JavaCode.expression(CodeGenerator.getValue(tmpInput, dt, i.toString), dt))
5858
}
5959

6060
val rowWriterClass = classOf[UnsafeRowWriter].getName
@@ -109,7 +109,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
109109
}
110110

111111
val writeField = writeElement(ctx, input.value, index.toString, dt, rowWriter)
112-
if (input.isNull == "false") {
112+
if (input.isNull == FalseLiteral) {
113113
s"""
114114
|${input.code}
115115
|${writeField.trim}
@@ -292,8 +292,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
292292
|$writeExpressions
293293
""".stripMargin
294294
// `rowWriter` is declared as a class field, so we can access it directly in methods.
295-
ExprCode(code, FalseLiteral, StatementValue(s"$rowWriter.getRow()", "UnsafeRow",
296-
canDirectAccess = true))
295+
ExprCode(code, FalseLiteral, JavaCode.expression(s"$rowWriter.getRow()", classOf[UnsafeRow]))
297296
}
298297

299298
protected def canonicalize(in: Seq[Expression]): Seq[Expression] =

0 commit comments

Comments
 (0)