Skip to content

Commit

Permalink
⚡ Added flag for easier validity check
Browse files Browse the repository at this point in the history
  • Loading branch information
Luanee committed Nov 9, 2023
1 parent ca293f1 commit fd50fad
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 6 deletions.
44 changes: 39 additions & 5 deletions pandera_report/validator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import (
Callable,
cast,
Literal,
Optional,
overload,
Type,
TypedDict,
Union,
Expand Down Expand Up @@ -42,6 +44,8 @@ def __init__(
self._col_status = self._columns["status"]
self._parser = parser or DefaultFailureCaseParser()

self._is_valid = None

@property
def columns(self) -> TypedDict:
"""
Expand All @@ -52,7 +56,30 @@ def columns(self) -> TypedDict:
"""
return self._columns

def validate(self, schema: Union[Type[pa.DataFrameModel], pa.DataFrameSchema], df: pd.DataFrame) -> pd.DataFrame:
@overload
def validate(
self,
schema: Union[Type[pa.DataFrameModel], pa.DataFrameSchema],
df: pd.DataFrame,
validity_flag: bool = False,
) -> pd.DataFrame:
...

@overload
def validate(
self,
schema: Union[Type[pa.DataFrameModel], pa.DataFrameSchema],
df: pd.DataFrame,
validity_flag: bool = True,
) -> tuple[bool, pd.DataFrame]:
...

def validate(
self,
schema: Union[Type[pa.DataFrameModel], pa.DataFrameSchema],
df: pd.DataFrame,
validity_flag: bool = False,
) -> tuple[bool, pd.DataFrame] | pd.DataFrame:
"""
Validate a DataFrame using a Pandera schema and generate a quality report.
Expand All @@ -67,21 +94,29 @@ def validate(self, schema: Union[Type[pa.DataFrameModel], pa.DataFrameSchema], d
schema = schema.to_schema()

error: Optional[SchemaError | SchemaErrors] = None
is_valid = False

try:
df = schema.validate(df, lazy=self.lazy)
df_failure = pd.DataFrame()
is_valid = True
except (SchemaErrors, SchemaError) as schema_error:
df_failure = cast(pd.DataFrame, schema_error.failure_cases)
error = schema_error

if not self.quality_report:
if error:
raise error

if validity_flag:
return is_valid, df
return df

error = error if isinstance(error, SchemaError) else None
return self.assign_quality_report(df, df_failure, error)
df = self.assign_quality_report(df, df_failure, error)
if validity_flag:
return is_valid, df
return df

def assign_quality_report(
self, df: pd.DataFrame, df_failure: pd.DataFrame, error: Optional[SchemaError]
Expand All @@ -99,10 +134,9 @@ def assign_quality_report(
"""
number_of_rows = df.shape[0] or 1

if df.empty:
df_failure = df_failure[df_failure["schema_context"].str.lower() != "column"]

if not df_failure.empty:
if df.empty:
df_failure = df_failure[df_failure["schema_context"].str.lower() != "column"]
df_failure = self.validate_failure_case_dataframe(df_failure, error)
df_failure = self.transform_failure_cases_dataframe(df_failure, number_of_rows)

Expand Down
42 changes: 41 additions & 1 deletion tests/test_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pandera.errors import SchemaError, SchemaErrors
from pandera.typing import Series

from pandera_report.options import QualityColumnsOptions
from pandera_report.options import QUALITY_COLUMNS_OPTIONS, QualityColumnsOptions
from pandera_report.parser import FailureCaseParser
from pandera_report.validator import DataFrameValidator

Expand Down Expand Up @@ -45,6 +45,12 @@ def custom_check(cls, value: Series[str]) -> bool:
return value.str.split("_", expand=True).shape[1] == 2


class EmptySchemaModel(pa.DataFrameModel):
column1: Series[int] = pa.Field(nullable=True, coerce=True)
column2: Series[float] = pa.Field(nullable=True)
column3: Series[str] = pa.Field(nullable=True)


custom_columns: QualityColumnsOptions = {"issues": "what's that?", "status": "does it work?"}


Expand Down Expand Up @@ -87,3 +93,37 @@ def test_validator_validate(
org_columns += list(validator.columns.values())

assert df.columns.to_list() == org_columns


@pytest.mark.parametrize(
"df_fixture,schema,validity,expected",
[
("df_valid", schema, True, tuple),
("df_valid", schema, False, pd.DataFrame),
],
)
def test_validator_validate_flag(
df_fixture: str,
schema: Union[Type[pa.DataFrameModel], pa.DataFrameSchema],
validity: bool,
expected: Type,
request,
):
df = cast(pd.DataFrame, request.getfixturevalue(df_fixture))
validator = DataFrameValidator()

df_validated = validator.validate(schema, df, validity_flag=validity)
assert isinstance(df_validated, expected)


@pytest.mark.parametrize(
"columns,expected",
[
(None, QUALITY_COLUMNS_OPTIONS),
(custom_columns, custom_columns),
],
)
def test_validator_columns(columns: Optional[QualityColumnsOptions], expected: QualityColumnsOptions):
validator = DataFrameValidator(columns=columns)

assert validator.columns == expected

0 comments on commit fd50fad

Please sign in to comment.