Skip to content

Commit cf0c112

Browse files
committed
Merge pull request #9 from cloud-fan/ds-to-df
Fixes Python UDT
2 parents cd63810 + 4c8d139 commit cf0c112

File tree

2 files changed

+78
-72
lines changed

2 files changed

+78
-72
lines changed

python/pyspark/sql/tests.py

Lines changed: 70 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -528,76 +528,76 @@ def check_datatype(datatype):
528528
_verify_type(PythonOnlyPoint(1.0, 2.0), PythonOnlyUDT())
529529
self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], PythonOnlyUDT()))
530530

531-
# def test_infer_schema_with_udt(self):
532-
# from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
533-
# row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
534-
# df = self.sqlCtx.createDataFrame([row])
535-
# schema = df.schema
536-
# field = [f for f in schema.fields if f.name == "point"][0]
537-
# self.assertEqual(type(field.dataType), ExamplePointUDT)
538-
# df.registerTempTable("labeled_point")
539-
# point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point
540-
# self.assertEqual(point, ExamplePoint(1.0, 2.0))
541-
542-
# row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
543-
# df = self.sqlCtx.createDataFrame([row])
544-
# schema = df.schema
545-
# field = [f for f in schema.fields if f.name == "point"][0]
546-
# self.assertEqual(type(field.dataType), PythonOnlyUDT)
547-
# df.registerTempTable("labeled_point")
548-
# point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point
549-
# self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
550-
551-
# def test_apply_schema_with_udt(self):
552-
# from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
553-
# row = (1.0, ExamplePoint(1.0, 2.0))
554-
# schema = StructType([StructField("label", DoubleType(), False),
555-
# StructField("point", ExamplePointUDT(), False)])
556-
# df = self.sqlCtx.createDataFrame([row], schema)
557-
# point = df.head().point
558-
# self.assertEqual(point, ExamplePoint(1.0, 2.0))
559-
560-
# row = (1.0, PythonOnlyPoint(1.0, 2.0))
561-
# schema = StructType([StructField("label", DoubleType(), False),
562-
# StructField("point", PythonOnlyUDT(), False)])
563-
# df = self.sqlCtx.createDataFrame([row], schema)
564-
# point = df.head().point
565-
# self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
566-
567-
# def test_udf_with_udt(self):
568-
# from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
569-
# row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
570-
# df = self.sqlCtx.createDataFrame([row])
571-
# self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first())
572-
# udf = UserDefinedFunction(lambda p: p.y, DoubleType())
573-
# self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
574-
# udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT())
575-
# self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])
576-
577-
# row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
578-
# df = self.sqlCtx.createDataFrame([row])
579-
# self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first())
580-
# udf = UserDefinedFunction(lambda p: p.y, DoubleType())
581-
# self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
582-
# udf2 = UserDefinedFunction(lambda p: PythonOnlyPoint(p.x + 1, p.y + 1), PythonOnlyUDT())
583-
# self.assertEqual(PythonOnlyPoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])
584-
585-
# def test_parquet_with_udt(self):
586-
# from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
587-
# row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
588-
# df0 = self.sqlCtx.createDataFrame([row])
589-
# output_dir = os.path.join(self.tempdir.name, "labeled_point")
590-
# df0.write.parquet(output_dir)
591-
# df1 = self.sqlCtx.read.parquet(output_dir)
592-
# point = df1.head().point
593-
# self.assertEqual(point, ExamplePoint(1.0, 2.0))
594-
595-
# row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
596-
# df0 = self.sqlCtx.createDataFrame([row])
597-
# df0.write.parquet(output_dir, mode='overwrite')
598-
# df1 = self.sqlCtx.read.parquet(output_dir)
599-
# point = df1.head().point
600-
# self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
531+
def test_infer_schema_with_udt(self):
532+
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
533+
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
534+
df = self.sqlCtx.createDataFrame([row])
535+
schema = df.schema
536+
field = [f for f in schema.fields if f.name == "point"][0]
537+
self.assertEqual(type(field.dataType), ExamplePointUDT)
538+
df.registerTempTable("labeled_point")
539+
point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point
540+
self.assertEqual(point, ExamplePoint(1.0, 2.0))
541+
542+
row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
543+
df = self.sqlCtx.createDataFrame([row])
544+
schema = df.schema
545+
field = [f for f in schema.fields if f.name == "point"][0]
546+
self.assertEqual(type(field.dataType), PythonOnlyUDT)
547+
df.registerTempTable("labeled_point")
548+
point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point
549+
self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
550+
551+
def test_apply_schema_with_udt(self):
552+
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
553+
row = (1.0, ExamplePoint(1.0, 2.0))
554+
schema = StructType([StructField("label", DoubleType(), False),
555+
StructField("point", ExamplePointUDT(), False)])
556+
df = self.sqlCtx.createDataFrame([row], schema)
557+
point = df.head().point
558+
self.assertEqual(point, ExamplePoint(1.0, 2.0))
559+
560+
row = (1.0, PythonOnlyPoint(1.0, 2.0))
561+
schema = StructType([StructField("label", DoubleType(), False),
562+
StructField("point", PythonOnlyUDT(), False)])
563+
df = self.sqlCtx.createDataFrame([row], schema)
564+
point = df.head().point
565+
self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
566+
567+
def test_udf_with_udt(self):
568+
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
569+
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
570+
df = self.sqlCtx.createDataFrame([row])
571+
self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first())
572+
udf = UserDefinedFunction(lambda p: p.y, DoubleType())
573+
self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
574+
udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT())
575+
self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])
576+
577+
row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
578+
df = self.sqlCtx.createDataFrame([row])
579+
self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first())
580+
udf = UserDefinedFunction(lambda p: p.y, DoubleType())
581+
self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
582+
udf2 = UserDefinedFunction(lambda p: PythonOnlyPoint(p.x + 1, p.y + 1), PythonOnlyUDT())
583+
self.assertEqual(PythonOnlyPoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])
584+
585+
def test_parquet_with_udt(self):
586+
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
587+
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
588+
df0 = self.sqlCtx.createDataFrame([row])
589+
output_dir = os.path.join(self.tempdir.name, "labeled_point")
590+
df0.write.parquet(output_dir)
591+
df1 = self.sqlCtx.read.parquet(output_dir)
592+
point = df1.head().point
593+
self.assertEqual(point, ExamplePoint(1.0, 2.0))
594+
595+
row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
596+
df0 = self.sqlCtx.createDataFrame([row])
597+
df0.write.parquet(output_dir, mode='overwrite')
598+
df1 = self.sqlCtx.read.parquet(output_dir)
599+
point = df1.head().point
600+
self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
601601

602602
def test_unionAll_with_udt(self):
603603
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ object RowEncoder {
5252
case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType |
5353
FloatType | DoubleType | BinaryType | CalendarIntervalType => inputObject
5454

55+
case p: PythonUserDefinedType => extractorsFor(inputObject, p.sqlType)
56+
5557
case udt: UserDefinedType[_] =>
5658
val obj = NewInstance(
5759
udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
@@ -151,10 +153,14 @@ object RowEncoder {
151153

152154
private def constructorFor(schema: StructType): Expression = {
153155
val fields = schema.zipWithIndex.map { case (f, i) =>
154-
val field = BoundReference(i, f.dataType, f.nullable)
156+
val dt = f.dataType match {
157+
case p: PythonUserDefinedType => p.sqlType
158+
case other => other
159+
}
160+
val field = BoundReference(i, dt, f.nullable)
155161
If(
156162
IsNull(field),
157-
Literal.create(null, externalDataTypeFor(f.dataType)),
163+
Literal.create(null, externalDataTypeFor(dt)),
158164
constructorFor(field)
159165
)
160166
}

0 commit comments

Comments
 (0)