Skip to content

Commit 9cc74f9

Browse files
sameeragarwalrxin
authored andcommitted
[SPARK-16488] Fix codegen variable namespace collision in pmod and partitionBy
## What changes were proposed in this pull request? This patch fixes a variable namespace collision bug in pmod and partitionBy ## How was this patch tested? Regression test for one possible occurrence. A more general fix in `ExpressionEvalHelper.checkEvaluation` will be in a subsequent PR. Author: Sameer Agarwal <sameer@databricks.com> Closes apache#14144 from sameeragarwal/codegen-bug.
1 parent e50efd5 commit 9cc74f9

File tree

2 files changed

+27
-12
lines changed

2 files changed

+27
-12
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -498,34 +498,35 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic wi
498498

499499
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
500500
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
501+
val remainder = ctx.freshName("remainder")
501502
dataType match {
502503
case dt: DecimalType =>
503504
val decimalAdd = "$plus"
504505
s"""
505-
${ctx.javaType(dataType)} r = $eval1.remainder($eval2);
506-
if (r.compare(new org.apache.spark.sql.types.Decimal().set(0)) < 0) {
507-
${ev.value} = (r.$decimalAdd($eval2)).remainder($eval2);
506+
${ctx.javaType(dataType)} $remainder = $eval1.remainder($eval2);
507+
if ($remainder.compare(new org.apache.spark.sql.types.Decimal().set(0)) < 0) {
508+
${ev.value} = ($remainder.$decimalAdd($eval2)).remainder($eval2);
508509
} else {
509-
${ev.value} = r;
510+
${ev.value} = $remainder;
510511
}
511512
"""
512513
// byte and short are casted into int when add, minus, times or divide
513514
case ByteType | ShortType =>
514515
s"""
515-
${ctx.javaType(dataType)} r = (${ctx.javaType(dataType)})($eval1 % $eval2);
516-
if (r < 0) {
517-
${ev.value} = (${ctx.javaType(dataType)})((r + $eval2) % $eval2);
516+
${ctx.javaType(dataType)} $remainder = (${ctx.javaType(dataType)})($eval1 % $eval2);
517+
if ($remainder < 0) {
518+
${ev.value} = (${ctx.javaType(dataType)})(($remainder + $eval2) % $eval2);
518519
} else {
519-
${ev.value} = r;
520+
${ev.value} = $remainder;
520521
}
521522
"""
522523
case _ =>
523524
s"""
524-
${ctx.javaType(dataType)} r = $eval1 % $eval2;
525-
if (r < 0) {
526-
${ev.value} = (r + $eval2) % $eval2;
525+
${ctx.javaType(dataType)} $remainder = $eval1 % $eval2;
526+
if ($remainder < 0) {
527+
${ev.value} = ($remainder + $eval2) % $eval2;
527528
} else {
528-
${ev.value} = r;
529+
${ev.value} = $remainder;
529530
}
530531
"""
531532
}

sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,20 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be
449449
}
450450
}
451451

452+
test("pmod with partitionBy") {
453+
val spark = this.spark
454+
import spark.implicits._
455+
456+
case class Test(a: Int, b: String)
457+
val data = Seq((0, "a"), (1, "b"), (1, "a"))
458+
spark.createDataset(data).createOrReplaceTempView("test")
459+
sql("select * from test distribute by pmod(_1, 2)")
460+
.write
461+
.partitionBy("_2")
462+
.mode("overwrite")
463+
.parquet(dir)
464+
}
465+
452466
private def testRead(
453467
df: => DataFrame,
454468
expectedResult: Seq[String],

0 commit comments

Comments
 (0)