Skip to content

Commit f550e03

Browse files
dsolowmaropudmsolow
committed
[SPARK-34794][SQL] Fix lambda variable name issues in nested DataFrame functions
### What changes were proposed in this pull request? To fix lambda variable name issues in nested DataFrame functions, this PR modifies code to use a global counter for `LambdaVariables` names created by higher order functions. This is the rework of apache#31887. Closes apache#31887. ### Why are the changes needed? This moves away from the current hard-coded variable names which break on nested function calls. There is currently a bug where nested transforms in particular fail (the inner variable shadows the outer variable) For this query: ``` val df = Seq( (Seq(1,2,3), Seq("a", "b", "c")) ).toDF("numbers", "letters") df.select( f.flatten( f.transform( $"numbers", (number: Column) => { f.transform( $"letters", (letter: Column) => { f.struct( number.as("number"), letter.as("letter") ) } ) } ) ).as("zipped") ).show(10, false) ``` This is the current (incorrect) output: ``` +------------------------------------------------------------------------+ |zipped | +------------------------------------------------------------------------+ |[{a, a}, {b, b}, {c, c}, {a, a}, {b, b}, {c, c}, {a, a}, {b, b}, {c, c}]| +------------------------------------------------------------------------+ ``` And this is the correct output after fix: ``` +------------------------------------------------------------------------+ |zipped | +------------------------------------------------------------------------+ |[{1, a}, {1, b}, {1, c}, {2, a}, {2, b}, {2, c}, {3, a}, {3, b}, {3, c}]| +------------------------------------------------------------------------+ ``` ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added the new test in `DataFrameFunctionsSuite`. Closes apache#32424 from maropu/pr31887. Lead-authored-by: dsolow <dsolow@sayari.com> Co-authored-by: Takeshi Yamamuro <yamamuro@apache.org> Co-authored-by: dmsolow <dsolow@sayarianalytics.com> Signed-off-by: Takeshi Yamamuro <yamamuro@apache.org>
1 parent 7fd3f8f commit f550e03

File tree

3 files changed

+40
-7
lines changed

3 files changed

+40
-7
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.catalyst.expressions
1919

2020
import java.util.Comparator
21-
import java.util.concurrent.atomic.AtomicReference
21+
import java.util.concurrent.atomic.{AtomicInteger, AtomicReference}
2222

2323
import scala.collection.mutable
2424

@@ -54,6 +54,16 @@ case class UnresolvedNamedLambdaVariable(nameParts: Seq[String])
5454
override def sql: String = name
5555
}
5656

57+
object UnresolvedNamedLambdaVariable {
58+
59+
// Counter to ensure lambda variable names are unique
60+
private val nextVarNameId = new AtomicInteger(0)
61+
62+
def freshVarName(name: String): String = {
63+
s"${name}_${nextVarNameId.getAndIncrement()}"
64+
}
65+
}
66+
5767
/**
5868
* A named lambda variable.
5969
*/

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3800,22 +3800,22 @@ object functions {
38003800
}
38013801

38023802
private def createLambda(f: Column => Column) = {
3803-
val x = UnresolvedNamedLambdaVariable(Seq("x"))
3803+
val x = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x")))
38043804
val function = f(Column(x)).expr
38053805
LambdaFunction(function, Seq(x))
38063806
}
38073807

38083808
private def createLambda(f: (Column, Column) => Column) = {
3809-
val x = UnresolvedNamedLambdaVariable(Seq("x"))
3810-
val y = UnresolvedNamedLambdaVariable(Seq("y"))
3809+
val x = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x")))
3810+
val y = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("y")))
38113811
val function = f(Column(x), Column(y)).expr
38123812
LambdaFunction(function, Seq(x, y))
38133813
}
38143814

38153815
private def createLambda(f: (Column, Column, Column) => Column) = {
3816-
val x = UnresolvedNamedLambdaVariable(Seq("x"))
3817-
val y = UnresolvedNamedLambdaVariable(Seq("y"))
3818-
val z = UnresolvedNamedLambdaVariable(Seq("z"))
3816+
val x = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x")))
3817+
val y = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("y")))
3818+
val z = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("z")))
38193819
val function = f(Column(x), Column(y), Column(z)).expr
38203820
LambdaFunction(function, Seq(x, y, z))
38213821
}

sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3629,6 +3629,29 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
36293629
df.select(map(map_entries($"m"), lit(1))),
36303630
Row(Map(Seq(Row(1, "a")) -> 1)))
36313631
}
3632+
3633+
test("SPARK-34794: lambda variable name issues in nested functions") {
3634+
val df1 = Seq((Seq(1, 2), Seq("a", "b"))).toDF("numbers", "letters")
3635+
3636+
checkAnswer(df1.select(flatten(transform($"numbers", (number: Column) =>
3637+
transform($"letters", (letter: Column) =>
3638+
struct(number, letter))))),
3639+
Seq(Row(Seq(Row(1, "a"), Row(1, "b"), Row(2, "a"), Row(2, "b"))))
3640+
)
3641+
checkAnswer(df1.select(flatten(transform($"numbers", (number: Column, i: Column) =>
3642+
transform($"letters", (letter: Column, j: Column) =>
3643+
struct(number + j, concat(letter, i)))))),
3644+
Seq(Row(Seq(Row(1, "a0"), Row(2, "b0"), Row(2, "a1"), Row(3, "b1"))))
3645+
)
3646+
3647+
val df2 = Seq((Map("a" -> 1, "b" -> 2), Map("a" -> 2, "b" -> 3))).toDF("m1", "m2")
3648+
3649+
checkAnswer(df2.select(map_zip_with($"m1", $"m2", (k1: Column, ov1: Column, ov2: Column) =>
3650+
map_zip_with($"m1", $"m2", (k2: Column, iv1: Column, iv2: Column) =>
3651+
ov1 + iv1 + ov2 + iv2))),
3652+
Seq(Row(Map("a" -> Map("a" -> 6, "b" -> 8), "b" -> Map("a" -> 8, "b" -> 10))))
3653+
)
3654+
}
36323655
}
36333656

36343657
object DataFrameFunctionsSuite {

0 commit comments

Comments
 (0)