Skip to content

Commit f42c732

Browse files
author
Davies Liu
committed
improve coverage and tests
1 parent bad6828 commit f42c732

20 files changed

+440
-237
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
4545

4646
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
4747
s"""
48-
final boolean ${ev.nullTerm} = i.isNullAt($ordinal);
49-
final ${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ev.nullTerm} ?
48+
boolean ${ev.isNull} = i.isNullAt($ordinal);
49+
${ctx.primitiveType(dataType)} ${ev.primitive} = ${ev.isNull} ?
5050
${ctx.defaultValue(dataType)} : (${ctx.getColumn(dataType, ordinal)});
5151
"""
5252
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,9 @@ abstract class Expression extends TreeNode[Expression] {
6060
* @return [[GeneratedExpressionCode]]
6161
*/
6262
def gen(ctx: CodeGenContext): GeneratedExpressionCode = {
63-
val nullTerm = ctx.freshName("nullTerm")
64-
val primitiveTerm = ctx.freshName("primitiveTerm")
65-
val ve = GeneratedExpressionCode("", nullTerm, primitiveTerm)
63+
val isNull = ctx.freshName("isNull")
64+
val primitive = ctx.freshName("primitive")
65+
val ve = GeneratedExpressionCode("", isNull, primitive)
6666
ve.code = genCode(ctx, ve)
6767
ve
6868
}
@@ -82,11 +82,11 @@ abstract class Expression extends TreeNode[Expression] {
8282
val objectTerm = ctx.freshName("obj")
8383
s"""
8484
/* expression: ${this} */
85-
final Object ${objectTerm} = expressions[${ctx.references.size - 1}].eval(i);
86-
final boolean ${ev.nullTerm} = ${objectTerm} == null;
87-
${ctx.primitiveType(e.dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(e.dataType)};
88-
if (!${ev.nullTerm}) {
89-
${ev.primitiveTerm} = (${ctx.boxedType(e.dataType)})${objectTerm};
85+
Object ${objectTerm} = expressions[${ctx.references.size - 1}].eval(i);
86+
boolean ${ev.isNull} = ${objectTerm} == null;
87+
${ctx.primitiveType(e.dataType)} ${ev.primitive} = ${ctx.defaultValue(e.dataType)};
88+
if (!${ev.isNull}) {
89+
${ev.primitive} = (${ctx.boxedType(e.dataType)})${objectTerm};
9090
}
9191
"""
9292
}
@@ -175,18 +175,18 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
175175

176176
val eval1 = left.gen(ctx)
177177
val eval2 = right.gen(ctx)
178-
val resultCode = f(eval1.primitiveTerm, eval2.primitiveTerm)
178+
val resultCode = f(eval1.primitive, eval2.primitive)
179179

180180
s"""
181181
${eval1.code}
182-
boolean ${ev.nullTerm} = ${eval1.nullTerm};
183-
${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)};
184-
if (!${ev.nullTerm}) {
182+
boolean ${ev.isNull} = ${eval1.isNull};
183+
${ctx.primitiveType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
184+
if (!${ev.isNull}) {
185185
${eval2.code}
186-
if(!${eval2.nullTerm}) {
187-
${ev.primitiveTerm} = $resultCode;
186+
if(!${eval2.isNull}) {
187+
${ev.primitive} = $resultCode;
188188
} else {
189-
${ev.nullTerm} = true;
189+
${ev.isNull} = true;
190190
}
191191
}
192192
"""
@@ -216,12 +216,12 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio
216216
ev: GeneratedExpressionCode,
217217
f: Term => Code): Code = {
218218
val eval = child.gen(ctx)
219-
// reuse the previous nullTerm
220-
ev.nullTerm = eval.nullTerm
219+
// reuse the previous isNull
220+
ev.isNull = eval.isNull
221221
eval.code + s"""
222-
${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)};
223-
if (!${ev.nullTerm}) {
224-
${ev.primitiveTerm} = ${f(eval.primitiveTerm)};
222+
${ctx.primitiveType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
223+
if (!${ev.isNull}) {
224+
${ev.primitive} = ${f(eval.primitive)};
225225
}
226226
"""
227227
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 60 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,11 @@ case class UnaryMinus(child: Expression) extends UnaryArithmetic {
5050

5151
private lazy val numeric = TypeUtils.getNumeric(dataType)
5252

53+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = dataType match {
54+
case dt: DecimalType => defineCodeGen(ctx, ev, c => s"c.unary_$$minus()")
55+
case dt: NumericType => defineCodeGen(ctx, ev, c => s"-($c)")
56+
}
57+
5358
protected override def evalInternal(evalE: Any) = numeric.negate(evalE)
5459
}
5560

@@ -68,6 +73,21 @@ case class Sqrt(child: Expression) extends UnaryArithmetic {
6873
if (value < 0) null
6974
else math.sqrt(value)
7075
}
76+
77+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
78+
val eval = child.gen(ctx)
79+
eval.code + s"""
80+
boolean ${ev.isNull} = ${eval.isNull};
81+
${ctx.primitiveType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
82+
if (!${ev.isNull}) {
83+
if (${eval.primitive} < 0.0) {
84+
${ev.isNull} = true;
85+
} else {
86+
${ev.primitive} = java.lang.Math.sqrt(${eval.primitive});
87+
}
88+
}
89+
"""
90+
}
7191
}
7292

7393
/**
@@ -216,9 +236,9 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
216236
val eval1 = left.gen(ctx)
217237
val eval2 = right.gen(ctx)
218238
val test = if (left.dataType.isInstanceOf[DecimalType]) {
219-
s"${eval2.primitiveTerm}.isZero()"
239+
s"${eval2.primitive}.isZero()"
220240
} else {
221-
s"${eval2.primitiveTerm} == 0"
241+
s"${eval2.primitive} == 0"
222242
}
223243
val method = if (left.dataType.isInstanceOf[DecimalType]) {
224244
s".$decimalMethod"
@@ -227,12 +247,12 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
227247
}
228248
eval1.code + eval2.code +
229249
s"""
230-
boolean ${ev.nullTerm} = false;
231-
${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(left.dataType)};
232-
if (${eval1.nullTerm} || ${eval2.nullTerm} || $test) {
233-
${ev.nullTerm} = true;
250+
boolean ${ev.isNull} = false;
251+
${ctx.primitiveType(left.dataType)} ${ev.primitive} = ${ctx.defaultValue(left.dataType)};
252+
if (${eval1.isNull} || ${eval2.isNull} || $test) {
253+
${ev.isNull} = true;
234254
} else {
235-
${ev.primitiveTerm} = ${eval1.primitiveTerm}$method(${eval2.primitiveTerm});
255+
${ev.primitive} = ${eval1.primitive}$method(${eval2.primitive});
236256
}
237257
"""
238258
}
@@ -276,9 +296,9 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
276296
val eval1 = left.gen(ctx)
277297
val eval2 = right.gen(ctx)
278298
val test = if (left.dataType.isInstanceOf[DecimalType]) {
279-
s"${eval2.primitiveTerm}.isZero()"
299+
s"${eval2.primitive}.isZero()"
280300
} else {
281-
s"${eval2.primitiveTerm} == 0"
301+
s"${eval2.primitive} == 0"
282302
}
283303
val method = if (left.dataType.isInstanceOf[DecimalType]) {
284304
s".$decimalMethod"
@@ -287,12 +307,12 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
287307
}
288308
eval1.code + eval2.code +
289309
s"""
290-
boolean ${ev.nullTerm} = false;
291-
${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(left.dataType)};
292-
if (${eval1.nullTerm} || ${eval2.nullTerm} || $test) {
293-
${ev.nullTerm} = true;
310+
boolean ${ev.isNull} = false;
311+
${ctx.primitiveType(left.dataType)} ${ev.primitive} = ${ctx.defaultValue(left.dataType)};
312+
if (${eval1.isNull} || ${eval2.isNull} || $test) {
313+
${ev.isNull} = true;
294314
} else {
295-
${ev.primitiveTerm} = ${eval1.primitiveTerm}$method(${eval2.primitiveTerm});
315+
${ev.primitive} = ${eval1.primitive}$method(${eval2.primitive});
296316
}
297317
"""
298318
}
@@ -387,6 +407,10 @@ case class BitwiseNot(child: Expression) extends UnaryArithmetic {
387407
((evalE: Long) => ~evalE).asInstanceOf[(Any) => Any]
388408
}
389409

410+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
411+
defineCodeGen(ctx, ev, c => s"(${ctx.primitiveType(dataType)})~($c)")
412+
}
413+
390414
protected override def evalInternal(evalE: Any) = not(evalE)
391415
}
392416

@@ -419,21 +443,21 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
419443
val eval1 = left.gen(ctx)
420444
val eval2 = right.gen(ctx)
421445
eval1.code + eval2.code + s"""
422-
boolean ${ev.nullTerm} = false;
423-
${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} =
446+
boolean ${ev.isNull} = false;
447+
${ctx.primitiveType(left.dataType)} ${ev.primitive} =
424448
${ctx.defaultValue(left.dataType)};
425449

426-
if (${eval1.nullTerm}) {
427-
${ev.nullTerm} = ${eval2.nullTerm};
428-
${ev.primitiveTerm} = ${eval2.primitiveTerm};
429-
} else if (${eval2.nullTerm}) {
430-
${ev.nullTerm} = ${eval1.nullTerm};
431-
${ev.primitiveTerm} = ${eval1.primitiveTerm};
450+
if (${eval1.isNull}) {
451+
${ev.isNull} = ${eval2.isNull};
452+
${ev.primitive} = ${eval2.primitive};
453+
} else if (${eval2.isNull}) {
454+
${ev.isNull} = ${eval1.isNull};
455+
${ev.primitive} = ${eval1.primitive};
432456
} else {
433-
if (${eval1.primitiveTerm} > ${eval2.primitiveTerm}) {
434-
${ev.primitiveTerm} = ${eval1.primitiveTerm};
457+
if (${eval1.primitive} > ${eval2.primitive}) {
458+
${ev.primitive} = ${eval1.primitive};
435459
} else {
436-
${ev.primitiveTerm} = ${eval2.primitiveTerm};
460+
${ev.primitive} = ${eval2.primitive};
437461
}
438462
}
439463
"""
@@ -475,21 +499,21 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic {
475499
val eval2 = right.gen(ctx)
476500

477501
eval1.code + eval2.code + s"""
478-
boolean ${ev.nullTerm} = false;
479-
${ctx.primitiveType(left.dataType)} ${ev.primitiveTerm} =
502+
boolean ${ev.isNull} = false;
503+
${ctx.primitiveType(left.dataType)} ${ev.primitive} =
480504
${ctx.defaultValue(left.dataType)};
481505

482-
if (${eval1.nullTerm}) {
483-
${ev.nullTerm} = ${eval2.nullTerm};
484-
${ev.primitiveTerm} = ${eval2.primitiveTerm};
485-
} else if (${eval2.nullTerm}) {
486-
${ev.nullTerm} = ${eval1.nullTerm};
487-
${ev.primitiveTerm} = ${eval1.primitiveTerm};
506+
if (${eval1.isNull}) {
507+
${ev.isNull} = ${eval2.isNull};
508+
${ev.primitive} = ${eval2.primitive};
509+
} else if (${eval2.isNull}) {
510+
${ev.isNull} = ${eval1.isNull};
511+
${ev.primitive} = ${eval1.primitive};
488512
} else {
489-
if (${eval1.primitiveTerm} < ${eval2.primitiveTerm}) {
490-
${ev.primitiveTerm} = ${eval1.primitiveTerm};
513+
if (${eval1.primitive} < ${eval2.primitive}) {
514+
${ev.primitive} = ${eval1.primitive};
491515
} else {
492-
${ev.primitiveTerm} = ${eval2.primitiveTerm};
516+
${ev.primitive} = ${eval2.primitive};
493517
}
494518
}
495519
"""

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,12 @@ class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long]
3535
* Java source for evaluating an [[Expression]] given a [[Row]] of input.
3636
*
3737
* @param code The sequence of statements required to evaluate the expression.
38-
* @param nullTerm A term that holds a boolean value representing whether the expression evaluated
38+
* @param isNull A term that holds a boolean value representing whether the expression evaluated
3939
* to null.
40-
* @param primitiveTerm A term for a possible primitive value of the result of the evaluation. Not
41-
* valid if `nullTerm` is set to `true`.
40+
* @param primitive A term for a possible primitive value of the result of the evaluation. Not
41+
* valid if `isNull` is set to `true`.
4242
*/
43-
case class GeneratedExpressionCode(var code: Code, var nullTerm: Term, var primitiveTerm: Term)
43+
case class GeneratedExpressionCode(var code: Code, var isNull: Term, var primitive: Term)
4444

4545
/**
4646
* A context for codegen, which is used to bookkeeping the expressions those are not supported
@@ -149,9 +149,9 @@ class CodeGenContext {
149149
def defaultValue(dt: DataType): Term = dt match {
150150
case BooleanType => "false"
151151
case FloatType => "-1.0f"
152-
case ShortType => "-1"
153-
case LongType => "-1"
154-
case ByteType => "-1"
152+
case ShortType => "(short)-1"
153+
case LongType => "-1L"
154+
case ByteType => "(byte)-1"
155155
case DoubleType => "-1.0"
156156
case IntegerType => "-1"
157157
case DateType => "-1"

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,10 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
4040
val evaluationCode = e.gen(ctx)
4141
evaluationCode.code +
4242
s"""
43-
if(${evaluationCode.nullTerm})
43+
if(${evaluationCode.isNull})
4444
mutableRow.setNullAt($i);
4545
else
46-
mutableRow.${ctx.setColumn(e.dataType, i, evaluationCode.primitiveTerm)};
46+
mutableRow.${ctx.setColumn(e.dataType, i, evaluationCode.primitive)};
4747
"""
4848
}.mkString("\n")
4949
val code = s"""

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit
5959
case BinaryType =>
6060
s"""
6161
{
62-
byte[] x = ${if (asc) evalA.primitiveTerm else evalB.primitiveTerm};
63-
byte[] y = ${if (!asc) evalB.primitiveTerm else evalA.primitiveTerm};
62+
byte[] x = ${if (asc) evalA.primitive else evalB.primitive};
63+
byte[] y = ${if (!asc) evalB.primitive else evalA.primitive};
6464
int j = 0;
6565
while (j < x.length && j < y.length) {
6666
if (x[j] != y[j]) return x[j] - y[j];
@@ -73,16 +73,16 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit
7373
}"""
7474
case _: NumericType =>
7575
s"""
76-
if (${evalA.primitiveTerm} != ${evalB.primitiveTerm}) {
77-
if (${evalA.primitiveTerm} > ${evalB.primitiveTerm}) {
76+
if (${evalA.primitive} != ${evalB.primitive}) {
77+
if (${evalA.primitive} > ${evalB.primitive}) {
7878
return ${if (asc) "1" else "-1"};
7979
} else {
8080
return ${if (asc) "-1" else "1"};
8181
}
8282
}"""
8383
case _ =>
8484
s"""
85-
int comp = ${evalA.primitiveTerm}.compare(${evalB.primitiveTerm});
85+
int comp = ${evalA.primitive}.compare(${evalB.primitive});
8686
if (comp != 0) {
8787
return ${if (asc) "comp" else "-comp"};
8888
}"""
@@ -93,11 +93,11 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit
9393
${evalA.code}
9494
i = $b;
9595
${evalB.code}
96-
if (${evalA.nullTerm} && ${evalB.nullTerm}) {
96+
if (${evalA.isNull} && ${evalB.isNull}) {
9797
// Nothing
98-
} else if (${evalA.nullTerm}) {
98+
} else if (${evalA.isNull}) {
9999
return ${if (order.direction == Ascending) "-1" else "1"};
100-
} else if (${evalB.nullTerm}) {
100+
} else if (${evalB.isNull}) {
101101
return ${if (order.direction == Ascending) "1" else "-1"};
102102
} else {
103103
$compare

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ object GeneratePredicate extends CodeGenerator[Expression, (Row) => Boolean] {
5555
@Override
5656
public boolean eval(Row i) {
5757
${eval.code}
58-
return !${eval.nullTerm} && ${eval.primitiveTerm};
58+
return !${eval.isNull} && ${eval.primitive};
5959
}
6060
}"""
6161

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,9 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
5555
{
5656
// column$i
5757
${eval.code}
58-
nullBits[$i] = ${eval.nullTerm};
59-
if(!${eval.nullTerm}) {
60-
c$i = ${eval.primitiveTerm};
58+
nullBits[$i] = ${eval.isNull};
59+
if (!${eval.isNull}) {
60+
c$i = ${eval.primitive};
6161
}
6262
}
6363
"""
@@ -122,7 +122,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
122122
case LongType => s"$col ^ ($col >>> 32)"
123123
case FloatType => s"Float.floatToIntBits($col)"
124124
case DoubleType =>
125-
s"Double.doubleToLongBits($col) ^ (Double.doubleToLongBits($col) >>> 32)"
125+
s"(int)(Double.doubleToLongBits($col) ^ (Double.doubleToLongBits($col) >>> 32))"
126126
case _ => s"$col.hashCode()"
127127
}
128128
s"isNullAt($i) ? 0 : ($nonNull)"

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,13 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un
6262
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
6363
val eval = child.gen(ctx)
6464
eval.code + s"""
65-
boolean ${ev.nullTerm} = ${eval.nullTerm};
66-
${ctx.decimalType} ${ev.primitiveTerm} = null;
65+
boolean ${ev.isNull} = ${eval.isNull};
66+
${ctx.decimalType} ${ev.primitive} = null;
6767

68-
if (!${ev.nullTerm}) {
69-
${ev.primitiveTerm} = (new ${ctx.decimalType}()).setOrNull(
70-
${eval.primitiveTerm}, $precision, $scale);
71-
${ev.nullTerm} = ${ev.primitiveTerm} == null;
68+
if (!${ev.isNull}) {
69+
${ev.primitive} = (new ${ctx.decimalType}()).setOrNull(
70+
${eval.primitive}, $precision, $scale);
71+
${ev.isNull} = ${ev.primitive} == null;
7272
}
7373
"""
7474
}

0 commit comments

Comments
 (0)