Skip to content

Commit

Permalink
[SPARK-34829][SQL] Fix higher order function results
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
This PR fixes a correctness issue with higher order functions. The results of function expressions needs to be copied in some higher order functions as such an expression can return with internal buffers and higher order functions can call multiple times the expression.
The issue was discovered with typed `ScalaUDF`s after apache#28979.

### Why are the changes needed?
To fix a bug.

### Does this PR introduce _any_ user-facing change?
Yes, some queries return the right results again.

### How was this patch tested?
Added new UT.

Closes apache#31955 from peter-toth/SPARK-34829-fix-scalaudf-resultconversion.

Authored-by: Peter Toth <peter.toth@gmail.com>
Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
  • Loading branch information
peter-toth authored and dongjoon-hyun committed Mar 28, 2021
1 parent 540f1fb commit 3382190
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,8 @@ case class ArrayTransform(
if (indexVar.isDefined) {
indexVar.get.value.set(i)
}
result.update(i, f.eval(inputRow))
val v = InternalRow.copyValue(f.eval(inputRow))
result.update(i, v)
i += 1
}
result
Expand Down Expand Up @@ -805,7 +806,7 @@ case class TransformKeys(
while (i < map.numElements) {
keyVar.value.set(map.keyArray().get(i, keyVar.dataType))
valueVar.value.set(map.valueArray().get(i, valueVar.dataType))
val result = functionForEval.eval(inputRow)
val result = InternalRow.copyValue(functionForEval.eval(inputRow))
resultKeys.update(i, result)
i += 1
}
Expand Down Expand Up @@ -853,7 +854,8 @@ case class TransformValues(
while (i < map.numElements) {
keyVar.value.set(map.keyArray().get(i, keyVar.dataType))
valueVar.value.set(map.valueArray().get(i, valueVar.dataType))
resultValues.update(i, functionForEval.eval(inputRow))
val v = InternalRow.copyValue(functionForEval.eval(inputRow))
resultValues.update(i, v)
i += 1
}
new ArrayBasedMapData(map.keyArray(), resultValues)
Expand Down Expand Up @@ -1035,7 +1037,8 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression)
value1Var.value.set(v1)
value2Var.value.set(v2)
keys.update(i, key)
values.update(i, functionForEval.eval(inputRow))
val v = InternalRow.copyValue(functionForEval.eval(inputRow))
values.update(i, v)
i += 1
}
new ArrayBasedMapData(keys, values)
Expand Down Expand Up @@ -1108,7 +1111,8 @@ case class ZipWith(left: Expression, right: Expression, function: Expression)
} else {
rightElemVar.value.set(null)
}
result.update(i, f.eval(input))
val v = InternalRow.copyValue(f.eval(input))
result.update(i, v)
i += 1
}
result
Expand Down
56 changes: 56 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2780,6 +2780,62 @@ class DataFrameSuite extends QueryTest
)
checkAnswer(test.select($"best_name.name"), Row("bob") :: Row("bob") :: Row("sam") :: Nil)
}

test("SPARK-34829: Multiple applications of typed ScalaUDFs in higher order functions work") {
val reverse = udf((s: String) => s.reverse)
val reverse2 = udf((b: Bar2) => Bar2(b.s.reverse))

val df = Seq(Array("abc", "def")).toDF("array")
val test = df.select(transform(col("array"), s => reverse(s)))
checkAnswer(test, Row(Array("cba", "fed")) :: Nil)

val df2 = Seq(Array(Bar2("abc"), Bar2("def"))).toDF("array")
val test2 = df2.select(transform(col("array"), b => reverse2(b)))
checkAnswer(test2, Row(Array(Row("cba"), Row("fed"))) :: Nil)

val df3 = Seq(Map("abc" -> 1, "def" -> 2)).toDF("map")
val test3 = df3.select(transform_keys(col("map"), (s, _) => reverse(s)))
checkAnswer(test3, Row(Map("cba" -> 1, "fed" -> 2)) :: Nil)

val df4 = Seq(Map(Bar2("abc") -> 1, Bar2("def") -> 2)).toDF("map")
val test4 = df4.select(transform_keys(col("map"), (b, _) => reverse2(b)))
checkAnswer(test4, Row(Map(Row("cba") -> 1, Row("fed") -> 2)) :: Nil)

val df5 = Seq(Map(1 -> "abc", 2 -> "def")).toDF("map")
val test5 = df5.select(transform_values(col("map"), (_, s) => reverse(s)))
checkAnswer(test5, Row(Map(1 -> "cba", 2 -> "fed")) :: Nil)

val df6 = Seq(Map(1 -> Bar2("abc"), 2 -> Bar2("def"))).toDF("map")
val test6 = df6.select(transform_values(col("map"), (_, b) => reverse2(b)))
checkAnswer(test6, Row(Map(1 -> Row("cba"), 2 -> Row("fed"))) :: Nil)

val reverseThenConcat = udf((s1: String, s2: String) => s1.reverse ++ s2.reverse)
val reverseThenConcat2 = udf((b1: Bar2, b2: Bar2) => Bar2(b1.s.reverse ++ b2.s.reverse))

val df7 = Seq((Map(1 -> "abc", 2 -> "def"), Map(1 -> "ghi", 2 -> "jkl"))).toDF("map1", "map2")
val test7 =
df7.select(map_zip_with(col("map1"), col("map2"), (_, s1, s2) => reverseThenConcat(s1, s2)))
checkAnswer(test7, Row(Map(1 -> "cbaihg", 2 -> "fedlkj")) :: Nil)

val df8 = Seq((Map(1 -> Bar2("abc"), 2 -> Bar2("def")),
Map(1 -> Bar2("ghi"), 2 -> Bar2("jkl")))).toDF("map1", "map2")
val test8 =
df8.select(map_zip_with(col("map1"), col("map2"), (_, b1, b2) => reverseThenConcat2(b1, b2)))
checkAnswer(test8, Row(Map(1 -> Row("cbaihg"), 2 -> Row("fedlkj"))) :: Nil)

val df9 = Seq((Array("abc", "def"), Array("ghi", "jkl"))).toDF("array1", "array2")
val test9 =
df9.select(zip_with(col("array1"), col("array2"), (s1, s2) => reverseThenConcat(s1, s2)))
checkAnswer(test9, Row(Array("cbaihg", "fedlkj")) :: Nil)

val df10 = Seq((Array(Bar2("abc"), Bar2("def")), Array(Bar2("ghi"), Bar2("jkl"))))
.toDF("array1", "array2")
val test10 =
df10.select(zip_with(col("array1"), col("array2"), (b1, b2) => reverseThenConcat2(b1, b2)))
checkAnswer(test10, Row(Array(Row("cbaihg"), Row("fedlkj"))) :: Nil)
}
}

case class GroupByKey(a: Int, b: Int)

case class Bar2(s: String)

0 comments on commit 3382190

Please sign in to comment.