Skip to content

Commit

Permalink
fix: use public pandas APIs where possible (#60)
Browse files Browse the repository at this point in the history
* refactor: use public pandas APIs where possible

* no need to override take

* backport take implementation

* move remaining private pandas methods to backports

* add note about _validate_scalar to docstring

* comment why we can't use public mixin
  • Loading branch information
tswast authored Jan 26, 2022
1 parent 5cb2c6b commit e9d41d1
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 74 deletions.
16 changes: 5 additions & 11 deletions db_dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,7 @@
import numpy
import packaging.version
import pandas
import pandas.compat.numpy.function
import pandas.core.algorithms
import pandas.core.arrays
import pandas.core.dtypes.base
import pandas.core.dtypes.dtypes
import pandas.core.dtypes.generic
import pandas.core.nanops
import pandas.api.extensions
import pyarrow
import pyarrow.compute

Expand All @@ -44,7 +38,7 @@
pandas_release = packaging.version.parse(pandas.__version__).release


@pandas.core.dtypes.dtypes.register_extension_dtype
@pandas.api.extensions.register_extension_dtype
class TimeDtype(core.BaseDatetimeDtype):
"""
Extension dtype for time data.
Expand Down Expand Up @@ -113,7 +107,7 @@ def _datetime(
.as_py()
)

if scalar is None:
if pandas.isna(scalar):
return None
if isinstance(scalar, datetime.time):
return pandas.Timestamp(
Expand Down Expand Up @@ -194,7 +188,7 @@ def __arrow_array__(self, type=None):
)


@pandas.core.dtypes.dtypes.register_extension_dtype
@pandas.api.extensions.register_extension_dtype
class DateDtype(core.BaseDatetimeDtype):
"""
Extension dtype for time data.
Expand Down Expand Up @@ -238,7 +232,7 @@ def _datetime(
if isinstance(scalar, (pyarrow.Date32Scalar, pyarrow.Date64Scalar)):
scalar = scalar.as_py()

if scalar is None:
if pandas.isna(scalar):
return None
elif isinstance(scalar, datetime.date):
return pandas.Timestamp(
Expand Down
80 changes: 18 additions & 62 deletions db_dtypes/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Optional, Sequence
from typing import Optional

import numpy
import pandas
from pandas._libs import NaT
from pandas import NaT
import pandas.api.extensions
import pandas.compat.numpy.function
import pandas.core.algorithms
import pandas.core.arrays
import pandas.core.dtypes.base
from pandas.core.dtypes.common import is_dtype_equal, is_list_like, pandas_dtype
import pandas.core.dtypes.dtypes
import pandas.core.dtypes.generic
import pandas.core.nanops
from pandas.api.types import is_dtype_equal, is_list_like, pandas_dtype

from db_dtypes import pandas_backports

Expand Down Expand Up @@ -107,42 +100,11 @@ def isna(self):
return pandas.isna(self._ndarray)

def _validate_scalar(self, value):
if pandas.isna(value):
return None

if not isinstance(value, self.dtype.type):
raise ValueError(value)

return value

def take(
self,
indices: Sequence[int],
*,
allow_fill: bool = False,
fill_value: Any = None,
):
indices = numpy.asarray(indices, dtype=numpy.intp)
data = self._ndarray
if allow_fill:
fill_value = self._validate_scalar(fill_value)
fill_value = (
numpy.datetime64() if fill_value is None else self._datetime(fill_value)
)
if (indices < -1).any():
raise ValueError(
"take called with negative indexes other than -1,"
" when a fill value is provided."
)
out = data.take(indices)
if allow_fill:
out[indices == -1] = fill_value

return self.__class__(out)

# TODO: provide implementations of dropna, fillna, unique,
# factorize, argsort, searchsoeted for better performance over
# abstract implementations.
"""
Validate and convert a scalar value to datetime64[ns] for storage in
backing NumPy array.
"""
return self._datetime(value)

def any(
self,
Expand All @@ -152,10 +114,8 @@ def any(
keepdims: bool = False,
skipna: bool = True,
):
pandas.compat.numpy.function.validate_any(
(), {"out": out, "keepdims": keepdims}
)
result = pandas.core.nanops.nanany(self._ndarray, axis=axis, skipna=skipna)
pandas_backports.numpy_validate_any((), {"out": out, "keepdims": keepdims})
result = pandas_backports.nanany(self._ndarray, axis=axis, skipna=skipna)
return result

def all(
Expand All @@ -166,22 +126,20 @@ def all(
keepdims: bool = False,
skipna: bool = True,
):
pandas.compat.numpy.function.validate_all(
(), {"out": out, "keepdims": keepdims}
)
result = pandas.core.nanops.nanall(self._ndarray, axis=axis, skipna=skipna)
pandas_backports.numpy_validate_all((), {"out": out, "keepdims": keepdims})
result = pandas_backports.nanall(self._ndarray, axis=axis, skipna=skipna)
return result

def min(self, *, axis: Optional[int] = None, skipna: bool = True, **kwargs):
pandas.compat.numpy.function.validate_min((), kwargs)
result = pandas.core.nanops.nanmin(
pandas_backports.numpy_validate_min((), kwargs)
result = pandas_backports.nanmin(
values=self._ndarray, axis=axis, mask=self.isna(), skipna=skipna
)
return self._box_func(result)

def max(self, *, axis: Optional[int] = None, skipna: bool = True, **kwargs):
pandas.compat.numpy.function.validate_max((), kwargs)
result = pandas.core.nanops.nanmax(
pandas_backports.numpy_validate_max((), kwargs)
result = pandas_backports.nanmax(
values=self._ndarray, axis=axis, mask=self.isna(), skipna=skipna
)
return self._box_func(result)
Expand All @@ -197,11 +155,9 @@ def median(
keepdims: bool = False,
skipna: bool = True,
):
pandas.compat.numpy.function.validate_median(
pandas_backports.numpy_validate_median(
(),
{"out": out, "overwrite_input": overwrite_input, "keepdims": keepdims},
)
result = pandas.core.nanops.nanmedian(
self._ndarray, axis=axis, skipna=skipna
)
result = pandas_backports.nanmedian(self._ndarray, axis=axis, skipna=skipna)
return self._box_func(result)
47 changes: 46 additions & 1 deletion db_dtypes/pandas_backports.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,32 @@
"""

import operator
from typing import Any

import numpy
import packaging.version
import pandas
from pandas._libs.lib import is_integer
from pandas.api.types import is_integer
import pandas.compat.numpy.function
import pandas.core.nanops


pandas_release = packaging.version.parse(pandas.__version__).release

# Create aliases for private methods in case they move in a future version.
nanall = pandas.core.nanops.nanall
nanany = pandas.core.nanops.nanany
nanmax = pandas.core.nanops.nanmax
nanmin = pandas.core.nanops.nanmin
numpy_validate_all = pandas.compat.numpy.function.validate_all
numpy_validate_any = pandas.compat.numpy.function.validate_any
numpy_validate_max = pandas.compat.numpy.function.validate_max
numpy_validate_min = pandas.compat.numpy.function.validate_min

if pandas_release >= (1, 2):
nanmedian = pandas.core.nanops.nanmedian
numpy_validate_median = pandas.compat.numpy.function.validate_median


def import_default(module_name, force=False, default=None):
"""
Expand All @@ -55,6 +72,10 @@ def import_default(module_name, force=False, default=None):
return getattr(module, name, default)


# pandas.core.arraylike.OpsMixin is private, but the related public API
# "ExtensionScalarOpsMixin" is not sufficient for adding dates to times.
# It results in unsupported operand type(s) for +: 'datetime.time' and
# 'datetime.date'
@import_default("pandas.core.arraylike")
class OpsMixin:
def _cmp_method(self, other, op): # pragma: NO COVER
Expand All @@ -81,6 +102,8 @@ def __ge__(self, other):
__add__ = __radd__ = __sub__ = lambda self, other: NotImplemented


# TODO: use public API once pandas 1.5 / 2.x is released.
# See: https://github.com/pandas-dev/pandas/pull/45544
@import_default("pandas.core.arrays._mixins", pandas_release < (1, 3))
class NDArrayBackedExtensionArray(pandas.core.arrays.base.ExtensionArray):

Expand Down Expand Up @@ -130,6 +153,28 @@ def copy(self):
def repeat(self, n):
return self.__class__(self._ndarray.repeat(n), self._dtype)

def take(
self,
indices,
*,
allow_fill: bool = False,
fill_value: Any = None,
axis: int = 0,
):
from pandas.core.algorithms import take

if allow_fill:
fill_value = self._validate_scalar(fill_value)

new_data = take(
self._ndarray,
indices,
allow_fill=allow_fill,
fill_value=fill_value,
axis=axis,
)
return self._from_backing_data(new_data)

@classmethod
def _concat_same_type(cls, to_concat, axis=0):
dtypes = {str(x.dtype) for x in to_concat}
Expand Down

0 comments on commit e9d41d1

Please sign in to comment.