Skip to content

[SPARK-13410][SQL] Support unionAll for DataFrames with UDT columns. #11333

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,24 @@ def test_parquet_with_udt(self):
point = df1.head().point
self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))

def test_unionAll_with_udt(self):
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
row1 = (1.0, ExamplePoint(1.0, 2.0))
row2 = (2.0, ExamplePoint(3.0, 4.0))
schema = StructType([StructField("label", DoubleType(), False),
StructField("point", ExamplePointUDT(), False)])
df1 = self.sqlCtx.createDataFrame([row1], schema)
df2 = self.sqlCtx.createDataFrame([row2], schema)

result = df1.unionAll(df2).orderBy("label").collect()
self.assertEqual(
result,
[
Row(label=1.0, point=ExamplePoint(1.0, 2.0)),
Row(label=2.0, point=ExamplePoint(3.0, 4.0))
]
)

def test_column_operators(self):
ci = self.df.key
cs = self.df.value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable {

override private[sql] def acceptsType(dataType: DataType) =
this.getClass == dataType.getClass

override def equals(other: Any): Boolean = other match {
case that: UserDefinedType[_] => this.acceptsType(that)
case _ => false
}
}

/**
Expand All @@ -110,4 +115,9 @@ private[sql] class PythonUserDefinedType(
("serializedClass" -> serializedPyClass) ~
("sqlType" -> sqlType.jsonValue)
}

override def equals(other: Any): Boolean = other match {
case that: PythonUserDefinedType => this.pyUDT.equals(that.pyUDT)
case _ => false
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@ import org.apache.spark.sql.types._
* @param y y coordinate
*/
@SQLUserDefinedType(udt = classOf[ExamplePointUDT])
private[sql] class ExamplePoint(val x: Double, val y: Double) extends Serializable
private[sql] class ExamplePoint(val x: Double, val y: Double) extends Serializable {
override def equals(other: Any): Boolean = other match {
case that: ExamplePoint => this.x == that.x && this.y == that.y
case _ => false
}
}

/**
* User-defined type for [[ExamplePoint]].
Expand Down
16 changes: 16 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,22 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
}
}

test("unionAll should union DataFrames with UDTs (SPARK-13410)") {
val rowRDD1 = sparkContext.parallelize(Seq(Row(1, new ExamplePoint(1.0, 2.0))))
val schema1 = StructType(Array(StructField("label", IntegerType, false),
StructField("point", new ExamplePointUDT(), false)))
val rowRDD2 = sparkContext.parallelize(Seq(Row(2, new ExamplePoint(3.0, 4.0))))
val schema2 = StructType(Array(StructField("label", IntegerType, false),
StructField("point", new ExamplePointUDT(), false)))
val df1 = sqlContext.createDataFrame(rowRDD1, schema1)
val df2 = sqlContext.createDataFrame(rowRDD2, schema2)

checkAnswer(
df1.unionAll(df2).orderBy("label"),
Seq(Row(1, new ExamplePoint(1.0, 2.0)), Row(2, new ExamplePoint(3.0, 4.0)))
)
}

ignore("show") {
// This test case is intended ignored, but to make sure it compiles correctly
testData.select($"*").show()
Expand Down