Skip to content

Commit 0164e0f

Browse files
committed
Fix
1 parent e9a398b commit 0164e0f

File tree

3 files changed

+40
-38
lines changed

3 files changed

+40
-38
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.catalyst.expressions
1919

2020
import java.util.Comparator
21-
import java.util.concurrent.atomic.AtomicReference
21+
import java.util.concurrent.atomic.{AtomicInteger, AtomicReference}
2222

2323
import scala.collection.mutable
2424

@@ -54,6 +54,16 @@ case class UnresolvedNamedLambdaVariable(nameParts: Seq[String])
5454
override def sql: String = name
5555
}
5656

57+
object UnresolvedNamedLambdaVariable {
58+
59+
// Counter to ensure lambda variable names are unique
60+
private val nextVarNameId = new AtomicInteger(0)
61+
62+
def freshVarName(name: String): String = {
63+
s"${name}_${nextVarNameId.getAndIncrement()}"
64+
}
65+
}
66+
5767
/**
5868
* A named lambda variable.
5969
*/

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717

1818
package org.apache.spark.sql
1919

20-
import java.util.concurrent.atomic.AtomicInteger
21-
2220
import scala.collection.JavaConverters._
2321
import scala.reflect.runtime.universe.TypeTag
2422
import scala.util.Try
@@ -3801,26 +3799,23 @@ object functions {
38013799
ArrayExcept(col1.expr, col2.expr)
38023800
}
38033801

3804-
// counter to ensure lambda variable names are unique
3805-
private val lambdaVarNameCounter = new AtomicInteger(0)
3806-
38073802
private def createLambda(f: Column => Column) = {
3808-
val x = UnresolvedNamedLambdaVariable(Seq("x_" + lambdaVarNameCounter.incrementAndGet()))
3803+
val x = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x")))
38093804
val function = f(Column(x)).expr
38103805
LambdaFunction(function, Seq(x))
38113806
}
38123807

38133808
private def createLambda(f: (Column, Column) => Column) = {
3814-
val x = UnresolvedNamedLambdaVariable(Seq("x_" + lambdaVarNameCounter.incrementAndGet()))
3815-
val y = UnresolvedNamedLambdaVariable(Seq("y_" + lambdaVarNameCounter.incrementAndGet()))
3809+
val x = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x")))
3810+
val y = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("y")))
38163811
val function = f(Column(x), Column(y)).expr
38173812
LambdaFunction(function, Seq(x, y))
38183813
}
38193814

38203815
private def createLambda(f: (Column, Column, Column) => Column) = {
3821-
val x = UnresolvedNamedLambdaVariable(Seq("x_" + lambdaVarNameCounter.incrementAndGet()))
3822-
val y = UnresolvedNamedLambdaVariable(Seq("y_" + lambdaVarNameCounter.incrementAndGet()))
3823-
val z = UnresolvedNamedLambdaVariable(Seq("z_" + lambdaVarNameCounter.incrementAndGet()))
3816+
val x = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x")))
3817+
val y = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("y")))
3818+
val z = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("z")))
38243819
val function = f(Column(x), Column(y), Column(z)).expr
38253820
LambdaFunction(function, Seq(x, y, z))
38263821
}

sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2261,32 +2261,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
22612261
assert(ex3.getMessage.contains("cannot resolve 'a'"))
22622262
}
22632263

2264-
test("nested transform (DSL)") {
2265-
val df = Seq(
2266-
(Seq(1, 2, 3), Seq("a", "b", "c"))
2267-
).toDF("numbers", "letters")
2268-
2269-
checkAnswer(
2270-
df.select(
2271-
flatten(
2272-
transform(
2273-
$"numbers",
2274-
(number: Column) => transform(
2275-
$"letters",
2276-
(letter: Column) => struct(
2277-
number.as("number"),
2278-
letter.as("letter")
2279-
)
2280-
)
2281-
)
2282-
).as("zipped")
2283-
),
2284-
Seq(Row(Seq(Row(1, "a"), Row(1, "b"), Row(1, "c"), Row(2, "a"), Row(2, "b"),
2285-
Row(2, "c"), Row(3, "a"), Row(3, "b"), Row(3, "c")
2286-
)))
2287-
)
2288-
}
2289-
22902264
test("map_filter") {
22912265
val dfInts = Seq(
22922266
Map(1 -> 10, 2 -> 20, 3 -> 30),
@@ -3655,6 +3629,29 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
36553629
df.select(map(map_entries($"m"), lit(1))),
36563630
Row(Map(Seq(Row(1, "a")) -> 1)))
36573631
}
3632+
3633+
test("SPARK-34794: lambda variable name issues in nested functions") {
3634+
val df1 = Seq((Seq(1, 2), Seq("a", "b"))).toDF("numbers", "letters")
3635+
3636+
checkAnswer(df1.select(flatten(transform($"numbers", (number: Column) =>
3637+
transform($"letters", (letter: Column) =>
3638+
struct(number, letter))))),
3639+
Seq(Row(Seq(Row(1, "a"), Row(1, "b"), Row(2, "a"), Row(2, "b"))))
3640+
)
3641+
checkAnswer(df1.select(flatten(transform($"numbers", (number: Column, i: Column) =>
3642+
transform($"letters", (letter: Column, j: Column) =>
3643+
struct(number + j, concat(letter, i)))))),
3644+
Seq(Row(Seq(Row(1, "a0"), Row(2, "b0"), Row(2, "a1"), Row(3, "b1"))))
3645+
)
3646+
3647+
val df2 = Seq((Map("a" -> 1, "b" -> 2), Map("a" -> 2, "b" -> 3))).toDF("m1", "m2")
3648+
3649+
checkAnswer(df2.select(map_zip_with($"m1", $"m2", (k1: Column, ov1: Column, ov2: Column) =>
3650+
map_zip_with($"m1", $"m2", (k2: Column, iv1: Column, iv2: Column) =>
3651+
ov1 + iv1 + ov2 + iv2))),
3652+
Seq(Row(Map("a" -> Map("a" -> 6, "b" -> 8), "b" -> Map("a" -> 8, "b" -> 10))))
3653+
)
3654+
}
36583655
}
36593656

36603657
object DataFrameFunctionsSuite {

0 commit comments

Comments
 (0)