Skip to content

Commit

Permalink
Add minimal support for connect_dfs, without changing all type annota…
Browse files Browse the repository at this point in the history
…tions

Signed-off-by: Filipe Oliveira <filipe_oliveira@mckinsey.com>
  • Loading branch information
filipeo2-mck committed Aug 2, 2024
1 parent 97221a6 commit bce8e64
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 93 deletions.
20 changes: 19 additions & 1 deletion pandera/accessors/pyspark_sql_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
"""

import warnings
from packaging import version
from typing import Optional

import pyspark
from pandera.api.base.error_handler import ErrorHandler
from pandera.api.pyspark.container import DataFrameSchema

Expand Down Expand Up @@ -104,7 +106,7 @@ def decorator(accessor):

def register_dataframe_accessor(name):
"""
Register a custom accessor with a DataFrame
Register a custom accessor with a classical Spark DataFrame
:param name: name used when calling the accessor after its registered
:returns: a class decorator callable.
Expand All @@ -115,6 +117,19 @@ def register_dataframe_accessor(name):
return _register_accessor(name, DataFrame)


def register_connect_dataframe_accessor(name):
"""
Register a custom accessor with a Spark Connect DataFrame
:param name: name used when calling the accessor after its registered
:returns: a class decorator callable.
"""
# pylint: disable=import-outside-toplevel
from pyspark.sql.connect.dataframe import DataFrame as psc_DataFrame

return _register_accessor(name, psc_DataFrame)


class PanderaDataFrameAccessor(PanderaAccessor):
"""Pandera accessor for pyspark DataFrame."""

Expand All @@ -127,3 +142,6 @@ def check_schema_type(schema):


register_dataframe_accessor("pandera")(PanderaDataFrameAccessor)
# Handle optional Spark Connect imports for pyspark>=3.4 (if available)
if version.parse(pyspark.__version__) >= version.parse("3.4"):
register_connect_dataframe_accessor("pandera")(PanderaDataFrameAccessor)
18 changes: 16 additions & 2 deletions pandera/api/pyspark/types.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,27 @@
"""Utility functions for pyspark validation."""

from functools import lru_cache
from numpy import bool_ as np_bool
from packaging import version
from typing import List, NamedTuple, Tuple, Type, Union

import pyspark
import pyspark.sql.types as pst
from pyspark.sql import DataFrame

from pandera.api.checks import Check
from pandera.dtypes import DataType

# Handles optional Spark Connect imports for pyspark>=3.4 (if available)
import pyspark

if version.parse(pyspark.__version__) >= version.parse("3.4"):
from pyspark.sql.connect.dataframe import DataFrame as psc_DataFrame
else:
from pyspark.sql import DataFrame as psc_DataFrame

DataFrameTypes = Union[DataFrame, psc_DataFrame]

CheckList = Union[Check, List[Check]]

PysparkDefaultTypes = Union[
Expand Down Expand Up @@ -57,7 +70,7 @@
class PysparkDataframeColumnObject(NamedTuple):
"""Pyspark Object which holds dataframe and column value in a named tuble"""

dataframe: DataFrame
dataframe: DataFrameTypes
column_name: str


Expand All @@ -69,6 +82,7 @@ def supported_types() -> SupportedTypes:

try:
table_types.append(DataFrame)
table_types.append(psc_DataFrame)

except ImportError: # pragma: no cover
pass
Expand All @@ -89,4 +103,4 @@ def is_table(obj):

def is_bool(x):
"""Verifies whether an object is a boolean type."""
return isinstance(x, (bool, type(pst.BooleanType())))
return isinstance(x, (bool, type(pst.BooleanType()), np_bool))
5 changes: 3 additions & 2 deletions pandera/backends/pyspark/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
scalar_failure_case,
)
from pandera.errors import FailureCaseMetadata, SchemaError, SchemaWarning
from pandera.api.pyspark.types import DataFrameTypes


class ColumnInfo(NamedTuple):
Expand All @@ -34,7 +35,7 @@ class ColumnInfo(NamedTuple):
lazy_exclude_column_names: List


FieldCheckObj = Union[col, DataFrame]
FieldCheckObj = Union[col, DataFrameTypes]

T = TypeVar(
"T",
Expand All @@ -50,7 +51,7 @@ class PysparkSchemaBackend(BaseSchemaBackend):

def subsample(
self,
check_obj: DataFrame,
check_obj: DataFrameTypes,
head: Optional[int] = None,
tail: Optional[int] = None,
sample: Optional[float] = None,
Expand Down
97 changes: 26 additions & 71 deletions pandera/backends/pyspark/checks.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
"""Check backend for pyspark."""

from functools import partial
from typing import Dict, List, Optional

from multimethod import DispatchError, overload
from pyspark.sql import DataFrame
from typing import Dict, List, Optional, Union

from pandera.api.base.checks import CheckResult, GroupbyObject
from pandera.api.checks import Check
Expand All @@ -14,6 +11,7 @@
is_table,
)
from pandera.backends.base import BaseCheckBackend
from pandera.api.pyspark.types import DataFrameTypes


class PySparkCheckBackend(BaseCheckBackend):
Expand All @@ -26,7 +24,7 @@ def __init__(self, check: Check):
self.check = check
self.check_fn = partial(check._check_fn, **check._check_kwargs)

def groupby(self, check_obj: DataFrame): # pragma: no cover
def groupby(self, check_obj: DataFrameTypes): # pragma: no cover
"""Implements groupby behavior for check object."""
assert self.check.groupby is not None, "Check.groupby must be set."
if isinstance(self.check.groupby, (str, list)):
Expand All @@ -45,61 +43,34 @@ def aggregate(self, check_obj):
def _format_groupby_input(
groupby_obj: GroupbyObject,
groups: Optional[List[str]],
) -> Dict[str, DataFrame]: # pragma: no cover
) -> Dict[str, DataFrameTypes]: # pragma: no cover
raise NotImplementedError

@overload # type: ignore [no-redef]
def preprocess(
self,
check_obj: DataFrame,
check_obj: DataFrameTypes,
key: str, # type: ignore [valid-type]
) -> DataFrame:
) -> DataFrameTypes:
return check_obj

# Workaround for multimethod not supporting Optional arguments
# such as `key: Optional[str]` (fails in multimethod)
# https://github.com/coady/multimethod/issues/90
# FIXME when the multimethod supports Optional args # pylint: disable=fixme
@overload # type: ignore [no-redef]
def preprocess(
def apply(
self,
check_obj: DataFrame, # type: ignore [valid-type]
) -> DataFrame:
return check_obj

@overload
def apply(self, check_obj):
"""Apply the check function to a check object."""
raise NotImplementedError

@overload # type: ignore [no-redef]
def apply(self, check_obj: DataFrame):
return self.check_fn(check_obj) # pragma: no cover

@overload # type: ignore [no-redef]
def apply(self, check_obj: is_table): # type: ignore [valid-type]
return self.check_fn(check_obj) # pragma: no cover

@overload # type: ignore [no-redef]
def apply(self, check_obj: DataFrame, column_name: str, kwargs: dict): # type: ignore [valid-type]
# kwargs['column_name'] = column_name
# return self.check._check_fn(check_obj, *list(kwargs.values()))
check_obj_and_col_name = PysparkDataframeColumnObject(
check_obj, column_name
)
return self.check._check_fn(check_obj_and_col_name, **kwargs)
check_obj: Union[DataFrameTypes, is_table],
column_name: str = None,
kwargs: dict = None,
):
if column_name and kwargs:
check_obj_and_col_name = PysparkDataframeColumnObject(
check_obj, column_name
)
return self.check._check_fn(check_obj_and_col_name, **kwargs)

@overload
def postprocess(self, check_obj, check_output):
"""Postprocesses the result of applying the check function."""
raise TypeError( # pragma: no cover
f"output type of check_fn not recognized: {type(check_output)}"
)
else:
return self.check_fn(check_obj) # pragma: no cover

@overload # type: ignore [no-redef]
def postprocess(
self,
check_obj,
check_obj: DataFrameTypes,
check_output: is_bool, # type: ignore [valid-type]
) -> CheckResult:
"""Postprocesses the result of applying the check function."""
Expand All @@ -112,29 +83,13 @@ def postprocess(

def __call__(
self,
check_obj: DataFrame,
check_obj: DataFrameTypes,
key: Optional[str] = None,
) -> CheckResult:
if key is None:
# pylint:disable=no-value-for-parameter
check_obj = self.preprocess(check_obj)
else:
check_obj = self.preprocess(check_obj, key)

try:
if key is None:
check_output = self.apply(check_obj)
else:
check_output = (
self.apply( # pylint:disable=too-many-function-args
check_obj, key, self.check._check_kwargs
)
)

except DispatchError as exc: # pragma: no cover
if exc.__cause__ is not None:
raise exc.__cause__
raise exc
except TypeError as err:
raise err
check_obj = self.preprocess(check_obj, key)

check_output = self.apply( # pylint:disable=too-many-function-args
check_obj, key, self.check._check_kwargs
)

return self.postprocess(check_obj, check_output)
12 changes: 0 additions & 12 deletions pandera/backends/pyspark/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,18 +553,6 @@ def unique(

return check_obj

def _check_uniqueness(
self,
obj: DataFrame,
schema,
) -> DataFrame:
"""Ensure uniqueness in dataframe columns.
:param obj: dataframe to check.
:param schema: schema object.
:returns: dataframe checked.
"""

##########
# Checks #
##########
Expand Down
24 changes: 19 additions & 5 deletions pandera/backends/pyspark/register.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
"""Register pyspark backends."""

from functools import lru_cache
from packaging import version

import pyspark.sql as pst
import pyspark
import pyspark.sql as ps

# Handles optional Spark Connect imports for pyspark>=3.4 (if available)
CURRENT_PYSPARK_VERSION = version.parse(pyspark.__version__)
if CURRENT_PYSPARK_VERSION >= version.parse("3.4"):
from pyspark.sql.connect import dataframe as psc


@lru_cache
Expand All @@ -28,7 +35,14 @@ def register_pyspark_backends():
from pandera.backends.pyspark.components import ColumnBackend
from pandera.backends.pyspark.container import DataFrameSchemaBackend

Check.register_backend(pst.DataFrame, PySparkCheckBackend)
ColumnSchema.register_backend(pst.DataFrame, ColumnSchemaBackend)
Column.register_backend(pst.DataFrame, ColumnBackend)
DataFrameSchema.register_backend(pst.DataFrame, DataFrameSchemaBackend)
# Register classical DataFrame
Check.register_backend(ps.DataFrame, PySparkCheckBackend)
ColumnSchema.register_backend(ps.DataFrame, ColumnSchemaBackend)
Column.register_backend(ps.DataFrame, ColumnBackend)
DataFrameSchema.register_backend(ps.DataFrame, DataFrameSchemaBackend)
# Register Spark Connect DataFrame, if available
if CURRENT_PYSPARK_VERSION >= version.parse("3.4"):
Check.register_backend(psc.DataFrame, PySparkCheckBackend)
ColumnSchema.register_backend(psc.DataFrame, ColumnSchemaBackend)
Column.register_backend(psc.DataFrame, ColumnBackend)
DataFrameSchema.register_backend(psc.DataFrame, DataFrameSchemaBackend)

0 comments on commit bce8e64

Please sign in to comment.