Skip to content

Commit

Permalink
fix: address failing 2D array compliance tests in DateArray (#64)
Browse files Browse the repository at this point in the history
* fix: address failing compliance tests in DateArray and TimeArray

test: add a test session with prerelease versions of dependencies

* fix min/max/median for 2D arrays

* fixes except for null contains

* actually use NaT as 'advertised'

* fix!: use `pandas.NaT` for missing values in dbdate and dbtime dtypes

This makes them consistent with other date/time dtypes, as well as internally
consistent with the advertised `dtype.na_value`.

BREAKING-CHANGE: dbdate and dbtime dtypes return NaT instead of None for missing values

Release-As: 0.4.0

* more progress towards compliance

* address errors in TestMethods

* move tests

* add prerelease deps

* fix: address failing tests with pandas 1.5.0

test: add a test session with prerelease versions of dependencies

* fix owlbot config

* 🦉 Updates from OwlBot post-processor

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

* 🦉 Updates from OwlBot post-processor

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

* document why microsecond precision is used

* use correct units

* add box_func tests

* typo

* add unit tests

Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
  • Loading branch information
tswast and gcf-owl-bot[bot] authored Mar 24, 2022
1 parent 1db1357 commit b771e05
Show file tree
Hide file tree
Showing 4 changed files with 210 additions and 23 deletions.
44 changes: 25 additions & 19 deletions db_dtypes/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,29 +152,35 @@ def min(self, *, axis: Optional[int] = None, skipna: bool = True, **kwargs):
result = pandas_backports.nanmin(
values=self._ndarray, axis=axis, mask=self.isna(), skipna=skipna
)
return self._box_func(result)
if axis is None or self.ndim == 1:
return self._box_func(result)
return self._from_backing_data(result)

def max(self, *, axis: Optional[int] = None, skipna: bool = True, **kwargs):
pandas_backports.numpy_validate_max((), kwargs)
result = pandas_backports.nanmax(
values=self._ndarray, axis=axis, mask=self.isna(), skipna=skipna
)
return self._box_func(result)

if pandas_release >= (1, 2):

def median(
self,
*,
axis: Optional[int] = None,
out=None,
overwrite_input: bool = False,
keepdims: bool = False,
skipna: bool = True,
):
pandas_backports.numpy_validate_median(
(),
{"out": out, "overwrite_input": overwrite_input, "keepdims": keepdims},
)
result = pandas_backports.nanmedian(self._ndarray, axis=axis, skipna=skipna)
if axis is None or self.ndim == 1:
return self._box_func(result)
return self._from_backing_data(result)

def median(
self,
*,
axis: Optional[int] = None,
out=None,
overwrite_input: bool = False,
keepdims: bool = False,
skipna: bool = True,
):
if not hasattr(pandas_backports, "numpy_validate_median"):
raise NotImplementedError("Need pandas 1.3 or later to calculate median.")

pandas_backports.numpy_validate_median(
(), {"out": out, "overwrite_input": overwrite_input, "keepdims": keepdims},
)
result = pandas_backports.nanmedian(self._ndarray, axis=axis, skipna=skipna)
if axis is None or self.ndim == 1:
return self._box_func(result)
return self._from_backing_data(result)
4 changes: 0 additions & 4 deletions db_dtypes/pandas_backports.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,8 @@ def __ge__(self, other):
# See: https://github.com/pandas-dev/pandas/pull/45544
@import_default("pandas.core.arrays._mixins", pandas_release < (1, 3))
class NDArrayBackedExtensionArray(pandas.core.arrays.base.ExtensionArray):

ndim = 1

def __init__(self, values, dtype):
assert isinstance(values, numpy.ndarray)
assert values.ndim == 1
self._ndarray = values
self._dtype = dtype

Expand Down
35 changes: 35 additions & 0 deletions tests/compliance/date/test_date_compliance_1_5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Tests for extension interface compliance, inherited from pandas.
See:
https://github.com/pandas-dev/pandas/blob/main/pandas/tests/extension/decimal/test_decimal.py
and
https://github.com/pandas-dev/pandas/blob/main/pandas/tests/extension/test_period.py
"""

from pandas.tests.extension import base
import pytest

# NDArrayBacked2DTests suite added in https://github.com/pandas-dev/pandas/pull/44974
pytest.importorskip("pandas", minversion="1.5.0dev")


class Test2DCompat(base.NDArrayBacked2DTests):
pass


class TestIndex(base.BaseIndexTests):
pass
150 changes: 150 additions & 0 deletions tests/unit/test_date.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import operator

import numpy
import numpy.testing
import pandas
import pandas.testing
import pytest
Expand Down Expand Up @@ -154,6 +155,100 @@ def test_date_parsing_errors(value, error):
pandas.Series([value], dtype="dbdate")


def test_date_max_2d():
input_array = db_dtypes.DateArray(
numpy.array(
[
[
numpy.datetime64("1970-01-01"),
numpy.datetime64("1980-02-02"),
numpy.datetime64("1990-03-03"),
],
[
numpy.datetime64("1971-02-02"),
numpy.datetime64("1981-03-03"),
numpy.datetime64("1991-04-04"),
],
[
numpy.datetime64("1972-03-03"),
numpy.datetime64("1982-04-04"),
numpy.datetime64("1992-05-05"),
],
],
dtype="datetime64[ns]",
)
)
numpy.testing.assert_array_equal(
input_array.max(axis=0)._ndarray,
numpy.array(
[
numpy.datetime64("1972-03-03"),
numpy.datetime64("1982-04-04"),
numpy.datetime64("1992-05-05"),
],
dtype="datetime64[ns]",
),
)
numpy.testing.assert_array_equal(
input_array.max(axis=1)._ndarray,
numpy.array(
[
numpy.datetime64("1990-03-03"),
numpy.datetime64("1991-04-04"),
numpy.datetime64("1992-05-05"),
],
dtype="datetime64[ns]",
),
)


def test_date_min_2d():
input_array = db_dtypes.DateArray(
numpy.array(
[
[
numpy.datetime64("1970-01-01"),
numpy.datetime64("1980-02-02"),
numpy.datetime64("1990-03-03"),
],
[
numpy.datetime64("1971-02-02"),
numpy.datetime64("1981-03-03"),
numpy.datetime64("1991-04-04"),
],
[
numpy.datetime64("1972-03-03"),
numpy.datetime64("1982-04-04"),
numpy.datetime64("1992-05-05"),
],
],
dtype="datetime64[ns]",
)
)
numpy.testing.assert_array_equal(
input_array.min(axis=0)._ndarray,
numpy.array(
[
numpy.datetime64("1970-01-01"),
numpy.datetime64("1980-02-02"),
numpy.datetime64("1990-03-03"),
],
dtype="datetime64[ns]",
),
)
numpy.testing.assert_array_equal(
input_array.min(axis=1)._ndarray,
numpy.array(
[
numpy.datetime64("1970-01-01"),
numpy.datetime64("1971-02-02"),
numpy.datetime64("1972-03-03"),
],
dtype="datetime64[ns]",
),
)


@pytest.mark.skipif(
not hasattr(pandas_backports, "numpy_validate_median"),
reason="median not available with this version of pandas",
Expand All @@ -178,3 +273,58 @@ def test_date_parsing_errors(value, error):
def test_date_median(values, expected):
series = pandas.Series(values, dtype="dbdate")
assert series.median() == expected


@pytest.mark.skipif(
not hasattr(pandas_backports, "numpy_validate_median"),
reason="median not available with this version of pandas",
)
def test_date_median_2d():
input_array = db_dtypes.DateArray(
numpy.array(
[
[
numpy.datetime64("1970-01-01"),
numpy.datetime64("1980-02-02"),
numpy.datetime64("1990-03-03"),
],
[
numpy.datetime64("1971-02-02"),
numpy.datetime64("1981-03-03"),
numpy.datetime64("1991-04-04"),
],
[
numpy.datetime64("1972-03-03"),
numpy.datetime64("1982-04-04"),
numpy.datetime64("1992-05-05"),
],
],
dtype="datetime64[ns]",
)
)
pandas.testing.assert_extension_array_equal(
input_array.median(axis=0),
db_dtypes.DateArray(
numpy.array(
[
numpy.datetime64("1971-02-02"),
numpy.datetime64("1981-03-03"),
numpy.datetime64("1991-04-04"),
],
dtype="datetime64[ns]",
)
),
)
pandas.testing.assert_extension_array_equal(
input_array.median(axis=1),
db_dtypes.DateArray(
numpy.array(
[
numpy.datetime64("1980-02-02"),
numpy.datetime64("1981-03-03"),
numpy.datetime64("1982-04-04"),
],
dtype="datetime64[ns]",
)
),
)

0 comments on commit b771e05

Please sign in to comment.