Skip to content

Commit

Permalink
Merge pull request #16 from facultyai/refactor
Browse files Browse the repository at this point in the history
Refactor Schemas into module per schema type
  • Loading branch information
zblz authored May 24, 2019
2 parents 0f87b82 + 1a64c0d commit e4e6a02
Show file tree
Hide file tree
Showing 16 changed files with 793 additions and 741 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ matrix:
include:
- name: Lint
python: 3.7
env: TOXENV=black,flake8
env: TOXENV=black,flake8,isort
- python: 3.6
env: TOXENV=py36
- python: 3.7
Expand Down
7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,10 @@ exclude = '''
| \.eggs
)/
'''

[tool.isort]
combine_as_imports = true
force_grid_wrap = 0
include_trailing_comma = true
line_length = 79
multi_line_output = 3
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from os import path
from setuptools import setup, find_packages

from setuptools import find_packages, setup

with open(
path.join(path.abspath(path.dirname(__file__)), "README.md"),
Expand Down
4 changes: 3 additions & 1 deletion src/marshmallow_dataframe/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# flake8: noqa

from .dataframe import *
from .base import Dtypes
from .records import RecordsDataFrameSchema
from .split import SplitDataFrameSchema
91 changes: 91 additions & 0 deletions src/marshmallow_dataframe/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from typing import Dict, List, NamedTuple, Union

import marshmallow as ma
import numpy as np
import pandas as pd


class Dtypes(NamedTuple):
columns: List[str]
dtypes: List[np.dtype]

@classmethod
def from_pandas_dtypes(cls, pd_dtypes: pd.Series) -> "Dtypes":
return cls(
columns=list(pd_dtypes.index), dtypes=list(pd_dtypes.values)
)

def to_pandas_dtypes(self) -> pd.DataFrame:
return pd.DataFrame(index=self.columns, data=self.dtypes)


def _validate_dtypes(dtypes: Union[Dtypes, pd.DataFrame]) -> Dtypes:
if isinstance(dtypes, pd.Series):
dtypes = Dtypes.from_pandas_dtypes(dtypes)
elif not isinstance(dtypes, Dtypes):
raise ValueError(
"The `dtypes` Meta option on a DataFrame Schema must be either a "
"pandas Series or an instance of marshmallow_dataframe.Dtypes"
)

return dtypes


class DataFrameSchemaOpts(ma.SchemaOpts):
"""Options class for BaseDataFrameSchema
Adds the following options:
- ``dtypes``
- ``index_dtype``
"""

def __init__(self, meta, *args, **kwargs):
super().__init__(meta, *args, **kwargs)
self.dtypes = getattr(meta, "dtypes", None)
if self.dtypes is not None:
self.dtypes = _validate_dtypes(self.dtypes)
self.index_dtype = getattr(meta, "index_dtype", None)
self.strict = getattr(meta, "strict", True)


class DataFrameSchemaMeta(ma.schema.SchemaMeta):
"""Base metaclass for DataFrame schemas"""

def __new__(meta, name, bases, class_dict):
"""Only validate subclasses of our schemas"""
klass = super().__new__(meta, name, bases, class_dict)

if bases != (ma.Schema,) and klass.opts.dtypes is None:
raise ValueError(
"Subclasses of marshmallow_dataframe Schemas must define "
"the `dtypes` Meta option"
)

return klass

@classmethod
def get_declared_fields(
mcs, klass, cls_fields, inherited_fields, dict_cls
) -> Dict[str, ma.fields.Field]:
"""
Updates declared fields with fields generated from DataFrame dtypes
"""

opts = klass.opts
declared_fields = super().get_declared_fields(
klass, cls_fields, inherited_fields, dict_cls
)
fields = mcs.get_fields(opts, dict_cls)
fields.update(declared_fields)
return fields

@classmethod
def get_fields(
mcs, opts: DataFrameSchemaOpts, dict_cls
) -> Dict[str, ma.fields.Field]:
"""
Generate fields from DataFrame dtypes
To be implemented in subclasses of DataFrameSchemaMeta
"""
pass
66 changes: 66 additions & 0 deletions src/marshmallow_dataframe/converters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from collections import defaultdict
from typing import Dict

import marshmallow as ma
import numpy as np
import pandas as pd

DTYPE_KIND_TO_FIELD_CLASS = {
"i": ma.fields.Int,
"u": ma.fields.Int,
"f": ma.fields.Float,
"O": ma.fields.Str,
"U": ma.fields.Str,
"S": ma.fields.Str,
"b": ma.fields.Bool,
"M": ma.fields.DateTime,
"m": ma.fields.TimeDelta,
}

_DEFAULT_FIELD_OPTIONS = {"required": True, "allow_none": True}

DTYPE_KIND_TO_FIELD_OPTIONS: Dict[str, Dict[str, bool]] = defaultdict(
lambda: _DEFAULT_FIELD_OPTIONS
)
# Integer columns in pandas cannot have null values, so we allow_none for all
# types except int
DTYPE_KIND_TO_FIELD_OPTIONS["i"] = {"required": True}
DTYPE_KIND_TO_FIELD_OPTIONS["f"] = {
"allow_nan": True,
**_DEFAULT_FIELD_OPTIONS,
}


class DtypeToFieldConversionError(Exception):
pass


def dtype_to_field(dtype: np.dtype) -> ma.fields.Field:
# Object dtypes require more detailed mapping
if pd.api.types.is_categorical_dtype(dtype):
categories = dtype.categories.values.tolist()
kind = dtype.categories.dtype.kind
field_class = DTYPE_KIND_TO_FIELD_CLASS[kind]
field_options = DTYPE_KIND_TO_FIELD_OPTIONS[kind]
return field_class(
validate=ma.validate.OneOf(categories), **field_options
)

try:
kind = dtype.kind
except AttributeError as exc:
raise DtypeToFieldConversionError(
f"The dtype {dtype} does not have a `kind` attribute, "
"unable to map dtype into marshmallow field type"
) from exc

try:
field_class = DTYPE_KIND_TO_FIELD_CLASS[kind]
field_options = DTYPE_KIND_TO_FIELD_OPTIONS[kind]
return field_class(**field_options)
except KeyError as exc:
raise DtypeToFieldConversionError(
f"The conversion of the dtype {dtype} with kind {dtype.kind} "
"into marshmallow fields is unknown. Known kinds are: "
f"{DTYPE_KIND_TO_FIELD_CLASS.keys()}"
) from exc
Loading

0 comments on commit e4e6a02

Please sign in to comment.