Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhancement: drop invalid rows on validate with new param #1189

Merged
merged 38 commits into from
Jun 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
72b3eed
Basic ArraySchema default for str series
kykyi Mar 22, 2023
40f851f
Add parameterised test cases for various data types
kykyi Mar 22, 2023
e18aa6c
Ensure column has a default
kykyi Mar 22, 2023
de6e211
Add some tests asserting Column.default works as expected
kykyi Mar 22, 2023
a9c8a40
Add tests asserting default causes an error when there is a dtype mis…
kykyi Mar 22, 2023
2e210fa
Remove inplace=True hardcoding, add default as kwarg across various c…
kykyi Mar 23, 2023
b626cb8
Simplify Column tests to avoid using DataFrameSchema
kykyi Mar 23, 2023
212fdff
Add test to raise error if inplace is False and default is non null
kykyi Mar 23, 2023
096afbb
any -> Any
kykyi Mar 28, 2023
5acb3dd
clean up PR
cosmicBboy Apr 14, 2023
8b709de
remove codecov
cosmicBboy Apr 14, 2023
e66abbd
xfail pyspark tests
cosmicBboy Apr 14, 2023
91e6250
Merge branch 'unionai-oss:main' into main
kykyi May 16, 2023
c2b6e6e
Merge branch 'unionai-oss:main' into main
kykyi Jun 4, 2023
5905b19
Simplify drop_invalid into a kwarg for schema.validate().
kykyi May 22, 2023
f86f279
Update docstrings
kykyi May 22, 2023
5efc041
Add a couple more test cases
kykyi May 23, 2023
87bce7c
Re-raise error on drop_invalid false, move some logic into a private …
kykyi Jun 4, 2023
fa24980
Add drop_invalid for SeriesSchema
kykyi Jun 4, 2023
039fd1c
Add drop_invalid to MultiIndex
kykyi Jun 4, 2023
7686b07
Small changes to fix mypy
kykyi Jun 4, 2023
478fc5e
More mypy fixes
kykyi Jun 4, 2023
bf80ef2
Move run_checks_and_handle_errors into it's own method with core chec…
kykyi Jun 4, 2023
1458f6b
Remove try/catch
kykyi Jun 4, 2023
b5de710
Move drop_logic into it's own method for array.py and container.py
kykyi Jun 4, 2023
5935b32
drop_invalid -> drop_invalid_data
kykyi Jun 4, 2023
0b2f6fb
Remove main() block from test_schemas.py
kykyi Jun 4, 2023
3180d31
Fix typo
kykyi Jun 4, 2023
95c4413
Add test for ColumnBackend
kykyi Jun 5, 2023
2140396
Move drop_invalid from validation to schema init
kykyi Jun 6, 2023
0a304e9
Stylistic changes
kykyi Jun 6, 2023
39072ff
Remove incorrect rescue logic in ColumnBackend
kykyi Jun 8, 2023
94394f9
Add draft docs
kykyi Jun 9, 2023
1f14cca
Add functionality for drop_invalid on DataFrameModel schemas
kykyi Jun 9, 2023
abc0324
Standardise tests
kykyi Jun 9, 2023
75b3cc7
Update docs for DataFrameModel
kykyi Jun 9, 2023
9dfba4e
Add docstrings
kykyi Jun 9, 2023
e721458
rename of `drop_invalid_rows`, exception handling, update docs
cosmicBboy Jun 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions docs/source/drop_invalid_rows.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
.. currentmodule:: pandera

.. _drop_invalid_rows:

Dropping Invalid Rows
=====================

*New in version 0.16.0*

If you wish to use the validation step to remove invalid data, you can pass the
``drop_invalid_rows=True`` argument to the ``schema`` object on creation. On ``schema.validate()``,
if a data-level check fails, then that row which caused the failure will be removed from the dataframe
when it is returned.

``drop_invalid`` will prevent data-level schema errors being raised and will instead
remove the rows which causes the failure.

This functionality is available on ``DataFrameSchema``, ``SeriesSchema``, ``Column``,
as well as ``DataFrameModel`` schemas.

Dropping invalid rows with :class:`~pandera.api.pandas.container.DataFrameSchema`:

.. testcode:: drop_invalid_rows_data_frame_schema

import pandas as pd
import pandera as pa

from pandera import Check, Column, DataFrameSchema

df = pd.DataFrame({"counter": ["1", "2", "3"]})
schema = DataFrameSchema(
{"counter": Column(int, checks=[Check(lambda x: x >= 3)])},
drop_invalid_rows=True,
)

schema.validate(df, lazy=True)

Dropping invalid rows with :class:`~pandera.api.pandas.array.SeriesSchema`:

.. testcode:: drop_invalid_rows_series_schema

import pandas as pd
import pandera as pa

from pandera import Check, SeriesSchema

series = pd.Series(["1", "2", "3"])
schema = SeriesSchema(
int,
checks=[Check(lambda x: x >= 3)],
drop_invalid_rows=True,
)

schema.validate(series, lazy=True)

Dropping invalid rows with :class:`~pandera.api.pandas.components.Column`:

.. testcode:: drop_invalid_rows_column

import pandas as pd
import pandera as pa

from pandera import Check, Column

df = pd.DataFrame({"counter": ["1", "2", "3"]})
schema = Column(
int,
name="counter",
drop_invalid_rows=True,
checks=[Check(lambda x: x >= 3)]
)

schema.validate(df, lazy=True)

Dropping invalid rows with :class:`~pandera.api.pandas.model.DataFrameModel`:

.. testcode:: drop_invalid_rows_data_frame_model

import pandas as pd
import pandera as pa

from pandera import Check, DataFrameModel, Field

class MySchema(DataFrameModel):
counter: int = Field(in_range={"min_value": 3, "max_value": 5})

class Config:
drop_invalid_rows = True


MySchema.validate(
pd.DataFrame({"counter": [1, 2, 3, 4, 5, 6]}), lazy=True
)

.. note::
In order to use ``drop_invalid_rows=True``, ``lazy=True`` must
be passed to the ``schema.validate()``. :ref:`lazy_validation` enables all schema
errors to be collected and raised together, meaning all invalid rows can be dropped together.
This provides clear API for ensuring the validated dataframe contains only valid data.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ page or reach out to the maintainers and pandera community on
hypothesis
dtypes
decorators
drop_invalid_rows
schema_inference
lazy_validation
data_synthesis_strategies
Expand Down
2 changes: 2 additions & 0 deletions pandera/api/base/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
name=None,
title=None,
description=None,
drop_invalid_rows=False,
):
"""Abstract base schema initializer."""
self.dtype = dtype
Expand All @@ -40,6 +41,7 @@ def __init__(
self.name = name
self.title = title
self.description = description
self.drop_invalid_rows = drop_invalid_rows

def validate(
self,
Expand Down
7 changes: 7 additions & 0 deletions pandera/api/pandas/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
title: Optional[str] = None,
description: Optional[str] = None,
default: Optional[Any] = None,
drop_invalid_rows: bool = False,
) -> None:
"""Initialize array schema.

Expand All @@ -63,6 +64,8 @@ def __init__(
:param title: A human-readable label for the series.
:param description: An arbitrary textual description of the series.
:param default: The default value for missing values in the series.
:param drop_invalid_rows: if True, drop invalid rows on validation.

"""

super().__init__(
Expand All @@ -72,6 +75,7 @@ def __init__(
name=name,
title=title,
description=description,
drop_invalid_rows=drop_invalid_rows,
)

if checks is None:
Expand Down Expand Up @@ -300,6 +304,7 @@ def __init__(
title: Optional[str] = None,
description: Optional[str] = None,
default: Optional[Any] = None,
drop_invalid_rows: bool = False,
) -> None:
"""Initialize series schema base object.

Expand Down Expand Up @@ -327,6 +332,7 @@ def __init__(
:param title: A human-readable label for the series.
:param description: An arbitrary textual description of the series.
:param default: The default value for missing values in the series.
:param drop_invalid_rows: if True, drop invalid rows on validation.

"""
super().__init__(
Expand All @@ -340,6 +346,7 @@ def __init__(
title,
description,
default,
drop_invalid_rows,
)
self.index = index

Expand Down
3 changes: 3 additions & 0 deletions pandera/api/pandas/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
title: Optional[str] = None,
description: Optional[str] = None,
default: Optional[Any] = None,
drop_invalid_rows: bool = False,
) -> None:
"""Create column validator object.

Expand All @@ -54,6 +55,7 @@ def __init__(
:param title: A human-readable label for the column.
:param description: An arbitrary textual description of the column.
:param default: The default value for missing values in the column.
:param drop_invalid_rows: if True, drop invalid rows on validation.

:raises SchemaInitError: if impossible to build schema from parameters

Expand Down Expand Up @@ -85,6 +87,7 @@ def __init__(
title=title,
description=description,
default=default,
drop_invalid_rows=drop_invalid_rows,
)
if (
name is not None
Expand Down
3 changes: 3 additions & 0 deletions pandera/api/pandas/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(
unique_column_names: bool = False,
title: Optional[str] = None,
description: Optional[str] = None,
drop_invalid_rows: bool = False,
) -> None:
"""Initialize DataFrameSchema validator.

Expand Down Expand Up @@ -77,6 +78,7 @@ def __init__(
:param unique_column_names: whether or not column names must be unique.
:param title: A human-readable label for the schema.
:param description: An arbitrary textual description of the schema.
:param drop_invalid_rows: if True, drop invalid rows on validation.

:raises SchemaInitError: if impossible to build schema from parameters

Expand Down Expand Up @@ -152,6 +154,7 @@ def __init__(
self._unique = unique
self.report_duplicates = report_duplicates
self.unique_column_names = unique_column_names
self.drop_invalid_rows = drop_invalid_rows

# this attribute is not meant to be accessed by users and is explicitly
# set to True in the case that a schema is created by infer_schema.
Expand Down
1 change: 1 addition & 0 deletions pandera/api/pandas/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ def to_schema(cls) -> DataFrameSchema:
"title": cls.__config__.title,
"description": cls.__config__.description or cls.__doc__,
"unique_column_names": cls.__config__.unique_column_names,
"drop_invalid_rows": cls.__config__.drop_invalid_rows,
}
cls.__schema__ = DataFrameSchema(
columns,
Expand Down
1 change: 1 addition & 0 deletions pandera/api/pandas/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class BaseConfig(BaseModelConfig): # pylint:disable=R0903
title: Optional[str] = None #: human-readable label for schema
description: Optional[str] = None #: arbitrary textual description
coerce: bool = False #: coerce types of all schema components
drop_invalid_rows: bool = False #: drop invalid rows on validation

#: make sure certain column combinations are unique
unique: Optional[Union[str, List[str]]] = None
Expand Down
4 changes: 4 additions & 0 deletions pandera/backends/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ def failure_cases_metadata(
"""Get failure cases metadata for lazy validation."""
raise NotImplementedError

def drop_invalid_rows(self, check_obj, error_handler):
"""Remove invalid elements in a `check_obj` according to failures in caught by the `error_handler`"""
raise NotImplementedError


class BaseCheckBackend(ABC):
"""Abstract base class for a check backend implementation."""
Expand Down
57 changes: 47 additions & 10 deletions pandera/backends/pandas/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
SchemaError,
SchemaErrors,
SchemaErrorReason,
SchemaDefinitionError,
)


Expand All @@ -45,6 +46,11 @@ def validate(
error_handler = SchemaErrorHandler(lazy)
check_obj = self.preprocess(check_obj, inplace)

if getattr(schema, "drop_invalid_rows", False) and not lazy:
raise SchemaDefinitionError(
"When drop_invalid_rows is True, lazy must be set to True."
)

# fill nans with `default` if it's present
if hasattr(schema, "default") and pd.notna(schema.default):
check_obj.fillna(schema.default, inplace=True)
Expand All @@ -55,6 +61,42 @@ def validate(
except SchemaError as exc:
error_handler.collect_error(exc.reason_code, exc)

# run the core checks
error_handler = self.run_checks_and_handle_errors(
error_handler,
schema,
check_obj,
head,
tail,
sample,
random_state,
)

if lazy and error_handler.collected_errors:
if getattr(schema, "drop_invalid_rows", False):
check_obj = self.drop_invalid_rows(check_obj, error_handler)
return check_obj
else:
raise SchemaErrors(
schema=schema,
schema_errors=error_handler.collected_errors,
data=check_obj,
)

return check_obj

def run_checks_and_handle_errors(
Copy link
Contributor Author

@kykyi kykyi Jun 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cosmicBboy this method could be moved into the parent class to remove the duplication, but I'm not sure this would be the right move. They are quite different implementations, and don't want to abstract it to the parent for some vain DRYness 😅

edit: I will move drop_invalid_data into the parent though

self,
error_handler,
schema,
check_obj,
head,
tail,
sample,
random_state,
):
"""Run checks on schema"""
# pylint: disable=too-many-locals
field_obj_subsample = self.subsample(
check_obj if is_field(check_obj) else check_obj[schema.name],
head,
Expand All @@ -71,14 +113,15 @@ def validate(
random_state,
)

# run the core checks
for core_check, args in (
core_checks = [
(self.check_name, (field_obj_subsample, schema)),
(self.check_nullable, (field_obj_subsample, schema)),
(self.check_unique, (field_obj_subsample, schema)),
(self.check_dtype, (field_obj_subsample, schema)),
(self.run_checks, (check_obj_subsample, schema)),
):
]

for core_check, args in core_checks:
results = core_check(*args)
if isinstance(results, CoreCheckResult):
results = [results]
Expand Down Expand Up @@ -106,13 +149,7 @@ def validate(
original_exc=result.original_exc,
)

if lazy and error_handler.collected_errors:
raise SchemaErrors(
schema=schema,
schema_errors=error_handler.collected_errors,
data=check_obj,
)
return check_obj
return error_handler

def coerce_dtype(
self,
Expand Down
10 changes: 10 additions & 0 deletions pandera/backends/pandas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
scalar_failure_case,
)
from pandera.errors import FailureCaseMetadata, SchemaError, SchemaErrorReason
from pandera.error_handlers import SchemaErrorHandler


class ColumnInfo(NamedTuple):
Expand Down Expand Up @@ -149,3 +150,12 @@ def failure_cases_metadata(
message=message,
error_counts=error_counts,
)

def drop_invalid_rows(self, check_obj, error_handler: SchemaErrorHandler):
"""Remove invalid elements in a check obj according to failures in caught by the error handler."""
errors = error_handler.collected_errors
for err in errors:
check_obj = check_obj.loc[
~check_obj.index.isin(err.failure_cases["index"])
]
return check_obj
Loading