Skip to content

Commit 6a322a4

Browse files
author
Davies Liu
committed
tests refactor
1 parent 3da44fc commit 6a322a4

File tree

2 files changed

+21
-13
lines changed

2 files changed

+21
-13
lines changed

python/pyspark/sql/tests.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import pydoc
2525
import shutil
2626
import tempfile
27+
import pickle
2728

2829
import py4j
2930

@@ -88,6 +89,14 @@ def __eq__(self, other):
8889
other.x == self.x and other.y == self.y
8990

9091

92+
class DataTypeTests(unittest.TestCase):
93+
# regression test for SPARK-6055
94+
def test_data_type_eq(self):
95+
lt = LongType()
96+
lt2 = pickle.loads(pickle.dumps(LongType()))
97+
self.assertEquals(lt, lt2)
98+
99+
91100
class SQLTests(ReusedPySparkTestCase):
92101

93102
@classmethod

python/pyspark/sql/types.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -509,30 +509,31 @@ def __eq__(self, other):
509509
def _parse_datatype_json_string(json_string):
510510
"""Parses the given data type JSON string.
511511
>>> import pickle
512-
>>> LongType() == pickle.loads(pickle.dumps(LongType()))
513-
True
514512
>>> def check_datatype(datatype):
513+
... pickled = pickle.loads(pickle.dumps(datatype))
514+
... assert datatype == pickled
515515
... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.json())
516516
... python_datatype = _parse_datatype_json_string(scala_datatype.json())
517-
... return datatype == python_datatype
518-
>>> all(check_datatype(cls()) for cls in _all_primitive_types.values())
519-
True
517+
... assert datatype == python_datatype
518+
>>> for cls in _all_primitive_types.values():
519+
... check_datatype(cls())
520+
520521
>>> # Simple ArrayType.
521522
>>> simple_arraytype = ArrayType(StringType(), True)
522523
>>> check_datatype(simple_arraytype)
523-
True
524+
524525
>>> # Simple MapType.
525526
>>> simple_maptype = MapType(StringType(), LongType())
526527
>>> check_datatype(simple_maptype)
527-
True
528+
528529
>>> # Simple StructType.
529530
>>> simple_structtype = StructType([
530531
... StructField("a", DecimalType(), False),
531532
... StructField("b", BooleanType(), True),
532533
... StructField("c", LongType(), True),
533534
... StructField("d", BinaryType(), False)])
534535
>>> check_datatype(simple_structtype)
535-
True
536+
536537
>>> # Complex StructType.
537538
>>> complex_structtype = StructType([
538539
... StructField("simpleArray", simple_arraytype, True),
@@ -541,22 +542,20 @@ def _parse_datatype_json_string(json_string):
541542
... StructField("boolean", BooleanType(), False),
542543
... StructField("withMeta", DoubleType(), False, {"name": "age"})])
543544
>>> check_datatype(complex_structtype)
544-
True
545+
545546
>>> # Complex ArrayType.
546547
>>> complex_arraytype = ArrayType(complex_structtype, True)
547548
>>> check_datatype(complex_arraytype)
548-
True
549+
549550
>>> # Complex MapType.
550551
>>> complex_maptype = MapType(complex_structtype,
551552
... complex_arraytype, False)
552553
>>> check_datatype(complex_maptype)
553-
True
554+
554555
>>> check_datatype(ExamplePointUDT())
555-
True
556556
>>> structtype_with_udt = StructType([StructField("label", DoubleType(), False),
557557
... StructField("point", ExamplePointUDT(), False)])
558558
>>> check_datatype(structtype_with_udt)
559-
True
560559
"""
561560
return _parse_datatype_json_value(json.loads(json_string))
562561

0 commit comments

Comments
 (0)