Skip to content

Commit 17b59a9

Browse files
ueshinHyukjinKwon
authored andcommitted
[SPARK-35382][PYTHON] Fix lambda variable name issues in nested DataFrame functions in Python APIs
### What changes were proposed in this pull request? This PR fixes the same issue as #32424. ```py from pyspark.sql.functions import flatten, struct, transform df = spark.sql("SELECT array(1, 2, 3) as numbers, array('a', 'b', 'c') as letters") df.select(flatten( transform( "numbers", lambda number: transform( "letters", lambda letter: struct(number.alias("n"), letter.alias("l")) ) ) ).alias("zipped")).show(truncate=False) ``` **Before:** ``` +------------------------------------------------------------------------+ |zipped | +------------------------------------------------------------------------+ |[{a, a}, {b, b}, {c, c}, {a, a}, {b, b}, {c, c}, {a, a}, {b, b}, {c, c}]| +------------------------------------------------------------------------+ ``` **After:** ``` +------------------------------------------------------------------------+ |zipped | +------------------------------------------------------------------------+ |[{1, a}, {1, b}, {1, c}, {2, a}, {2, b}, {2, c}, {3, a}, {3, b}, {3, c}]| +------------------------------------------------------------------------+ ``` ### Why are the changes needed? To produce the correct results. ### Does this PR introduce _any_ user-facing change? Yes, it fixes the results to be correct as mentioned above. ### How was this patch tested? Added a unit test as well as manually. Closes #32523 from ueshin/issues/SPARK-35382/nested_higher_order_functions. Authored-by: Takuya UESHIN <ueshin@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
1 parent 0ab9bd7 commit 17b59a9

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

python/pyspark/sql/functions.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4254,7 +4254,10 @@ def _create_lambda(f):
42544254

42554255
argnames = ["x", "y", "z"]
42564256
args = [
4257-
_unresolved_named_lambda_variable(arg) for arg in argnames[: len(parameters)]
4257+
_unresolved_named_lambda_variable(
4258+
expressions.UnresolvedNamedLambdaVariable.freshVarName(arg)
4259+
)
4260+
for arg in argnames[: len(parameters)]
42584261
]
42594262

42604263
result = f(*args)

python/pyspark/sql/tests/test_functions.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,28 @@ def test_higher_order_function_failures(self):
493493
with self.assertRaises(ValueError):
494494
transform(col("foo"), lambda x: 1)
495495

496+
def test_nested_higher_order_function(self):
497+
# SPARK-35382: lambda vars must be resolved properly in nested higher order functions
498+
from pyspark.sql.functions import flatten, struct, transform
499+
500+
df = self.spark.sql("SELECT array(1, 2, 3) as numbers, array('a', 'b', 'c') as letters")
501+
502+
actual = df.select(flatten(
503+
transform(
504+
"numbers",
505+
lambda number: transform(
506+
"letters",
507+
lambda letter: struct(number.alias("n"), letter.alias("l"))
508+
)
509+
)
510+
)).first()[0]
511+
512+
expected = [(1, "a"), (1, "b"), (1, "c"),
513+
(2, "a"), (2, "b"), (2, "c"),
514+
(3, "a"), (3, "b"), (3, "c")]
515+
516+
self.assertEquals(actual, expected)
517+
496518
def test_window_functions(self):
497519
df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
498520
w = Window.partitionBy("value").orderBy("key")

0 commit comments

Comments
 (0)