diff --git a/pandera/accessors/pyspark_sql_accessor.py b/pandera/accessors/pyspark_sql_accessor.py index d59dbc99..338ac29e 100644 --- a/pandera/accessors/pyspark_sql_accessor.py +++ b/pandera/accessors/pyspark_sql_accessor.py @@ -2,8 +2,10 @@ """ import warnings +from packaging import version from typing import Optional +import pyspark from pandera.api.base.error_handler import ErrorHandler from pandera.api.pyspark.container import DataFrameSchema @@ -104,7 +106,7 @@ def decorator(accessor): def register_dataframe_accessor(name): """ - Register a custom accessor with a DataFrame + Register a custom accessor with a classical Spark DataFrame :param name: name used when calling the accessor after its registered :returns: a class decorator callable. @@ -115,6 +117,19 @@ def register_dataframe_accessor(name): return _register_accessor(name, DataFrame) +def register_connect_dataframe_accessor(name): + """ + Register a custom accessor with a Spark Connect DataFrame + + :param name: name used when calling the accessor after its registered + :returns: a class decorator callable. + """ + # pylint: disable=import-outside-toplevel + from pyspark.sql.connect.dataframe import DataFrame as psc_DataFrame + + return _register_accessor(name, psc_DataFrame) + + class PanderaDataFrameAccessor(PanderaAccessor): """Pandera accessor for pyspark DataFrame.""" @@ -127,3 +142,6 @@ def check_schema_type(schema): register_dataframe_accessor("pandera")(PanderaDataFrameAccessor) +# Handle optional Spark Connect imports for pyspark>=3.4 (if available) +if version.parse(pyspark.__version__) >= version.parse("3.4"): + register_connect_dataframe_accessor("pandera")(PanderaDataFrameAccessor) diff --git a/pandera/api/pyspark/types.py b/pandera/api/pyspark/types.py index 6ae1afb9..a3283959 100644 --- a/pandera/api/pyspark/types.py +++ b/pandera/api/pyspark/types.py @@ -1,14 +1,27 @@ """Utility functions for pyspark validation.""" from functools import lru_cache +from numpy import bool_ as np_bool +from packaging import version from typing import List, NamedTuple, Tuple, Type, Union +import pyspark import pyspark.sql.types as pst from pyspark.sql import DataFrame from pandera.api.checks import Check from pandera.dtypes import DataType +# Handles optional Spark Connect imports for pyspark>=3.4 (if available) +import pyspark + +if version.parse(pyspark.__version__) >= version.parse("3.4"): + from pyspark.sql.connect.dataframe import DataFrame as psc_DataFrame +else: + from pyspark.sql import DataFrame as psc_DataFrame + +DataFrameTypes = Union[DataFrame, psc_DataFrame] + CheckList = Union[Check, List[Check]] PysparkDefaultTypes = Union[ @@ -57,7 +70,7 @@ class PysparkDataframeColumnObject(NamedTuple): """Pyspark Object which holds dataframe and column value in a named tuble""" - dataframe: DataFrame + dataframe: DataFrameTypes column_name: str @@ -69,6 +82,7 @@ def supported_types() -> SupportedTypes: try: table_types.append(DataFrame) + table_types.append(psc_DataFrame) except ImportError: # pragma: no cover pass @@ -89,4 +103,4 @@ def is_table(obj): def is_bool(x): """Verifies whether an object is a boolean type.""" - return isinstance(x, (bool, type(pst.BooleanType()))) + return isinstance(x, (bool, type(pst.BooleanType()), np_bool)) diff --git a/pandera/backends/pyspark/base.py b/pandera/backends/pyspark/base.py index 1a31e792..80f4099d 100644 --- a/pandera/backends/pyspark/base.py +++ b/pandera/backends/pyspark/base.py @@ -22,6 +22,7 @@ scalar_failure_case, ) from pandera.errors import FailureCaseMetadata, SchemaError, SchemaWarning +from pandera.api.pyspark.types import DataFrameTypes class ColumnInfo(NamedTuple): @@ -34,7 +35,7 @@ class ColumnInfo(NamedTuple): lazy_exclude_column_names: List -FieldCheckObj = Union[col, DataFrame] +FieldCheckObj = Union[col, DataFrameTypes] T = TypeVar( "T", @@ -50,7 +51,7 @@ class PysparkSchemaBackend(BaseSchemaBackend): def subsample( self, - check_obj: DataFrame, + check_obj: DataFrameTypes, head: Optional[int] = None, tail: Optional[int] = None, sample: Optional[float] = None, diff --git a/pandera/backends/pyspark/checks.py b/pandera/backends/pyspark/checks.py index 5746e50e..9b0c5fc7 100644 --- a/pandera/backends/pyspark/checks.py +++ b/pandera/backends/pyspark/checks.py @@ -1,10 +1,7 @@ """Check backend for pyspark.""" from functools import partial -from typing import Dict, List, Optional - -from multimethod import DispatchError, overload -from pyspark.sql import DataFrame +from typing import Dict, List, Optional, Union from pandera.api.base.checks import CheckResult, GroupbyObject from pandera.api.checks import Check @@ -14,6 +11,7 @@ is_table, ) from pandera.backends.base import BaseCheckBackend +from pandera.api.pyspark.types import DataFrameTypes class PySparkCheckBackend(BaseCheckBackend): @@ -26,7 +24,7 @@ def __init__(self, check: Check): self.check = check self.check_fn = partial(check._check_fn, **check._check_kwargs) - def groupby(self, check_obj: DataFrame): # pragma: no cover + def groupby(self, check_obj: DataFrameTypes): # pragma: no cover """Implements groupby behavior for check object.""" assert self.check.groupby is not None, "Check.groupby must be set." if isinstance(self.check.groupby, (str, list)): @@ -45,61 +43,34 @@ def aggregate(self, check_obj): def _format_groupby_input( groupby_obj: GroupbyObject, groups: Optional[List[str]], - ) -> Dict[str, DataFrame]: # pragma: no cover + ) -> Dict[str, DataFrameTypes]: # pragma: no cover raise NotImplementedError - @overload # type: ignore [no-redef] def preprocess( self, - check_obj: DataFrame, + check_obj: DataFrameTypes, key: str, # type: ignore [valid-type] - ) -> DataFrame: + ) -> DataFrameTypes: return check_obj - # Workaround for multimethod not supporting Optional arguments - # such as `key: Optional[str]` (fails in multimethod) - # https://github.com/coady/multimethod/issues/90 - # FIXME when the multimethod supports Optional args # pylint: disable=fixme - @overload # type: ignore [no-redef] - def preprocess( + def apply( self, - check_obj: DataFrame, # type: ignore [valid-type] - ) -> DataFrame: - return check_obj - - @overload - def apply(self, check_obj): - """Apply the check function to a check object.""" - raise NotImplementedError - - @overload # type: ignore [no-redef] - def apply(self, check_obj: DataFrame): - return self.check_fn(check_obj) # pragma: no cover - - @overload # type: ignore [no-redef] - def apply(self, check_obj: is_table): # type: ignore [valid-type] - return self.check_fn(check_obj) # pragma: no cover - - @overload # type: ignore [no-redef] - def apply(self, check_obj: DataFrame, column_name: str, kwargs: dict): # type: ignore [valid-type] - # kwargs['column_name'] = column_name - # return self.check._check_fn(check_obj, *list(kwargs.values())) - check_obj_and_col_name = PysparkDataframeColumnObject( - check_obj, column_name - ) - return self.check._check_fn(check_obj_and_col_name, **kwargs) + check_obj: Union[DataFrameTypes, is_table], + column_name: str = None, + kwargs: dict = None, + ): + if column_name and kwargs: + check_obj_and_col_name = PysparkDataframeColumnObject( + check_obj, column_name + ) + return self.check._check_fn(check_obj_and_col_name, **kwargs) - @overload - def postprocess(self, check_obj, check_output): - """Postprocesses the result of applying the check function.""" - raise TypeError( # pragma: no cover - f"output type of check_fn not recognized: {type(check_output)}" - ) + else: + return self.check_fn(check_obj) # pragma: no cover - @overload # type: ignore [no-redef] def postprocess( self, - check_obj, + check_obj: DataFrameTypes, check_output: is_bool, # type: ignore [valid-type] ) -> CheckResult: """Postprocesses the result of applying the check function.""" @@ -112,29 +83,13 @@ def postprocess( def __call__( self, - check_obj: DataFrame, + check_obj: DataFrameTypes, key: Optional[str] = None, ) -> CheckResult: - if key is None: - # pylint:disable=no-value-for-parameter - check_obj = self.preprocess(check_obj) - else: - check_obj = self.preprocess(check_obj, key) - - try: - if key is None: - check_output = self.apply(check_obj) - else: - check_output = ( - self.apply( # pylint:disable=too-many-function-args - check_obj, key, self.check._check_kwargs - ) - ) - - except DispatchError as exc: # pragma: no cover - if exc.__cause__ is not None: - raise exc.__cause__ - raise exc - except TypeError as err: - raise err + check_obj = self.preprocess(check_obj, key) + + check_output = self.apply( # pylint:disable=too-many-function-args + check_obj, key, self.check._check_kwargs + ) + return self.postprocess(check_obj, check_output) diff --git a/pandera/backends/pyspark/container.py b/pandera/backends/pyspark/container.py index 8250f02a..768fb3f0 100644 --- a/pandera/backends/pyspark/container.py +++ b/pandera/backends/pyspark/container.py @@ -553,18 +553,6 @@ def unique( return check_obj - def _check_uniqueness( - self, - obj: DataFrame, - schema, - ) -> DataFrame: - """Ensure uniqueness in dataframe columns. - - :param obj: dataframe to check. - :param schema: schema object. - :returns: dataframe checked. - """ - ########## # Checks # ########## diff --git a/pandera/backends/pyspark/register.py b/pandera/backends/pyspark/register.py index 4f9d8429..1a6c04f4 100644 --- a/pandera/backends/pyspark/register.py +++ b/pandera/backends/pyspark/register.py @@ -1,8 +1,15 @@ """Register pyspark backends.""" from functools import lru_cache +from packaging import version -import pyspark.sql as pst +import pyspark +import pyspark.sql as ps + +# Handles optional Spark Connect imports for pyspark>=3.4 (if available) +CURRENT_PYSPARK_VERSION = version.parse(pyspark.__version__) +if CURRENT_PYSPARK_VERSION >= version.parse("3.4"): + from pyspark.sql.connect import dataframe as psc @lru_cache @@ -28,7 +35,14 @@ def register_pyspark_backends(): from pandera.backends.pyspark.components import ColumnBackend from pandera.backends.pyspark.container import DataFrameSchemaBackend - Check.register_backend(pst.DataFrame, PySparkCheckBackend) - ColumnSchema.register_backend(pst.DataFrame, ColumnSchemaBackend) - Column.register_backend(pst.DataFrame, ColumnBackend) - DataFrameSchema.register_backend(pst.DataFrame, DataFrameSchemaBackend) + # Register classical DataFrame + Check.register_backend(ps.DataFrame, PySparkCheckBackend) + ColumnSchema.register_backend(ps.DataFrame, ColumnSchemaBackend) + Column.register_backend(ps.DataFrame, ColumnBackend) + DataFrameSchema.register_backend(ps.DataFrame, DataFrameSchemaBackend) + # Register Spark Connect DataFrame, if available + if CURRENT_PYSPARK_VERSION >= version.parse("3.4"): + Check.register_backend(psc.DataFrame, PySparkCheckBackend) + ColumnSchema.register_backend(psc.DataFrame, ColumnSchemaBackend) + Column.register_backend(psc.DataFrame, ColumnBackend) + DataFrameSchema.register_backend(psc.DataFrame, DataFrameSchemaBackend)