Skip to content

Commit 7278bc7

Browse files
zhengruifengHyukjinKwon
authored andcommitted
[SPARK-50489][SQL][PYTHON] Fix self-join after applyInArrow
### What changes were proposed in this pull request? Fix self-join after `applyInArrow`, the same issue of `applyInPandas` was fixed in #31429 ### Why are the changes needed? bug fix before: ``` In [1]: import pyarrow as pa In [2]: df = spark.createDataFrame([(1, 1)], ("k", "v")) In [3]: def arrow_func(key, table): ...: return pa.Table.from_pydict({"x": [2], "y": [2]}) ...: In [4]: df2 = df.groupby("k").applyInArrow(arrow_func, schema="x long, y long") In [5]: df2.show() 24/12/04 17:47:43 WARN CheckAllocator: More than one DefaultAllocationManager on classpath. Choosing first found +---+---+ | x| y| +---+---+ | 2| 2| +---+---+ In [6]: df2.join(df2) ... Failure when resolving conflicting references in Join: 'Join Inner :- FlatMapGroupsInArrow [k#0L], arrow_func(k#0L, v#1L)#2, [x#3L, y#4L] : +- Project [k#0L, k#0L, v#1L] : +- LogicalRDD [k#0L, v#1L], false +- FlatMapGroupsInArrow [k#12L], arrow_func(k#12L, v#13L)#2, [x#3L, y#4L] +- Project [k#12L, k#12L, v#13L] +- LogicalRDD [k#12L, v#13L], false Conflicting attributes: "x", "y". SQLSTATE: XX000 at org.apache.spark.SparkException$.internalError(SparkException.scala:92) at org.apache.spark.SparkException$.internalError(SparkException.scala:79) at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis0$2(CheckAnalysis.scala:798) ``` after: ``` In [6]: df2.join(df2) Out[6]: DataFrame[x: bigint, y: bigint, x: bigint, y: bigint] In [7]: df2.join(df2).show() +---+---+---+---+ | x| y| x| y| +---+---+---+---+ | 2| 2| 2| 2| +---+---+---+---+ ``` ### Does this PR introduce _any_ user-facing change? bug fix ### How was this patch tested? added tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #49056 from zhengruifeng/fix_arrow_join. Authored-by: Ruifeng Zheng <ruifengz@apache.org> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
1 parent fe904e6 commit 7278bc7

File tree

3 files changed

+34
-0
lines changed

3 files changed

+34
-0
lines changed

python/pyspark/sql/tests/test_arrow_cogrouped_map.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,16 @@ def summarize(left, right):
299299
"+---------+------------+----------+-------------+\n",
300300
)
301301

302+
def test_self_join(self):
303+
df = self.spark.createDataFrame([(1, 1)], ("k", "v"))
304+
305+
def arrow_func(key, left, right):
306+
return pa.Table.from_pydict({"x": [2], "y": [2]})
307+
308+
df2 = df.groupby("k").cogroup(df.groupby("k")).applyInArrow(arrow_func, "x long, y long")
309+
310+
self.assertEqual(df2.join(df2).count(), 1)
311+
302312

303313
class CogroupedMapInArrowTests(CogroupedMapInArrowTestsMixin, ReusedSQLTestCase):
304314
@classmethod

python/pyspark/sql/tests/test_arrow_grouped_map.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,16 @@ def foo(_):
255255
self.assertEqual(r.a, "hi")
256256
self.assertEqual(r.b, 1)
257257

258+
def test_self_join(self):
259+
df = self.spark.createDataFrame([(1, 1)], ("k", "v"))
260+
261+
def arrow_func(key, table):
262+
return pa.Table.from_pydict({"x": [2], "y": [2]})
263+
264+
df2 = df.groupby("k").applyInArrow(arrow_func, schema="x long, y long")
265+
266+
self.assertEqual(df2.join(df2).count(), 1)
267+
258268

259269
class GroupedMapInArrowTests(GroupedMapInArrowTestsMixin, ReusedSQLTestCase):
260270
@classmethod

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,13 +132,27 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
132132
_.output.map(_.exprId.id),
133133
newFlatMap => newFlatMap.copy(output = newFlatMap.output.map(_.newInstance())))
134134

135+
case f: FlatMapGroupsInArrow =>
136+
deduplicateAndRenew[FlatMapGroupsInArrow](
137+
existingRelations,
138+
f,
139+
_.output.map(_.exprId.id),
140+
newFlatMap => newFlatMap.copy(output = newFlatMap.output.map(_.newInstance())))
141+
135142
case f: FlatMapCoGroupsInPandas =>
136143
deduplicateAndRenew[FlatMapCoGroupsInPandas](
137144
existingRelations,
138145
f,
139146
_.output.map(_.exprId.id),
140147
newFlatMap => newFlatMap.copy(output = newFlatMap.output.map(_.newInstance())))
141148

149+
case f: FlatMapCoGroupsInArrow =>
150+
deduplicateAndRenew[FlatMapCoGroupsInArrow](
151+
existingRelations,
152+
f,
153+
_.output.map(_.exprId.id),
154+
newFlatMap => newFlatMap.copy(output = newFlatMap.output.map(_.newInstance())))
155+
142156
case m: MapInPandas =>
143157
deduplicateAndRenew[MapInPandas](
144158
existingRelations,

0 commit comments

Comments
 (0)