Skip to content

Commit 02262c9

Browse files
author
Davies Liu
committed
address comments
1 parent b5d3617 commit 02262c9

File tree

7 files changed

+89
-119
lines changed

7 files changed

+89
-119
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 18 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -442,47 +442,35 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
442442
case (BinaryType, StringType) =>
443443
defineCodeGen (ctx, ev, c =>
444444
s"new ${ctx.stringType}().set($c)")
445-
446445
case (DateType, StringType) =>
447446
defineCodeGen(ctx, ev, c =>
448447
s"""new ${ctx.stringType}().set(
449448
org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""")
450-
451-
case (BooleanType, dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
452-
defineCodeGen(ctx, ev, c => s"(${ctx.primitiveType(dt)})($c ? 1 : 0)")
453-
454-
case (_: NumericType, dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
455-
defineCodeGen(ctx, ev, c => s"(${ctx.primitiveType(dt)})($c)")
456-
457-
case (_: DecimalType, ByteType) =>
458-
defineCodeGen(ctx, ev, c => s"($c).toByte()")
459-
460-
case (_: DecimalType, ShortType) =>
461-
defineCodeGen(ctx, ev, c => s"($c).toShort()")
462-
463-
case (_: DecimalType, IntegerType) =>
464-
defineCodeGen(ctx, ev, c => s"($c).toInt()")
465-
466-
case (_: DecimalType, LongType) =>
467-
defineCodeGen(ctx, ev, c => s"($c).toLong()")
468-
469-
case (_: DecimalType, FloatType) =>
470-
defineCodeGen(ctx, ev, c => s"($c).toFloat()")
471-
472-
case (_: DecimalType, DoubleType) =>
473-
defineCodeGen(ctx, ev, c => s"($c).toDouble()")
474-
475-
case (_: DecimalType, dt: NumericType) if !dt.isInstanceOf[DecimalType] =>
476-
defineCodeGen(ctx, ev, c => s"($c).to${ctx.boxedType(dt)}()")
477-
478449
// Special handling required for timestamps in hive test cases since the toString function
479450
// does not match the expected output.
480451
case (TimestampType, StringType) =>
481452
super.genCode(ctx, ev)
482-
483453
case (_, StringType) =>
484454
defineCodeGen(ctx, ev, c => s"new ${ctx.stringType}().set(String.valueOf($c))")
485455

456+
// fallback for DecimalType, this must be before other numeric types
457+
case (_, dt: DecimalType) =>
458+
super.genCode(ctx, ev)
459+
460+
case (BooleanType, dt: NumericType) =>
461+
defineCodeGen(ctx, ev, c => s"(${ctx.primitiveType(dt)})($c ? 1 : 0)")
462+
case (dt: DecimalType, BooleanType) =>
463+
defineCodeGen(ctx, ev, c => s"$c.isZero()")
464+
case (dt: NumericType, BooleanType) =>
465+
defineCodeGen(ctx, ev, c => s"$c != 0")
466+
467+
case (_: DecimalType, IntegerType) =>
468+
defineCodeGen(ctx, ev, c => s"($c).toInt()")
469+
case (_: DecimalType, dt: NumericType) =>
470+
defineCodeGen(ctx, ev, c => s"($c).to${ctx.boxedType(dt)}()")
471+
case (_: NumericType, dt: NumericType) =>
472+
defineCodeGen(ctx, ev, c => s"(${ctx.primitiveType(dt)})($c)")
473+
486474
case other =>
487475
super.genCode(ctx, ev)
488476
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute}
21-
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext}
21+
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext, Term}
2222
import org.apache.spark.sql.catalyst.trees
2323
import org.apache.spark.sql.catalyst.trees.TreeNode
2424
import org.apache.spark.sql.types._
@@ -62,8 +62,7 @@ abstract class Expression extends TreeNode[Expression] {
6262
def gen(ctx: CodeGenContext): GeneratedExpressionCode = {
6363
val nullTerm = ctx.freshName("nullTerm")
6464
val primitiveTerm = ctx.freshName("primitiveTerm")
65-
val objectTerm = ctx.freshName("objectTerm")
66-
val ve = GeneratedExpressionCode("", nullTerm, primitiveTerm, objectTerm)
65+
val ve = GeneratedExpressionCode("", nullTerm, primitiveTerm)
6766
ve.code = genCode(ctx, ve)
6867
ve
6968
}
@@ -77,17 +76,18 @@ abstract class Expression extends TreeNode[Expression] {
7776
* @param ev an [[GeneratedExpressionCode]] with unique terms.
7877
* @return Java source code
7978
*/
80-
def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
79+
protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
8180
val e = this.asInstanceOf[Expression]
8281
ctx.references += e
82+
val objectTerm = ctx.freshName("obj")
8383
s"""
84-
/* expression: ${this} */
85-
Object ${ev.objectTerm} = expressions[${ctx.references.size - 1}].eval(i);
86-
boolean ${ev.nullTerm} = ${ev.objectTerm} == null;
87-
${ctx.primitiveType(e.dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(e.dataType)};
88-
if (!${ev.nullTerm}) {
89-
${ev.primitiveTerm} = (${ctx.boxedType(e.dataType)})${ev.objectTerm};
90-
}
84+
/* expression: ${this} */
85+
final Object ${objectTerm} = expressions[${ctx.references.size - 1}].eval(i);
86+
final boolean ${ev.nullTerm} = ${objectTerm} == null;
87+
${ctx.primitiveType(e.dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(e.dataType)};
88+
if (!${ev.nullTerm}) {
89+
${ev.primitiveTerm} = (${ctx.boxedType(e.dataType)})${objectTerm};
90+
}
9191
"""
9292
}
9393

@@ -167,7 +167,7 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
167167
protected def defineCodeGen(
168168
ctx: CodeGenContext,
169169
ev: GeneratedExpressionCode,
170-
f: (String, String) => String): String = {
170+
f: (Term, Term) => Code): String = {
171171
// TODO: Right now some timestamp tests fail if we enforce this...
172172
if (left.dataType != right.dataType) {
173173
// log.warn(s"${left.dataType} != ${right.dataType}")
@@ -214,10 +214,11 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio
214214
protected def defineCodeGen(
215215
ctx: CodeGenContext,
216216
ev: GeneratedExpressionCode,
217-
f: String => String): String = {
217+
f: Term => Code): Code = {
218218
val eval = child.gen(ctx)
219+
// reuse the previous nullTerm
220+
ev.nullTerm = eval.nullTerm
219221
eval.code + s"""
220-
boolean ${ev.nullTerm} = ${eval.nullTerm};
221222
${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)};
222223
if (!${ev.nullTerm}) {
223224
${ev.primitiveTerm} = ${f(eval.primitiveTerm)};

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 39 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,8 @@ class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long]
3939
* to null.
4040
* @param primitiveTerm A term for a possible primitive value of the result of the evaluation. Not
4141
* valid if `nullTerm` is set to `true`.
42-
* @param objectTerm A possibly boxed version of the result of evaluating this expression.
4342
*/
44-
case class GeneratedExpressionCode(var code: Code,
45-
nullTerm: Term,
46-
primitiveTerm: Term,
47-
objectTerm: Term)
43+
case class GeneratedExpressionCode(var code: Code, var nullTerm: Term, primitiveTerm: Term)
4844

4945
/**
5046
* A context for codegen, which is used to bookkeeping the expressions those are not supported
@@ -73,40 +69,44 @@ class CodeGenContext {
7369
s"$prefix${curId.getAndIncrement}"
7470
}
7571

72+
/**
73+
* Return the code to access a column for given DataType
74+
*/
7675
def getColumn(dataType: DataType, ordinal: Int): Code = {
77-
dataType match {
78-
case StringType => s"($stringType)i.apply($ordinal)"
79-
case dt: DataType if isNativeType(dt) => s"i.${accessorForType(dt)}($ordinal)"
80-
case _ => s"(${boxedType(dataType)})i.apply($ordinal)"
76+
if (isNativeType(dataType)) {
77+
s"i.${accessorForType(dataType)}($ordinal)"
78+
} else {
79+
s"(${boxedType(dataType)})i.apply($ordinal)"
8180
}
8281
}
8382

84-
def setColumn(destinationRow: Term, dataType: DataType, ordinal: Int, value: Term): Code = {
85-
dataType match {
86-
case StringType => s"$destinationRow.update($ordinal, $value)"
87-
case dt: DataType if isNativeType(dt) =>
88-
s"$destinationRow.${mutatorForType(dt)}($ordinal, $value)"
89-
case _ => s"$destinationRow.update($ordinal, $value)"
83+
/**
84+
* Return the code to update a column in Row for given DataType
85+
*/
86+
def setColumn(dataType: DataType, ordinal: Int, value: Term): Code = {
87+
if (isNativeType(dataType)) {
88+
s"${mutatorForType(dataType)}($ordinal, $value)"
89+
} else {
90+
s"update($ordinal, $value)"
9091
}
9192
}
9293

94+
/**
95+
* Return the name of accessor in Row for a DataType
96+
*/
9397
def accessorForType(dt: DataType): Term = dt match {
9498
case IntegerType => "getInt"
9599
case other => s"get${boxedType(dt)}"
96100
}
97101

102+
/**
103+
* Return the name of mutator in Row for a DataType
104+
*/
98105
def mutatorForType(dt: DataType): Term = dt match {
99106
case IntegerType => "setInt"
100107
case other => s"set${boxedType(dt)}"
101108
}
102109

103-
def hashSetForType(dt: DataType): Term = dt match {
104-
case IntegerType => classOf[IntegerHashSet].getName
105-
case LongType => classOf[LongHashSet].getName
106-
case unsupportedType =>
107-
sys.error(s"Code generation not support for hashset of type $unsupportedType")
108-
}
109-
110110
/**
111111
* Return the primitive type for a DataType
112112
*/
@@ -123,9 +123,26 @@ class CodeGenContext {
123123
case StringType => stringType
124124
case DateType => "int"
125125
case TimestampType => "java.sql.Timestamp"
126+
case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName
127+
case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName
126128
case _ => "Object"
127129
}
128130

131+
/**
132+
* Return the boxed type in Java
133+
*/
134+
def boxedType(dt: DataType): Term = dt match {
135+
case IntegerType => "Integer"
136+
case LongType => "Long"
137+
case ShortType => "Short"
138+
case ByteType => "Byte"
139+
case DoubleType => "Double"
140+
case FloatType => "Float"
141+
case BooleanType => "Boolean"
142+
case DateType => "Integer"
143+
case _ => primitiveType(dt)
144+
}
145+
129146
/**
130147
* Return the representation of default value for given DataType
131148
*/
@@ -138,30 +155,9 @@ class CodeGenContext {
138155
case DoubleType => "-1.0"
139156
case IntegerType => "-1"
140157
case DateType => "-1"
141-
case dt: DecimalType => "null"
142-
case StringType => "null"
143158
case _ => "null"
144159
}
145160

146-
/**
147-
* Return the boxed type in Java
148-
*/
149-
def boxedType(dt: DataType): Term = dt match {
150-
case IntegerType => "Integer"
151-
case LongType => "Long"
152-
case ShortType => "Short"
153-
case ByteType => "Byte"
154-
case DoubleType => "Double"
155-
case FloatType => "Float"
156-
case BooleanType => "Boolean"
157-
case dt: DecimalType => decimalType
158-
case BinaryType => "byte[]"
159-
case StringType => stringType
160-
case DateType => "Integer"
161-
case TimestampType => "java.sql.Timestamp"
162-
case _ => "Object"
163-
}
164-
165161
/**
166162
* Returns a function to generate equal expression in Java
167163
*/

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
4343
if(${evaluationCode.nullTerm})
4444
mutableRow.setNullAt($i);
4545
else
46-
${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitiveTerm)};
46+
mutableRow.${ctx.setColumn(e.dataType, i, evaluationCode.primitiveTerm)};
4747
"""
4848
}.mkString("\n")
4949
val code = s"""

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,11 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un
6767
val eval = child.gen(ctx)
6868
eval.code + s"""
6969
boolean ${ev.nullTerm} = ${eval.nullTerm};
70-
org.apache.spark.sql.types.Decimal ${ev.primitiveTerm} = ${ctx.defaultValue(DecimalType())};
70+
${ctx.decimalType} ${ev.primitiveTerm} = null;
7171

7272
if (!${ev.nullTerm}) {
73-
${ev.primitiveTerm} = new org.apache.spark.sql.types.Decimal();
74-
${ev.primitiveTerm} =
75-
${ev.primitiveTerm}.setOrNull(${eval.primitiveTerm}, $precision, $scale);
73+
${ev.primitiveTerm} = (new ${ctx.decimalType}()).setOrNull(
74+
${eval.primitiveTerm}, $precision, $scale);
7675
${ev.nullTerm} = ${ev.primitiveTerm} == null;
7776
}
7877
"""

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -85,34 +85,21 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres
8585
if (value == null) {
8686
s"""
8787
final boolean ${ev.nullTerm} = true;
88-
${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)};
88+
final ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)};
8989
"""
9090
} else {
91-
// TODO(cg): Add support for more data types.
9291
dataType match {
93-
case StringType =>
94-
val v = value.asInstanceOf[UTF8String]
95-
val arr = s"new byte[]{${v.getBytes.map(_.toString).mkString(", ")}}"
96-
s"""
97-
final boolean ${ev.nullTerm} = false;
98-
${ctx.stringType} ${ev.primitiveTerm} = new ${ctx.stringType}().set(${arr});
99-
"""
10092
case FloatType => // This must go before NumericType
10193
s"""
10294
final boolean ${ev.nullTerm} = false;
103-
float ${ev.primitiveTerm} = ${value}f;
104-
"""
105-
case dt: DecimalType => // This must go before NumericType
106-
s"""
107-
final boolean ${ev.nullTerm} = false;
108-
${ctx.primitiveType(dt)} ${ev.primitiveTerm} =
109-
new ${ctx.primitiveType(dt)}().set($value);
95+
final float ${ev.primitiveTerm} = ${value}f;
11096
"""
111-
case dt: NumericType =>
97+
case dt: NumericType if !dt.isInstanceOf[DecimalType]=>
11298
s"""
11399
final boolean ${ev.nullTerm} = false;
114-
${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = $value;
100+
final ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = $value;
115101
"""
102+
// eval() version may be faster for non-primitive types
116103
case other =>
117104
super.genCode(ctx, ev)
118105
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,7 @@ case class NewSet(elementType: DataType) extends LeafExpression {
6666
case IntegerType | LongType =>
6767
s"""
6868
boolean ${ev.nullTerm} = false;
69-
${ctx.hashSetForType(elementType)} ${ev.primitiveTerm} =
70-
new ${ctx.hashSetForType(elementType)}();
69+
${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = new ${ctx.primitiveType(dataType)}();
7170
"""
7271
case _ => super.genCode(ctx, ev)
7372
}
@@ -110,14 +109,14 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression {
110109
case IntegerType | LongType =>
111110
val itemEval = item.gen(ctx)
112111
val setEval = set.gen(ctx)
113-
val htype = ctx.hashSetForType(elementType)
112+
val htype = ctx.primitiveType(dataType)
114113

115114
itemEval.code + setEval.code + s"""
116-
if (!${itemEval.nullTerm} && !${setEval.nullTerm}) {
117-
(($htype)${setEval.primitiveTerm}).add(${itemEval.primitiveTerm});
118-
}
119-
boolean ${ev.nullTerm} = false;
120-
${htype} ${ev.primitiveTerm} = ($htype)${setEval.primitiveTerm};
115+
if (!${itemEval.nullTerm} && !${setEval.nullTerm}) {
116+
(($htype)${setEval.primitiveTerm}).add(${itemEval.primitiveTerm});
117+
}
118+
boolean ${ev.nullTerm} = false;
119+
${htype} ${ev.primitiveTerm} = ($htype)${setEval.primitiveTerm};
121120
"""
122121
case _ => super.genCode(ctx, ev)
123122
}
@@ -163,7 +162,7 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres
163162
case IntegerType | LongType =>
164163
val leftEval = left.gen(ctx)
165164
val rightEval = right.gen(ctx)
166-
val htype = ctx.hashSetForType(elementType)
165+
val htype = ctx.primitiveType(dataType)
167166

168167
leftEval.code + rightEval.code + s"""
169168
boolean ${ev.nullTerm} = false;

0 commit comments

Comments
 (0)