Skip to content

[Python] (PySpark) Support for subclasses in type_verifier #50726

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
299 changes: 299 additions & 0 deletions python/pyspark/sql/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2700,6 +2700,305 @@ def test_struct_field_from_json(self):
self.assertEqual(repr(struct_field), "StructField('c1', StringType(), True)")


class ExtendedDataTypeVerificationTests(unittest.TestCase, PySparkErrorTestUtils):
class _ExtendedBooleanType(BooleanType):
...

class _ExtendedByteType(ByteType):
...

class _ExtendedShortType(ShortType):
...

class _ExtendedIntegerType(IntegerType):
...

class _ExtendedLongType(LongType):
...

class _ExtendedFloatType(FloatType):
...

class _ExtendedDoubleType(DoubleType):
...

class _ExtendedDecimalType(DecimalType):
...

class _ExtendedStringType(StringType):
...

class _ExtendedCharType(CharType):
...

class _ExtendedVarcharType(VarcharType):
...

class _ExtendedBinaryType(BinaryType):
...

class _ExtendedDateType(DateType):
...

class _ExtendedTimestampType(TimestampType):
...

class _ExtendedArrayType(ArrayType):
...

class _ExtendedMapType(MapType):
...

class _ExtendedStructType(StructType):
...

def test_verify_type_exception_msg(self):
# tests similar to DataTypeVerificationTests.test_verify_type_exception_msg
with self.assertRaises(PySparkValueError) as pe:
_make_type_verifier(self._ExtendedStringType(), nullable=False, name="test_name")(None)

self.check_error(
exception=pe.exception,
errorClass="FIELD_NOT_NULLABLE_WITH_NAME",
messageParameters={
"field_name": "test_name",
},
)

schema = self._ExtendedStructType(
[
StructField(
"a", self._ExtendedStructType([StructField("b", self._ExtendedIntegerType())])
)
]
)
with self.assertRaises(PySparkTypeError) as pe:
_make_type_verifier(schema)([["data"]])

self.check_error(
exception=pe.exception,
errorClass="FIELD_DATA_TYPE_UNACCEPTABLE_WITH_NAME",
messageParameters={
"data_type": "_ExtendedIntegerType()",
"field_name": "field b in field a",
"obj": "'data'",
"obj_type": "<class 'str'>",
},
)

def test_verify_type_ok_nullable(self):
# tests similar to DataTypeVerificationTests.test_verify_type_ok_nullable
obj = None
types = [
self._ExtendedIntegerType(),
self._ExtendedFloatType(),
self._ExtendedStringType(),
self._ExtendedStructType([]),
]
for data_type in types:
try:
_make_type_verifier(data_type, nullable=True)(obj)
except Exception:
self.fail("verify_type(%s, %s, nullable=True)" % (obj, data_type))

def test_verify_type_not_nullable(self):
# tests similar to DataTypeVerificationTests.test_verify_type_not_nullable
import array
import datetime
import decimal

schema = self._ExtendedStructType(
[
StructField("s", StringType(), nullable=False),
StructField("i", IntegerType(), nullable=True),
]
)

class MyObj:
def __init__(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)

# obj, data_type
success_spec = [
# String
("", self._ExtendedStringType()),
(1, self._ExtendedStringType()),
(1.0, self._ExtendedStringType()),
([], self._ExtendedStringType()),
({}, self._ExtendedStringType()),
("", self._ExtendedStringType("UTF8_LCASE")),
# Char
("", self._ExtendedCharType(10)),
(1, self._ExtendedCharType(10)),
(1.0, self._ExtendedCharType(10)),
([], self._ExtendedCharType(10)),
({}, self._ExtendedCharType(10)),
# Varchar
("", self._ExtendedVarcharType(10)),
(1, self._ExtendedVarcharType(10)),
(1.0, self._ExtendedVarcharType(10)),
([], self._ExtendedVarcharType(10)),
({}, self._ExtendedVarcharType(10)),
# Boolean
(True, self._ExtendedBooleanType()),
# Byte
(-(2**7), self._ExtendedByteType()),
(2**7 - 1, self._ExtendedByteType()),
# Short
(-(2**15), self._ExtendedShortType()),
(2**15 - 1, self._ExtendedShortType()),
# Integer
(-(2**31), self._ExtendedIntegerType()),
(2**31 - 1, self._ExtendedIntegerType()),
# Long
(-(2**63), self._ExtendedLongType()),
(2**63 - 1, self._ExtendedLongType()),
# Float & Double
(1.0, self._ExtendedFloatType()),
(1.0, self._ExtendedDoubleType()),
# Decimal
(decimal.Decimal("1.0"), self._ExtendedDecimalType()),
# Binary
(bytearray([1, 2]), self._ExtendedBinaryType()),
# Date/Timestamp
(datetime.date(2000, 1, 2), self._ExtendedDateType()),
(datetime.datetime(2000, 1, 2, 3, 4), self._ExtendedDateType()),
(datetime.datetime(2000, 1, 2, 3, 4), self._ExtendedTimestampType()),
# Array
([], self._ExtendedArrayType(IntegerType())),
(
["1", None],
self._ExtendedArrayType(self._ExtendedStringType(), containsNull=True),
),
([1, 2], self._ExtendedArrayType(self._ExtendedIntegerType())),
((1, 2), self._ExtendedArrayType(self._ExtendedIntegerType())),
(
array.array("h", [1, 2]),
self._ExtendedArrayType(self._ExtendedIntegerType()),
),
# Map
(
{},
self._ExtendedMapType(self._ExtendedStringType(), self._ExtendedIntegerType()),
),
(
{"a": 1},
self._ExtendedMapType(self._ExtendedStringType(), self._ExtendedIntegerType()),
),
(
{"a": None},
self._ExtendedMapType(
self._ExtendedStringType(),
self._ExtendedIntegerType(),
valueContainsNull=True,
),
),
# Struct
({"s": "a", "i": 1}, schema),
({"s": "a", "i": None}, schema),
({"s": "a"}, schema),
({"s": "a", "f": 1.0}, schema),
(Row(s="a", i=1), schema),
(Row(s="a", i=None), schema),
(["a", 1], schema),
(["a", None], schema),
(("a", 1), schema),
(MyObj(s="a", i=1), schema),
(MyObj(s="a", i=None), schema),
(MyObj(s="a"), schema),
]

# obj, data_type, exception class
failure_spec = [
# String (match anything but None)
(None, self._ExtendedStringType(), ValueError),
(None, self._ExtendedStringType("UTF8_LCASE"), ValueError),
# CharType (match anything but None)
(None, self._ExtendedCharType(10), ValueError),
# VarcharType (match anything but None)
(None, self._ExtendedVarcharType(10), ValueError),
# UDT
(ExamplePoint(1.0, 2.0), PythonOnlyUDT(), ValueError),
# Boolean
(1, self._ExtendedBooleanType(), TypeError),
("True", self._ExtendedBooleanType(), TypeError),
([1], self._ExtendedBooleanType(), TypeError),
# Byte
(-(2**7) - 1, self._ExtendedByteType(), ValueError),
(2**7, self._ExtendedByteType(), ValueError),
("1", self._ExtendedByteType(), TypeError),
(1.0, self._ExtendedByteType(), TypeError),
# Short
(-(2**15) - 1, self._ExtendedShortType(), ValueError),
(2**15, self._ExtendedShortType(), ValueError),
# Integer
(-(2**31) - 1, self._ExtendedIntegerType(), ValueError),
(2**31, self._ExtendedIntegerType(), ValueError),
# Float & Double
(1, self._ExtendedFloatType(), TypeError),
(1, self._ExtendedDoubleType(), TypeError),
# Decimal
(1.0, self._ExtendedDecimalType(), TypeError),
(1, self._ExtendedDecimalType(), TypeError),
("1.0", self._ExtendedDecimalType(), TypeError),
# Binary
(1, self._ExtendedBinaryType(), TypeError),
# Date/Timestamp
("2000-01-02", self._ExtendedDateType(), TypeError),
(946811040, self._ExtendedTimestampType(), TypeError),
# Array
(
["1", None],
self._ExtendedArrayType(self._ExtendedStringType(), containsNull=False),
ValueError,
),
([1, "2"], self._ExtendedArrayType(self._ExtendedIntegerType()), TypeError),
# Map
(
{"a": 1},
self._ExtendedMapType(self._ExtendedIntegerType(), self._ExtendedIntegerType()),
TypeError,
),
(
{"a": "1"},
self._ExtendedMapType(self._ExtendedStringType(), self._ExtendedIntegerType()),
TypeError,
),
(
{"a": None},
self._ExtendedMapType(
self._ExtendedStringType(),
self._ExtendedIntegerType(),
valueContainsNull=False,
),
ValueError,
),
# Struct
({"s": "a", "i": "1"}, schema, TypeError),
(Row(s="a"), schema, ValueError), # Row can't have missing field
(Row(s="a", i="1"), schema, TypeError),
(["a"], schema, ValueError),
(["a", "1"], schema, TypeError),
(MyObj(s="a", i="1"), schema, TypeError),
(MyObj(s=None, i="1"), schema, ValueError),
([1], schema, ValueError),
]

# Check success cases
for obj, data_type in success_spec:
try:
_make_type_verifier(data_type, nullable=False)(obj)
except Exception:
self.fail("verify_type(%s, %s, nullable=False)" % (obj, data_type))

# Check failure cases
for obj, data_type, exp in failure_spec:
msg = "verify_type(%s, %s, nullable=False) == %s" % (obj, data_type, exp)
with self.assertRaises(exp, msg=msg):
_make_type_verifier(data_type, nullable=False)(obj)


class TypesTests(TypesTestsMixin, ReusedSQLTestCase):
pass

Expand Down
12 changes: 8 additions & 4 deletions python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2701,16 +2701,20 @@ def verify_nullability(obj: Any) -> bool:
else:
return False

_type = type(dataType)

def assert_acceptable_types(obj: Any) -> None:
assert _type in _acceptable_types, new_msg(
assert any(isinstance(dataType, _type) for _type in _acceptable_types), new_msg(
"unknown datatype: %s for object %r" % (dataType, obj)
)

def _get_supported_types() -> Tuple[Any, ...]:
for _type, data_types in _acceptable_types.items():
if isinstance(dataType, _type):
return data_types
raise PySparkTypeError("unknown datatype: %s" % dataType)

def verify_acceptable_types(obj: Any) -> None:
# subclass of them can not be fromInternal in JVM
if type(obj) not in _acceptable_types[_type]:
if type(obj) not in _get_supported_types():
if name is not None:
raise PySparkTypeError(
errorClass="FIELD_DATA_TYPE_UNACCEPTABLE_WITH_NAME",
Expand Down