Skip to content

Commit bad6828

Browse files
author
Davies Liu
committed
address comments
1 parent e03edaa commit bad6828

File tree

3 files changed

+30
-29
lines changed

3 files changed

+30
-29
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,7 @@ case class UnscaledValue(child: Expression) extends UnaryExpression {
3838
}
3939

4040
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
41-
val eval = child.gen(ctx)
42-
eval.code + s"""
43-
boolean ${ev.nullTerm} = ${eval.nullTerm};
44-
long ${ev.primitiveTerm} = ${ev.nullTerm} ? -1 : ${eval.primitiveTerm}.toUnscaledLong();
45-
"""
41+
defineCodeGen(ctx, ev, c => s"$c.toUnscaledLong()")
4642
}
4743
}
4844

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -146,9 +146,12 @@ case class And(left: Expression, right: Expression)
146146
}
147147
}
148148
}
149+
149150
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
150151
val eval1 = left.gen(ctx)
151152
val eval2 = right.gen(ctx)
153+
154+
// The result should be `false`, if any of them is `false` whenever the other is null or not.
152155
s"""
153156
${eval1.code}
154157
boolean ${ev.nullTerm} = false;
@@ -192,20 +195,21 @@ case class Or(left: Expression, right: Expression)
192195
}
193196
}
194197
}
198+
195199
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
196200
val eval1 = left.gen(ctx)
197201
val eval2 = right.gen(ctx)
202+
203+
// The result should be `true`, if any of them is `true` whenever the other is null or not.
198204
s"""
199205
${eval1.code}
200206
boolean ${ev.nullTerm} = false;
201-
boolean ${ev.primitiveTerm} = false;
207+
boolean ${ev.primitiveTerm} = true;
202208

203209
if (!${eval1.nullTerm} && ${eval1.primitiveTerm}) {
204-
${ev.primitiveTerm} = true;
205210
} else {
206211
${eval2.code}
207212
if (!${eval2.nullTerm} && ${eval2.primitiveTerm}) {
208-
${ev.primitiveTerm} = true;
209213
} else if (!${eval1.nullTerm} && !${eval2.nullTerm}) {
210214
${ev.primitiveTerm} = false;
211215
} else {
@@ -218,19 +222,6 @@ case class Or(left: Expression, right: Expression)
218222

219223
abstract class BinaryComparison extends BinaryExpression with Predicate {
220224
self: Product =>
221-
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
222-
left.dataType match {
223-
case dt: NumericType if ctx.isNativeType(dt) => defineCodeGen (ctx, ev, {
224-
(c1, c3) => s"$c1 $symbol $c3"
225-
})
226-
case TimestampType =>
227-
// java.sql.Timestamp does not have compare()
228-
super.genCode(ctx, ev)
229-
case other => defineCodeGen (ctx, ev, {
230-
(c1, c2) => s"$c1.compare($c2) $symbol 0"
231-
})
232-
}
233-
}
234225

235226
override def checkInputDataTypes(): TypeCheckResult = {
236227
if (left.dataType != right.dataType) {
@@ -258,6 +249,20 @@ abstract class BinaryComparison extends BinaryExpression with Predicate {
258249
}
259250
}
260251

252+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = {
253+
left.dataType match {
254+
case dt: NumericType if ctx.isNativeType(dt) => defineCodeGen (ctx, ev, {
255+
(c1, c3) => s"$c1 $symbol $c3"
256+
})
257+
case TimestampType =>
258+
// java.sql.Timestamp does not have compare()
259+
super.genCode(ctx, ev)
260+
case other => defineCodeGen (ctx, ev, {
261+
(c1, c2) => s"$c1.compare($c2) $symbol 0"
262+
})
263+
}
264+
}
265+
261266
protected def evalInternal(evalE1: Any, evalE2: Any): Any =
262267
sys.error(s"BinaryComparisons must override either eval or evalInternal")
263268
}
@@ -389,9 +394,9 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
389394
val falseEval = falseValue.gen(ctx)
390395

391396
s"""
397+
${condEval.code}
392398
boolean ${ev.nullTerm} = false;
393399
${ctx.primitiveType(dataType)} ${ev.primitiveTerm} = ${ctx.defaultValue(dataType)};
394-
${condEval.code}
395400
if (!${condEval.nullTerm} && ${condEval.primitiveTerm}) {
396401
${trueEval.code}
397402
${ev.nullTerm} = ${trueEval.nullTerm};

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -112,11 +112,11 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression {
112112
val htype = ctx.primitiveType(dataType)
113113

114114
ev.nullTerm = "false"
115+
ev.primitiveTerm = setEval.primitiveTerm
115116
itemEval.code + setEval.code + s"""
116117
if (!${itemEval.nullTerm} && !${setEval.nullTerm}) {
117118
(($htype)${setEval.primitiveTerm}).add(${itemEval.primitiveTerm});
118119
}
119-
${htype} ${ev.primitiveTerm} = ($htype)${setEval.primitiveTerm};
120120
"""
121121
case _ => super.genCode(ctx, ev)
122122
}
@@ -147,10 +147,8 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres
147147
val rightValue = iterator.next()
148148
leftEval.add(rightValue)
149149
}
150-
leftEval
151-
} else {
152-
null
153150
}
151+
leftEval
154152
} else {
155153
null
156154
}
@@ -164,10 +162,12 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres
164162
val rightEval = right.gen(ctx)
165163
val htype = ctx.primitiveType(dataType)
166164

167-
ev.nullTerm = "false"
165+
ev.nullTerm = leftEval.nullTerm
166+
ev.primitiveTerm = leftEval.primitiveTerm
168167
leftEval.code + rightEval.code + s"""
169-
${htype} ${ev.primitiveTerm} = (${htype})${leftEval.primitiveTerm};
170-
${ev.primitiveTerm}.union((${htype})${rightEval.primitiveTerm});
168+
if (!${leftEval.nullTerm} && !${rightEval.nullTerm}) {
169+
${leftEval.primitiveTerm}.union((${htype})${rightEval.primitiveTerm});
170+
}
171171
"""
172172
case _ => super.genCode(ctx, ev)
173173
}

0 commit comments

Comments
 (0)