diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 83899ad4b1b12..2720439416682 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -24,6 +24,7 @@ import pydoc import shutil import tempfile +import pickle import py4j @@ -88,6 +89,14 @@ def __eq__(self, other): other.x == self.x and other.y == self.y +class DataTypeTests(unittest.TestCase): + # regression test for SPARK-6055 + def test_data_type_eq(self): + lt = LongType() + lt2 = pickle.loads(pickle.dumps(LongType())) + self.assertEquals(lt, lt2) + + class SQLTests(ReusedPySparkTestCase): @classmethod diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 3e7c1256f3595..31a861e1feb46 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -509,22 +509,23 @@ def __eq__(self, other): def _parse_datatype_json_string(json_string): """Parses the given data type JSON string. >>> import pickle - >>> LongType() == pickle.loads(pickle.dumps(LongType())) - True >>> def check_datatype(datatype): + ... pickled = pickle.loads(pickle.dumps(datatype)) + ... assert datatype == pickled ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.json()) ... python_datatype = _parse_datatype_json_string(scala_datatype.json()) - ... return datatype == python_datatype - >>> all(check_datatype(cls()) for cls in _all_primitive_types.values()) - True + ... assert datatype == python_datatype + >>> for cls in _all_primitive_types.values(): + ... check_datatype(cls()) + >>> # Simple ArrayType. >>> simple_arraytype = ArrayType(StringType(), True) >>> check_datatype(simple_arraytype) - True + >>> # Simple MapType. >>> simple_maptype = MapType(StringType(), LongType()) >>> check_datatype(simple_maptype) - True + >>> # Simple StructType. >>> simple_structtype = StructType([ ... StructField("a", DecimalType(), False), @@ -532,7 +533,7 @@ def _parse_datatype_json_string(json_string): ... StructField("c", LongType(), True), ... StructField("d", BinaryType(), False)]) >>> check_datatype(simple_structtype) - True + >>> # Complex StructType. >>> complex_structtype = StructType([ ... StructField("simpleArray", simple_arraytype, True), @@ -541,22 +542,20 @@ def _parse_datatype_json_string(json_string): ... StructField("boolean", BooleanType(), False), ... StructField("withMeta", DoubleType(), False, {"name": "age"})]) >>> check_datatype(complex_structtype) - True + >>> # Complex ArrayType. >>> complex_arraytype = ArrayType(complex_structtype, True) >>> check_datatype(complex_arraytype) - True + >>> # Complex MapType. >>> complex_maptype = MapType(complex_structtype, ... complex_arraytype, False) >>> check_datatype(complex_maptype) - True + >>> check_datatype(ExamplePointUDT()) - True >>> structtype_with_udt = StructType([StructField("label", DoubleType(), False), ... StructField("point", ExamplePointUDT(), False)]) >>> check_datatype(structtype_with_udt) - True """ return _parse_datatype_json_value(json.loads(json_string))