Skip to content

Commit 6df4ec0

Browse files
dsolowmaropudmsolow
committed
[SPARK-34794][SQL] Fix lambda variable name issues in nested DataFrame functions
### What changes were proposed in this pull request? To fix lambda variable name issues in nested DataFrame functions, this PR modifies code to use a global counter for `LambdaVariables` names created by higher order functions. This is the rework of #31887. Closes #31887. ### Why are the changes needed? This moves away from the current hard-coded variable names which break on nested function calls. There is currently a bug where nested transforms in particular fail (the inner variable shadows the outer variable) For this query: ``` val df = Seq( (Seq(1,2,3), Seq("a", "b", "c")) ).toDF("numbers", "letters") df.select( f.flatten( f.transform( $"numbers", (number: Column) => { f.transform( $"letters", (letter: Column) => { f.struct( number.as("number"), letter.as("letter") ) } ) } ) ).as("zipped") ).show(10, false) ``` This is the current (incorrect) output: ``` +------------------------------------------------------------------------+ |zipped | +------------------------------------------------------------------------+ |[{a, a}, {b, b}, {c, c}, {a, a}, {b, b}, {c, c}, {a, a}, {b, b}, {c, c}]| +------------------------------------------------------------------------+ ``` And this is the correct output after fix: ``` +------------------------------------------------------------------------+ |zipped | +------------------------------------------------------------------------+ |[{1, a}, {1, b}, {1, c}, {2, a}, {2, b}, {2, c}, {3, a}, {3, b}, {3, c}]| +------------------------------------------------------------------------+ ``` ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added the new test in `DataFrameFunctionsSuite`. Closes #32424 from maropu/pr31887. Lead-authored-by: dsolow <dsolow@sayari.com> Co-authored-by: Takeshi Yamamuro <yamamuro@apache.org> Co-authored-by: dmsolow <dsolow@sayarianalytics.com> Signed-off-by: Takeshi Yamamuro <yamamuro@apache.org> (cherry picked from commit f550e03) Signed-off-by: Takeshi Yamamuro <yamamuro@apache.org>
1 parent 89f5ec7 commit 6df4ec0

File tree

3 files changed

+40
-7
lines changed

3 files changed

+40
-7
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.catalyst.expressions
1919

2020
import java.util.Comparator
21-
import java.util.concurrent.atomic.AtomicReference
21+
import java.util.concurrent.atomic.{AtomicInteger, AtomicReference}
2222

2323
import scala.collection.mutable
2424

@@ -52,6 +52,16 @@ case class UnresolvedNamedLambdaVariable(nameParts: Seq[String])
5252
override def sql: String = name
5353
}
5454

55+
object UnresolvedNamedLambdaVariable {
56+
57+
// Counter to ensure lambda variable names are unique
58+
private val nextVarNameId = new AtomicInteger(0)
59+
60+
def freshVarName(name: String): String = {
61+
s"${name}_${nextVarNameId.getAndIncrement()}"
62+
}
63+
}
64+
5565
/**
5666
* A named lambda variable.
5767
*/

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3644,22 +3644,22 @@ object functions {
36443644
}
36453645

36463646
private def createLambda(f: Column => Column) = {
3647-
val x = UnresolvedNamedLambdaVariable(Seq("x"))
3647+
val x = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x")))
36483648
val function = f(Column(x)).expr
36493649
LambdaFunction(function, Seq(x))
36503650
}
36513651

36523652
private def createLambda(f: (Column, Column) => Column) = {
3653-
val x = UnresolvedNamedLambdaVariable(Seq("x"))
3654-
val y = UnresolvedNamedLambdaVariable(Seq("y"))
3653+
val x = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x")))
3654+
val y = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("y")))
36553655
val function = f(Column(x), Column(y)).expr
36563656
LambdaFunction(function, Seq(x, y))
36573657
}
36583658

36593659
private def createLambda(f: (Column, Column, Column) => Column) = {
3660-
val x = UnresolvedNamedLambdaVariable(Seq("x"))
3661-
val y = UnresolvedNamedLambdaVariable(Seq("y"))
3662-
val z = UnresolvedNamedLambdaVariable(Seq("z"))
3660+
val x = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x")))
3661+
val y = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("y")))
3662+
val z = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("z")))
36633663
val function = f(Column(x), Column(y), Column(z)).expr
36643664
LambdaFunction(function, Seq(x, y, z))
36653665
}

sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3629,6 +3629,29 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
36293629
df.select(map(map_entries($"m"), lit(1))),
36303630
Row(Map(Seq(Row(1, "a")) -> 1)))
36313631
}
3632+
3633+
test("SPARK-34794: lambda variable name issues in nested functions") {
3634+
val df1 = Seq((Seq(1, 2), Seq("a", "b"))).toDF("numbers", "letters")
3635+
3636+
checkAnswer(df1.select(flatten(transform($"numbers", (number: Column) =>
3637+
transform($"letters", (letter: Column) =>
3638+
struct(number, letter))))),
3639+
Seq(Row(Seq(Row(1, "a"), Row(1, "b"), Row(2, "a"), Row(2, "b"))))
3640+
)
3641+
checkAnswer(df1.select(flatten(transform($"numbers", (number: Column, i: Column) =>
3642+
transform($"letters", (letter: Column, j: Column) =>
3643+
struct(number + j, concat(letter, i)))))),
3644+
Seq(Row(Seq(Row(1, "a0"), Row(2, "b0"), Row(2, "a1"), Row(3, "b1"))))
3645+
)
3646+
3647+
val df2 = Seq((Map("a" -> 1, "b" -> 2), Map("a" -> 2, "b" -> 3))).toDF("m1", "m2")
3648+
3649+
checkAnswer(df2.select(map_zip_with($"m1", $"m2", (k1: Column, ov1: Column, ov2: Column) =>
3650+
map_zip_with($"m1", $"m2", (k2: Column, iv1: Column, iv2: Column) =>
3651+
ov1 + iv1 + ov2 + iv2))),
3652+
Seq(Row(Map("a" -> Map("a" -> 6, "b" -> 8), "b" -> Map("a" -> 8, "b" -> 10))))
3653+
)
3654+
}
36323655
}
36333656

36343657
object DataFrameFunctionsSuite {

0 commit comments

Comments
 (0)