Skip to content

Commit e9362c2

Browse files
Ngone51HyukjinKwon
authored andcommitted
[SPARK-34319][SQL] Resolve duplicate attributes for FlatMapCoGroupsInPandas/MapInPandas
### What changes were proposed in this pull request? Resolve duplicate attributes for `FlatMapCoGroupsInPandas`. ### Why are the changes needed? When performing self-join on top of `FlatMapCoGroupsInPandas`, analysis can fail because of conflicting attributes. For example, ```scala df = spark.createDataFrame([(1, 1)], ("column", "value")) row = df.groupby("ColUmn").cogroup( df.groupby("COLUMN") ).applyInPandas(lambda r, l: r + l, "column long, value long") row.join(row).show() ``` error: ```scala ... Conflicting attributes: column#163321L,value#163322L ;; ’Join Inner :- FlatMapCoGroupsInPandas [ColUmn#163312L], [COLUMN#163312L], <lambda>(column#163312L, value#163313L, column#163312L, value#163313L), [column#163321L, value#163322L] : :- Project [ColUmn#163312L, column#163312L, value#163313L] : : +- LogicalRDD [column#163312L, value#163313L], false : +- Project [COLUMN#163312L, column#163312L, value#163313L] : +- LogicalRDD [column#163312L, value#163313L], false +- FlatMapCoGroupsInPandas [ColUmn#163312L], [COLUMN#163312L], <lambda>(column#163312L, value#163313L, column#163312L, value#163313L), [column#163321L, value#163322L] :- Project [ColUmn#163312L, column#163312L, value#163313L] : +- LogicalRDD [column#163312L, value#163313L], false +- Project [COLUMN#163312L, column#163312L, value#163313L] +- LogicalRDD [column#163312L, value#163313L], false ... ``` ### Does this PR introduce _any_ user-facing change? yes, the query like the above example won't fail. ### How was this patch tested? Adde unit tests. Closes #31429 from Ngone51/fix-conflcting-attrs-of-FlatMapCoGroupsInPandas. Lead-authored-by: yi.wu <yi.wu@databricks.com> Co-authored-by: wuyi <yi.wu@databricks.com> Signed-off-by: HyukjinKwon <gurwls223@apache.org>
1 parent 521397f commit e9362c2

File tree

4 files changed

+70
-0
lines changed

4 files changed

+70
-0
lines changed

python/pyspark/sql/tests/test_pandas_cogrouped_map.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,18 @@ def test_case_insensitive_grouping_column(self):
203203
).applyInPandas(lambda r, l: r + l, "column long, value long").first()
204204
self.assertEqual(row.asDict(), Row(column=2, value=2).asDict())
205205

206+
def test_self_join(self):
207+
# SPARK-34319: self-join with FlatMapCoGroupsInPandas
208+
df = self.spark.createDataFrame([(1, 1)], ("column", "value"))
209+
210+
row = df.groupby("ColUmn").cogroup(
211+
df.groupby("COLUMN")
212+
).applyInPandas(lambda r, l: r + l, "column long, value long")
213+
214+
row = row.join(row).first()
215+
216+
self.assertEqual(row.asDict(), Row(column=2, value=2).asDict())
217+
206218
@staticmethod
207219
def _test_with_key(left, right, isLeft):
208220

python/pyspark/sql/tests/test_pandas_map.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,14 @@ def func(iterator):
112112
expected = df.collect()
113113
self.assertEqual(actual, expected)
114114

115+
def test_self_join(self):
116+
# SPARK-34319: self-join with MapInPandas
117+
df1 = self.spark.range(10)
118+
df2 = df1.mapInPandas(lambda iter: iter, 'id long')
119+
actual = df2.join(df2).collect()
120+
expected = df1.join(df1).collect()
121+
self.assertEqual(sorted(actual), sorted(expected))
122+
115123

116124
if __name__ == "__main__":
117125
from pyspark.sql.tests.test_pandas_map import * # noqa: F401

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1405,6 +1405,14 @@ class Analyzer(override val catalogManager: CatalogManager)
14051405
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
14061406
Seq((oldVersion, oldVersion.copy(output = output.map(_.newInstance()))))
14071407

1408+
case oldVersion @ FlatMapCoGroupsInPandas(_, _, _, output, _, _)
1409+
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
1410+
Seq((oldVersion, oldVersion.copy(output = output.map(_.newInstance()))))
1411+
1412+
case oldVersion @ MapInPandas(_, output, _)
1413+
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
1414+
Seq((oldVersion, oldVersion.copy(output = output.map(_.newInstance()))))
1415+
14081416
case oldVersion: Generate
14091417
if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty =>
14101418
val newOutput = oldVersion.generatorOutput.map(_.newInstance())

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -630,6 +630,48 @@ class AnalysisSuite extends AnalysisTest with Matchers {
630630
Project(Seq(UnresolvedAttribute("temp0.a"), UnresolvedAttribute("temp1.a")), join))
631631
}
632632

633+
test("SPARK-34319: analysis fails on self-join with FlatMapCoGroupsInPandas") {
634+
val pythonUdf = PythonUDF("pyUDF", null,
635+
StructType(Seq(StructField("a", LongType))),
636+
Seq.empty,
637+
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
638+
true)
639+
val output = pythonUdf.dataType.asInstanceOf[StructType].toAttributes
640+
val project1 = Project(Seq(UnresolvedAttribute("a")), testRelation)
641+
val project2 = Project(Seq(UnresolvedAttribute("a")), testRelation2)
642+
val flatMapGroupsInPandas = FlatMapCoGroupsInPandas(
643+
Seq(UnresolvedAttribute("a")),
644+
Seq(UnresolvedAttribute("a")),
645+
pythonUdf,
646+
output,
647+
project1,
648+
project2)
649+
val left = SubqueryAlias("temp0", flatMapGroupsInPandas)
650+
val right = SubqueryAlias("temp1", flatMapGroupsInPandas)
651+
val join = Join(left, right, Inner, None, JoinHint.NONE)
652+
assertAnalysisSuccess(
653+
Project(Seq(UnresolvedAttribute("temp0.a"), UnresolvedAttribute("temp1.a")), join))
654+
}
655+
656+
test("SPARK-34319: analysis fails on self-join with MapInPandas") {
657+
val pythonUdf = PythonUDF("pyUDF", null,
658+
StructType(Seq(StructField("a", LongType))),
659+
Seq.empty,
660+
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
661+
true)
662+
val output = pythonUdf.dataType.asInstanceOf[StructType].toAttributes
663+
val project = Project(Seq(UnresolvedAttribute("a")), testRelation)
664+
val mapInPandas = MapInPandas(
665+
pythonUdf,
666+
output,
667+
project)
668+
val left = SubqueryAlias("temp0", mapInPandas)
669+
val right = SubqueryAlias("temp1", mapInPandas)
670+
val join = Join(left, right, Inner, None, JoinHint.NONE)
671+
assertAnalysisSuccess(
672+
Project(Seq(UnresolvedAttribute("temp0.a"), UnresolvedAttribute("temp1.a")), join))
673+
}
674+
633675
test("SPARK-24488 Generator with multiple aliases") {
634676
assertAnalysisSuccess(
635677
listRelation.select(Explode($"list").as("first_alias").as("second_alias")))

0 commit comments

Comments
 (0)