diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index 7dd42eecde7f8..4e2d3b9ba42a2 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -130,6 +130,20 @@ def test_self_join_II(self): self.assertTrue(df3.columns, ["aa", "b", "a", "b"]) self.assertTrue(df3.count() == 2) + def test_self_join_III(self): + df1 = self.spark.range(10).withColumn("value", lit(1)) + df2 = df1.union(df1) + df3 = df1.join(df2, df1.id == df2.id, "left") + self.assertTrue(df3.columns, ["id", "value", "id", "value"]) + self.assertTrue(df3.count() == 20) + + def test_self_join_IV(self): + df1 = self.spark.range(10).withColumn("value", lit(1)) + df2 = df1.withColumn("value", lit(2)).union(df1.withColumn("value", lit(3))) + df3 = df1.join(df2, df1.id == df2.id, "right") + self.assertTrue(df3.columns, ["id", "value", "id", "value"]) + self.assertTrue(df3.count() == 20) + def test_duplicated_column_names(self): df = self.spark.createDataFrame([(1, 2)], ["c", "c"]) row = df.select("*").first() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala index c10e000a098c9..1947c884694bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala @@ -585,7 +585,12 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { } (resolved.map(r => (r, currentDepth)), true) } else { - resolveDataFrameColumnByPlanId(u, id, isMetadataAccess, p.children, currentDepth + 1) + val children = p match { + // treat Union node as the leaf node + case _: Union => Seq.empty[LogicalPlan] + case _ => p.children + } + resolveDataFrameColumnByPlanId(u, id, isMetadataAccess, children, currentDepth + 1) } // In self join case like: