Skip to content

Commit 6892614

Browse files
sameeragarwalrxin
authored andcommitted
[SPARK-16488] Fix codegen variable namespace collision in pmod and partitionBy
This patch fixes a variable namespace collision bug in pmod and partitionBy Regression test for one possible occurrence. A more general fix in `ExpressionEvalHelper.checkEvaluation` will be in a subsequent PR. Author: Sameer Agarwal <sameer@databricks.com> Closes #14144 from sameeragarwal/codegen-bug. (cherry picked from commit 9cc74f9) Signed-off-by: Reynold Xin <rxin@databricks.com>
1 parent b37177c commit 6892614

File tree

2 files changed

+27
-12
lines changed

2 files changed

+27
-12
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -498,34 +498,35 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic wi
498498

499499
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
500500
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
501+
val remainder = ctx.freshName("remainder")
501502
dataType match {
502503
case dt: DecimalType =>
503504
val decimalAdd = "$plus"
504505
s"""
505-
${ctx.javaType(dataType)} r = $eval1.remainder($eval2);
506-
if (r.compare(new org.apache.spark.sql.types.Decimal().set(0)) < 0) {
507-
${ev.value} = (r.$decimalAdd($eval2)).remainder($eval2);
506+
${ctx.javaType(dataType)} $remainder = $eval1.remainder($eval2);
507+
if ($remainder.compare(new org.apache.spark.sql.types.Decimal().set(0)) < 0) {
508+
${ev.value} = ($remainder.$decimalAdd($eval2)).remainder($eval2);
508509
} else {
509-
${ev.value} = r;
510+
${ev.value} = $remainder;
510511
}
511512
"""
512513
// byte and short are casted into int when add, minus, times or divide
513514
case ByteType | ShortType =>
514515
s"""
515-
${ctx.javaType(dataType)} r = (${ctx.javaType(dataType)})($eval1 % $eval2);
516-
if (r < 0) {
517-
${ev.value} = (${ctx.javaType(dataType)})((r + $eval2) % $eval2);
516+
${ctx.javaType(dataType)} $remainder = (${ctx.javaType(dataType)})($eval1 % $eval2);
517+
if ($remainder < 0) {
518+
${ev.value} = (${ctx.javaType(dataType)})(($remainder + $eval2) % $eval2);
518519
} else {
519-
${ev.value} = r;
520+
${ev.value} = $remainder;
520521
}
521522
"""
522523
case _ =>
523524
s"""
524-
${ctx.javaType(dataType)} r = $eval1 % $eval2;
525-
if (r < 0) {
526-
${ev.value} = (r + $eval2) % $eval2;
525+
${ctx.javaType(dataType)} $remainder = $eval1 % $eval2;
526+
if ($remainder < 0) {
527+
${ev.value} = ($remainder + $eval2) % $eval2;
527528
} else {
528-
${ev.value} = r;
529+
${ev.value} = $remainder;
529530
}
530531
"""
531532
}

sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,20 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be
424424
spark.range(10).write.orc(dir)
425425
}
426426

427+
test("pmod with partitionBy") {
428+
val spark = this.spark
429+
import spark.implicits._
430+
431+
case class Test(a: Int, b: String)
432+
val data = Seq((0, "a"), (1, "b"), (1, "a"))
433+
spark.createDataset(data).createOrReplaceTempView("test")
434+
sql("select * from test distribute by pmod(_1, 2)")
435+
.write
436+
.partitionBy("_2")
437+
.mode("overwrite")
438+
.parquet(dir)
439+
}
440+
427441
private def testRead(
428442
df: => DataFrame,
429443
expectedResult: Seq[String],

0 commit comments

Comments
 (0)