Skip to content

Commit

Permalink
tests refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Feb 27, 2015
1 parent 3da44fc commit 6a322a4
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 13 deletions.
9 changes: 9 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import pydoc
import shutil
import tempfile
import pickle

import py4j

Expand Down Expand Up @@ -88,6 +89,14 @@ def __eq__(self, other):
other.x == self.x and other.y == self.y


class DataTypeTests(unittest.TestCase):
# regression test for SPARK-6055
def test_data_type_eq(self):
lt = LongType()
lt2 = pickle.loads(pickle.dumps(LongType()))
self.assertEquals(lt, lt2)


class SQLTests(ReusedPySparkTestCase):

@classmethod
Expand Down
25 changes: 12 additions & 13 deletions python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,30 +509,31 @@ def __eq__(self, other):
def _parse_datatype_json_string(json_string):
"""Parses the given data type JSON string.
>>> import pickle
>>> LongType() == pickle.loads(pickle.dumps(LongType()))
True
>>> def check_datatype(datatype):
... pickled = pickle.loads(pickle.dumps(datatype))
... assert datatype == pickled
... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.json())
... python_datatype = _parse_datatype_json_string(scala_datatype.json())
... return datatype == python_datatype
>>> all(check_datatype(cls()) for cls in _all_primitive_types.values())
True
... assert datatype == python_datatype
>>> for cls in _all_primitive_types.values():
... check_datatype(cls())
>>> # Simple ArrayType.
>>> simple_arraytype = ArrayType(StringType(), True)
>>> check_datatype(simple_arraytype)
True
>>> # Simple MapType.
>>> simple_maptype = MapType(StringType(), LongType())
>>> check_datatype(simple_maptype)
True
>>> # Simple StructType.
>>> simple_structtype = StructType([
... StructField("a", DecimalType(), False),
... StructField("b", BooleanType(), True),
... StructField("c", LongType(), True),
... StructField("d", BinaryType(), False)])
>>> check_datatype(simple_structtype)
True
>>> # Complex StructType.
>>> complex_structtype = StructType([
... StructField("simpleArray", simple_arraytype, True),
Expand All @@ -541,22 +542,20 @@ def _parse_datatype_json_string(json_string):
... StructField("boolean", BooleanType(), False),
... StructField("withMeta", DoubleType(), False, {"name": "age"})])
>>> check_datatype(complex_structtype)
True
>>> # Complex ArrayType.
>>> complex_arraytype = ArrayType(complex_structtype, True)
>>> check_datatype(complex_arraytype)
True
>>> # Complex MapType.
>>> complex_maptype = MapType(complex_structtype,
... complex_arraytype, False)
>>> check_datatype(complex_maptype)
True
>>> check_datatype(ExamplePointUDT())
True
>>> structtype_with_udt = StructType([StructField("label", DoubleType(), False),
... StructField("point", ExamplePointUDT(), False)])
>>> check_datatype(structtype_with_udt)
True
"""
return _parse_datatype_json_value(json.loads(json_string))

Expand Down

0 comments on commit 6a322a4

Please sign in to comment.