diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 0c39c83821d6e..a8df4fb6cd9f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -279,7 +279,8 @@ case class ArrayTransform( if (indexVar.isDefined) { indexVar.get.value.set(i) } - result.update(i, f.eval(inputRow)) + val v = InternalRow.copyValue(f.eval(inputRow)) + result.update(i, v) i += 1 } result @@ -805,7 +806,7 @@ case class TransformKeys( while (i < map.numElements) { keyVar.value.set(map.keyArray().get(i, keyVar.dataType)) valueVar.value.set(map.valueArray().get(i, valueVar.dataType)) - val result = functionForEval.eval(inputRow) + val result = InternalRow.copyValue(functionForEval.eval(inputRow)) resultKeys.update(i, result) i += 1 } @@ -853,7 +854,8 @@ case class TransformValues( while (i < map.numElements) { keyVar.value.set(map.keyArray().get(i, keyVar.dataType)) valueVar.value.set(map.valueArray().get(i, valueVar.dataType)) - resultValues.update(i, functionForEval.eval(inputRow)) + val v = InternalRow.copyValue(functionForEval.eval(inputRow)) + resultValues.update(i, v) i += 1 } new ArrayBasedMapData(map.keyArray(), resultValues) @@ -1035,7 +1037,8 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression) value1Var.value.set(v1) value2Var.value.set(v2) keys.update(i, key) - values.update(i, functionForEval.eval(inputRow)) + val v = InternalRow.copyValue(functionForEval.eval(inputRow)) + values.update(i, v) i += 1 } new ArrayBasedMapData(keys, values) @@ -1108,7 +1111,8 @@ case class ZipWith(left: Expression, right: Expression, function: Expression) } else { rightElemVar.value.set(null) } - result.update(i, f.eval(input)) + val v = InternalRow.copyValue(f.eval(input)) + result.update(i, v) i += 1 } result diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 3e0312d11d92e..553823a4c84a9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2780,6 +2780,62 @@ class DataFrameSuite extends QueryTest ) checkAnswer(test.select($"best_name.name"), Row("bob") :: Row("bob") :: Row("sam") :: Nil) } + + test("SPARK-34829: Multiple applications of typed ScalaUDFs in higher order functions work") { + val reverse = udf((s: String) => s.reverse) + val reverse2 = udf((b: Bar2) => Bar2(b.s.reverse)) + + val df = Seq(Array("abc", "def")).toDF("array") + val test = df.select(transform(col("array"), s => reverse(s))) + checkAnswer(test, Row(Array("cba", "fed")) :: Nil) + + val df2 = Seq(Array(Bar2("abc"), Bar2("def"))).toDF("array") + val test2 = df2.select(transform(col("array"), b => reverse2(b))) + checkAnswer(test2, Row(Array(Row("cba"), Row("fed"))) :: Nil) + + val df3 = Seq(Map("abc" -> 1, "def" -> 2)).toDF("map") + val test3 = df3.select(transform_keys(col("map"), (s, _) => reverse(s))) + checkAnswer(test3, Row(Map("cba" -> 1, "fed" -> 2)) :: Nil) + + val df4 = Seq(Map(Bar2("abc") -> 1, Bar2("def") -> 2)).toDF("map") + val test4 = df4.select(transform_keys(col("map"), (b, _) => reverse2(b))) + checkAnswer(test4, Row(Map(Row("cba") -> 1, Row("fed") -> 2)) :: Nil) + + val df5 = Seq(Map(1 -> "abc", 2 -> "def")).toDF("map") + val test5 = df5.select(transform_values(col("map"), (_, s) => reverse(s))) + checkAnswer(test5, Row(Map(1 -> "cba", 2 -> "fed")) :: Nil) + + val df6 = Seq(Map(1 -> Bar2("abc"), 2 -> Bar2("def"))).toDF("map") + val test6 = df6.select(transform_values(col("map"), (_, b) => reverse2(b))) + checkAnswer(test6, Row(Map(1 -> Row("cba"), 2 -> Row("fed"))) :: Nil) + + val reverseThenConcat = udf((s1: String, s2: String) => s1.reverse ++ s2.reverse) + val reverseThenConcat2 = udf((b1: Bar2, b2: Bar2) => Bar2(b1.s.reverse ++ b2.s.reverse)) + + val df7 = Seq((Map(1 -> "abc", 2 -> "def"), Map(1 -> "ghi", 2 -> "jkl"))).toDF("map1", "map2") + val test7 = + df7.select(map_zip_with(col("map1"), col("map2"), (_, s1, s2) => reverseThenConcat(s1, s2))) + checkAnswer(test7, Row(Map(1 -> "cbaihg", 2 -> "fedlkj")) :: Nil) + + val df8 = Seq((Map(1 -> Bar2("abc"), 2 -> Bar2("def")), + Map(1 -> Bar2("ghi"), 2 -> Bar2("jkl")))).toDF("map1", "map2") + val test8 = + df8.select(map_zip_with(col("map1"), col("map2"), (_, b1, b2) => reverseThenConcat2(b1, b2))) + checkAnswer(test8, Row(Map(1 -> Row("cbaihg"), 2 -> Row("fedlkj"))) :: Nil) + + val df9 = Seq((Array("abc", "def"), Array("ghi", "jkl"))).toDF("array1", "array2") + val test9 = + df9.select(zip_with(col("array1"), col("array2"), (s1, s2) => reverseThenConcat(s1, s2))) + checkAnswer(test9, Row(Array("cbaihg", "fedlkj")) :: Nil) + + val df10 = Seq((Array(Bar2("abc"), Bar2("def")), Array(Bar2("ghi"), Bar2("jkl")))) + .toDF("array1", "array2") + val test10 = + df10.select(zip_with(col("array1"), col("array2"), (b1, b2) => reverseThenConcat2(b1, b2))) + checkAnswer(test10, Row(Array(Row("cbaihg"), Row("fedlkj"))) :: Nil) + } } case class GroupByKey(a: Int, b: Int) + +case class Bar2(s: String)