-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #16 from facultyai/refactor
Refactor Schemas into module per schema type
- Loading branch information
Showing
16 changed files
with
793 additions
and
741 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
# flake8: noqa | ||
|
||
from .dataframe import * | ||
from .base import Dtypes | ||
from .records import RecordsDataFrameSchema | ||
from .split import SplitDataFrameSchema |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
from typing import Dict, List, NamedTuple, Union | ||
|
||
import marshmallow as ma | ||
import numpy as np | ||
import pandas as pd | ||
|
||
|
||
class Dtypes(NamedTuple): | ||
columns: List[str] | ||
dtypes: List[np.dtype] | ||
|
||
@classmethod | ||
def from_pandas_dtypes(cls, pd_dtypes: pd.Series) -> "Dtypes": | ||
return cls( | ||
columns=list(pd_dtypes.index), dtypes=list(pd_dtypes.values) | ||
) | ||
|
||
def to_pandas_dtypes(self) -> pd.DataFrame: | ||
return pd.DataFrame(index=self.columns, data=self.dtypes) | ||
|
||
|
||
def _validate_dtypes(dtypes: Union[Dtypes, pd.DataFrame]) -> Dtypes: | ||
if isinstance(dtypes, pd.Series): | ||
dtypes = Dtypes.from_pandas_dtypes(dtypes) | ||
elif not isinstance(dtypes, Dtypes): | ||
raise ValueError( | ||
"The `dtypes` Meta option on a DataFrame Schema must be either a " | ||
"pandas Series or an instance of marshmallow_dataframe.Dtypes" | ||
) | ||
|
||
return dtypes | ||
|
||
|
||
class DataFrameSchemaOpts(ma.SchemaOpts): | ||
"""Options class for BaseDataFrameSchema | ||
Adds the following options: | ||
- ``dtypes`` | ||
- ``index_dtype`` | ||
""" | ||
|
||
def __init__(self, meta, *args, **kwargs): | ||
super().__init__(meta, *args, **kwargs) | ||
self.dtypes = getattr(meta, "dtypes", None) | ||
if self.dtypes is not None: | ||
self.dtypes = _validate_dtypes(self.dtypes) | ||
self.index_dtype = getattr(meta, "index_dtype", None) | ||
self.strict = getattr(meta, "strict", True) | ||
|
||
|
||
class DataFrameSchemaMeta(ma.schema.SchemaMeta): | ||
"""Base metaclass for DataFrame schemas""" | ||
|
||
def __new__(meta, name, bases, class_dict): | ||
"""Only validate subclasses of our schemas""" | ||
klass = super().__new__(meta, name, bases, class_dict) | ||
|
||
if bases != (ma.Schema,) and klass.opts.dtypes is None: | ||
raise ValueError( | ||
"Subclasses of marshmallow_dataframe Schemas must define " | ||
"the `dtypes` Meta option" | ||
) | ||
|
||
return klass | ||
|
||
@classmethod | ||
def get_declared_fields( | ||
mcs, klass, cls_fields, inherited_fields, dict_cls | ||
) -> Dict[str, ma.fields.Field]: | ||
""" | ||
Updates declared fields with fields generated from DataFrame dtypes | ||
""" | ||
|
||
opts = klass.opts | ||
declared_fields = super().get_declared_fields( | ||
klass, cls_fields, inherited_fields, dict_cls | ||
) | ||
fields = mcs.get_fields(opts, dict_cls) | ||
fields.update(declared_fields) | ||
return fields | ||
|
||
@classmethod | ||
def get_fields( | ||
mcs, opts: DataFrameSchemaOpts, dict_cls | ||
) -> Dict[str, ma.fields.Field]: | ||
""" | ||
Generate fields from DataFrame dtypes | ||
To be implemented in subclasses of DataFrameSchemaMeta | ||
""" | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
from collections import defaultdict | ||
from typing import Dict | ||
|
||
import marshmallow as ma | ||
import numpy as np | ||
import pandas as pd | ||
|
||
DTYPE_KIND_TO_FIELD_CLASS = { | ||
"i": ma.fields.Int, | ||
"u": ma.fields.Int, | ||
"f": ma.fields.Float, | ||
"O": ma.fields.Str, | ||
"U": ma.fields.Str, | ||
"S": ma.fields.Str, | ||
"b": ma.fields.Bool, | ||
"M": ma.fields.DateTime, | ||
"m": ma.fields.TimeDelta, | ||
} | ||
|
||
_DEFAULT_FIELD_OPTIONS = {"required": True, "allow_none": True} | ||
|
||
DTYPE_KIND_TO_FIELD_OPTIONS: Dict[str, Dict[str, bool]] = defaultdict( | ||
lambda: _DEFAULT_FIELD_OPTIONS | ||
) | ||
# Integer columns in pandas cannot have null values, so we allow_none for all | ||
# types except int | ||
DTYPE_KIND_TO_FIELD_OPTIONS["i"] = {"required": True} | ||
DTYPE_KIND_TO_FIELD_OPTIONS["f"] = { | ||
"allow_nan": True, | ||
**_DEFAULT_FIELD_OPTIONS, | ||
} | ||
|
||
|
||
class DtypeToFieldConversionError(Exception): | ||
pass | ||
|
||
|
||
def dtype_to_field(dtype: np.dtype) -> ma.fields.Field: | ||
# Object dtypes require more detailed mapping | ||
if pd.api.types.is_categorical_dtype(dtype): | ||
categories = dtype.categories.values.tolist() | ||
kind = dtype.categories.dtype.kind | ||
field_class = DTYPE_KIND_TO_FIELD_CLASS[kind] | ||
field_options = DTYPE_KIND_TO_FIELD_OPTIONS[kind] | ||
return field_class( | ||
validate=ma.validate.OneOf(categories), **field_options | ||
) | ||
|
||
try: | ||
kind = dtype.kind | ||
except AttributeError as exc: | ||
raise DtypeToFieldConversionError( | ||
f"The dtype {dtype} does not have a `kind` attribute, " | ||
"unable to map dtype into marshmallow field type" | ||
) from exc | ||
|
||
try: | ||
field_class = DTYPE_KIND_TO_FIELD_CLASS[kind] | ||
field_options = DTYPE_KIND_TO_FIELD_OPTIONS[kind] | ||
return field_class(**field_options) | ||
except KeyError as exc: | ||
raise DtypeToFieldConversionError( | ||
f"The conversion of the dtype {dtype} with kind {dtype.kind} " | ||
"into marshmallow fields is unknown. Known kinds are: " | ||
f"{DTYPE_KIND_TO_FIELD_CLASS.keys()}" | ||
) from exc |
Oops, something went wrong.