@@ -528,76 +528,76 @@ def check_datatype(datatype):
528
528
_verify_type (PythonOnlyPoint (1.0 , 2.0 ), PythonOnlyUDT ())
529
529
self .assertRaises (ValueError , lambda : _verify_type ([1.0 , 2.0 ], PythonOnlyUDT ()))
530
530
531
- # def test_infer_schema_with_udt(self):
532
- # from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
533
- # row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
534
- # df = self.sqlCtx.createDataFrame([row])
535
- # schema = df.schema
536
- # field = [f for f in schema.fields if f.name == "point"][0]
537
- # self.assertEqual(type(field.dataType), ExamplePointUDT)
538
- # df.registerTempTable("labeled_point")
539
- # point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point
540
- # self.assertEqual(point, ExamplePoint(1.0, 2.0))
541
-
542
- # row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
543
- # df = self.sqlCtx.createDataFrame([row])
544
- # schema = df.schema
545
- # field = [f for f in schema.fields if f.name == "point"][0]
546
- # self.assertEqual(type(field.dataType), PythonOnlyUDT)
547
- # df.registerTempTable("labeled_point")
548
- # point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point
549
- # self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
550
-
551
- # def test_apply_schema_with_udt(self):
552
- # from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
553
- # row = (1.0, ExamplePoint(1.0, 2.0))
554
- # schema = StructType([StructField("label", DoubleType(), False),
555
- # StructField("point", ExamplePointUDT(), False)])
556
- # df = self.sqlCtx.createDataFrame([row], schema)
557
- # point = df.head().point
558
- # self.assertEqual(point, ExamplePoint(1.0, 2.0))
559
-
560
- # row = (1.0, PythonOnlyPoint(1.0, 2.0))
561
- # schema = StructType([StructField("label", DoubleType(), False),
562
- # StructField("point", PythonOnlyUDT(), False)])
563
- # df = self.sqlCtx.createDataFrame([row], schema)
564
- # point = df.head().point
565
- # self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
566
-
567
- # def test_udf_with_udt(self):
568
- # from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
569
- # row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
570
- # df = self.sqlCtx.createDataFrame([row])
571
- # self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first())
572
- # udf = UserDefinedFunction(lambda p: p.y, DoubleType())
573
- # self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
574
- # udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT())
575
- # self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])
576
-
577
- # row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
578
- # df = self.sqlCtx.createDataFrame([row])
579
- # self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first())
580
- # udf = UserDefinedFunction(lambda p: p.y, DoubleType())
581
- # self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
582
- # udf2 = UserDefinedFunction(lambda p: PythonOnlyPoint(p.x + 1, p.y + 1), PythonOnlyUDT())
583
- # self.assertEqual(PythonOnlyPoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])
584
-
585
- # def test_parquet_with_udt(self):
586
- # from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
587
- # row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
588
- # df0 = self.sqlCtx.createDataFrame([row])
589
- # output_dir = os.path.join(self.tempdir.name, "labeled_point")
590
- # df0.write.parquet(output_dir)
591
- # df1 = self.sqlCtx.read.parquet(output_dir)
592
- # point = df1.head().point
593
- # self.assertEqual(point, ExamplePoint(1.0, 2.0))
594
-
595
- # row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
596
- # df0 = self.sqlCtx.createDataFrame([row])
597
- # df0.write.parquet(output_dir, mode='overwrite')
598
- # df1 = self.sqlCtx.read.parquet(output_dir)
599
- # point = df1.head().point
600
- # self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
531
+ def test_infer_schema_with_udt (self ):
532
+ from pyspark .sql .tests import ExamplePoint , ExamplePointUDT
533
+ row = Row (label = 1.0 , point = ExamplePoint (1.0 , 2.0 ))
534
+ df = self .sqlCtx .createDataFrame ([row ])
535
+ schema = df .schema
536
+ field = [f for f in schema .fields if f .name == "point" ][0 ]
537
+ self .assertEqual (type (field .dataType ), ExamplePointUDT )
538
+ df .registerTempTable ("labeled_point" )
539
+ point = self .sqlCtx .sql ("SELECT point FROM labeled_point" ).head ().point
540
+ self .assertEqual (point , ExamplePoint (1.0 , 2.0 ))
541
+
542
+ row = Row (label = 1.0 , point = PythonOnlyPoint (1.0 , 2.0 ))
543
+ df = self .sqlCtx .createDataFrame ([row ])
544
+ schema = df .schema
545
+ field = [f for f in schema .fields if f .name == "point" ][0 ]
546
+ self .assertEqual (type (field .dataType ), PythonOnlyUDT )
547
+ df .registerTempTable ("labeled_point" )
548
+ point = self .sqlCtx .sql ("SELECT point FROM labeled_point" ).head ().point
549
+ self .assertEqual (point , PythonOnlyPoint (1.0 , 2.0 ))
550
+
551
+ def test_apply_schema_with_udt (self ):
552
+ from pyspark .sql .tests import ExamplePoint , ExamplePointUDT
553
+ row = (1.0 , ExamplePoint (1.0 , 2.0 ))
554
+ schema = StructType ([StructField ("label" , DoubleType (), False ),
555
+ StructField ("point" , ExamplePointUDT (), False )])
556
+ df = self .sqlCtx .createDataFrame ([row ], schema )
557
+ point = df .head ().point
558
+ self .assertEqual (point , ExamplePoint (1.0 , 2.0 ))
559
+
560
+ row = (1.0 , PythonOnlyPoint (1.0 , 2.0 ))
561
+ schema = StructType ([StructField ("label" , DoubleType (), False ),
562
+ StructField ("point" , PythonOnlyUDT (), False )])
563
+ df = self .sqlCtx .createDataFrame ([row ], schema )
564
+ point = df .head ().point
565
+ self .assertEqual (point , PythonOnlyPoint (1.0 , 2.0 ))
566
+
567
+ def test_udf_with_udt (self ):
568
+ from pyspark .sql .tests import ExamplePoint , ExamplePointUDT
569
+ row = Row (label = 1.0 , point = ExamplePoint (1.0 , 2.0 ))
570
+ df = self .sqlCtx .createDataFrame ([row ])
571
+ self .assertEqual (1.0 , df .rdd .map (lambda r : r .point .x ).first ())
572
+ udf = UserDefinedFunction (lambda p : p .y , DoubleType ())
573
+ self .assertEqual (2.0 , df .select (udf (df .point )).first ()[0 ])
574
+ udf2 = UserDefinedFunction (lambda p : ExamplePoint (p .x + 1 , p .y + 1 ), ExamplePointUDT ())
575
+ self .assertEqual (ExamplePoint (2.0 , 3.0 ), df .select (udf2 (df .point )).first ()[0 ])
576
+
577
+ row = Row (label = 1.0 , point = PythonOnlyPoint (1.0 , 2.0 ))
578
+ df = self .sqlCtx .createDataFrame ([row ])
579
+ self .assertEqual (1.0 , df .rdd .map (lambda r : r .point .x ).first ())
580
+ udf = UserDefinedFunction (lambda p : p .y , DoubleType ())
581
+ self .assertEqual (2.0 , df .select (udf (df .point )).first ()[0 ])
582
+ udf2 = UserDefinedFunction (lambda p : PythonOnlyPoint (p .x + 1 , p .y + 1 ), PythonOnlyUDT ())
583
+ self .assertEqual (PythonOnlyPoint (2.0 , 3.0 ), df .select (udf2 (df .point )).first ()[0 ])
584
+
585
+ def test_parquet_with_udt (self ):
586
+ from pyspark .sql .tests import ExamplePoint , ExamplePointUDT
587
+ row = Row (label = 1.0 , point = ExamplePoint (1.0 , 2.0 ))
588
+ df0 = self .sqlCtx .createDataFrame ([row ])
589
+ output_dir = os .path .join (self .tempdir .name , "labeled_point" )
590
+ df0 .write .parquet (output_dir )
591
+ df1 = self .sqlCtx .read .parquet (output_dir )
592
+ point = df1 .head ().point
593
+ self .assertEqual (point , ExamplePoint (1.0 , 2.0 ))
594
+
595
+ row = Row (label = 1.0 , point = PythonOnlyPoint (1.0 , 2.0 ))
596
+ df0 = self .sqlCtx .createDataFrame ([row ])
597
+ df0 .write .parquet (output_dir , mode = 'overwrite' )
598
+ df1 = self .sqlCtx .read .parquet (output_dir )
599
+ point = df1 .head ().point
600
+ self .assertEqual (point , PythonOnlyPoint (1.0 , 2.0 ))
601
601
602
602
def test_unionAll_with_udt (self ):
603
603
from pyspark .sql .tests import ExamplePoint , ExamplePointUDT
0 commit comments