-
-
Notifications
You must be signed in to change notification settings - Fork 18.5k
REF: back IntervalArray by a single ndarray #37047
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
34b44cb
d4074c7
014c13e
8a4f7d4
b93fbc0
e9e9b7c
763a384
a9adb76
66c70e8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from operator import le, lt | ||
import textwrap | ||
from typing import TYPE_CHECKING, Optional, Tuple, Union, cast | ||
|
||
import numpy as np | ||
|
||
|
@@ -11,14 +12,17 @@ | |
IntervalMixin, | ||
intervals_to_interval_bounds, | ||
) | ||
from pandas._typing import ArrayLike, Dtype | ||
from pandas.compat.numpy import function as nv | ||
from pandas.util._decorators import Appender | ||
|
||
from pandas.core.dtypes.cast import maybe_convert_platform | ||
from pandas.core.dtypes.common import ( | ||
is_categorical_dtype, | ||
is_datetime64_any_dtype, | ||
is_dtype_equal, | ||
is_float_dtype, | ||
is_integer, | ||
is_integer_dtype, | ||
is_interval_dtype, | ||
is_list_like, | ||
|
@@ -45,6 +49,10 @@ | |
from pandas.core.indexers import check_array_indexer | ||
from pandas.core.indexes.base import ensure_index | ||
|
||
if TYPE_CHECKING: | ||
from pandas import Index | ||
from pandas.core.arrays import DatetimeArray, TimedeltaArray | ||
|
||
_interval_shared_docs = {} | ||
|
||
_shared_docs_kwargs = dict( | ||
|
@@ -169,6 +177,17 @@ def __new__( | |
left = data._left | ||
right = data._right | ||
closed = closed or data.closed | ||
|
||
if dtype is None or data.dtype == dtype: | ||
# This path will preserve id(result._combined) | ||
# TODO: could also validate dtype before going to simple_new | ||
combined = data._combined | ||
if copy: | ||
combined = combined.copy() | ||
result = cls._simple_new(combined, closed=closed) | ||
if verify_integrity: | ||
result._validate() | ||
return result | ||
else: | ||
|
||
# don't allow scalars | ||
|
@@ -186,83 +205,22 @@ def __new__( | |
) | ||
closed = closed or infer_closed | ||
|
||
return cls._simple_new( | ||
left, | ||
right, | ||
closed, | ||
copy=copy, | ||
dtype=dtype, | ||
verify_integrity=verify_integrity, | ||
) | ||
closed = closed or "right" | ||
left, right = _maybe_cast_inputs(left, right, copy, dtype) | ||
combined = _get_combined_data(left, right) | ||
result = cls._simple_new(combined, closed=closed) | ||
if verify_integrity: | ||
result._validate() | ||
return result | ||
|
||
@classmethod | ||
def _simple_new( | ||
cls, left, right, closed=None, copy=False, dtype=None, verify_integrity=True | ||
): | ||
def _simple_new(cls, data, closed="right"): | ||
result = IntervalMixin.__new__(cls) | ||
|
||
closed = closed or "right" | ||
left = ensure_index(left, copy=copy) | ||
right = ensure_index(right, copy=copy) | ||
|
||
if dtype is not None: | ||
# GH 19262: dtype must be an IntervalDtype to override inferred | ||
dtype = pandas_dtype(dtype) | ||
if not is_interval_dtype(dtype): | ||
msg = f"dtype must be an IntervalDtype, got {dtype}" | ||
raise TypeError(msg) | ||
elif dtype.subtype is not None: | ||
left = left.astype(dtype.subtype) | ||
right = right.astype(dtype.subtype) | ||
|
||
# coerce dtypes to match if needed | ||
if is_float_dtype(left) and is_integer_dtype(right): | ||
right = right.astype(left.dtype) | ||
elif is_float_dtype(right) and is_integer_dtype(left): | ||
left = left.astype(right.dtype) | ||
|
||
if type(left) != type(right): | ||
msg = ( | ||
f"must not have differing left [{type(left).__name__}] and " | ||
f"right [{type(right).__name__}] types" | ||
) | ||
raise ValueError(msg) | ||
elif is_categorical_dtype(left.dtype) or is_string_dtype(left.dtype): | ||
# GH 19016 | ||
msg = ( | ||
"category, object, and string subtypes are not supported " | ||
"for IntervalArray" | ||
) | ||
raise TypeError(msg) | ||
elif isinstance(left, ABCPeriodIndex): | ||
msg = "Period dtypes are not supported, use a PeriodIndex instead" | ||
raise ValueError(msg) | ||
elif isinstance(left, ABCDatetimeIndex) and str(left.tz) != str(right.tz): | ||
msg = ( | ||
"left and right must have the same time zone, got " | ||
f"'{left.tz}' and '{right.tz}'" | ||
) | ||
raise ValueError(msg) | ||
|
||
# For dt64/td64 we want DatetimeArray/TimedeltaArray instead of ndarray | ||
from pandas.core.ops.array_ops import maybe_upcast_datetimelike_array | ||
|
||
left = maybe_upcast_datetimelike_array(left) | ||
left = extract_array(left, extract_numpy=True) | ||
right = maybe_upcast_datetimelike_array(right) | ||
right = extract_array(right, extract_numpy=True) | ||
|
||
lbase = getattr(left, "_ndarray", left).base | ||
rbase = getattr(right, "_ndarray", right).base | ||
if lbase is not None and lbase is rbase: | ||
# If these share data, then setitem could corrupt our IA | ||
right = right.copy() | ||
|
||
result._left = left | ||
result._right = right | ||
result._combined = data | ||
result._left = data[:, 0] | ||
result._right = data[:, 1] | ||
result._closed = closed | ||
if verify_integrity: | ||
result._validate() | ||
return result | ||
|
||
@classmethod | ||
|
@@ -397,10 +355,16 @@ def from_breaks(cls, breaks, closed="right", copy=False, dtype=None): | |
def from_arrays(cls, left, right, closed="right", copy=False, dtype=None): | ||
left = maybe_convert_platform_interval(left) | ||
right = maybe_convert_platform_interval(right) | ||
if len(left) != len(right): | ||
raise ValueError("left and right must have the same length") | ||
|
||
return cls._simple_new( | ||
left, right, closed, copy=copy, dtype=dtype, verify_integrity=True | ||
) | ||
closed = closed or "right" | ||
left, right = _maybe_cast_inputs(left, right, copy, dtype) | ||
combined = _get_combined_data(left, right) | ||
|
||
result = cls._simple_new(combined, closed) | ||
result._validate() | ||
return result | ||
|
||
_interval_shared_docs["from_tuples"] = textwrap.dedent( | ||
""" | ||
|
@@ -506,19 +470,6 @@ def _validate(self): | |
msg = "left side of interval must be <= right side" | ||
raise ValueError(msg) | ||
|
||
def _shallow_copy(self, left, right): | ||
""" | ||
Return a new IntervalArray with the replacement attributes | ||
|
||
Parameters | ||
---------- | ||
left : Index | ||
Values to be used for the left-side of the intervals. | ||
right : Index | ||
Values to be used for the right-side of the intervals. | ||
""" | ||
return self._simple_new(left, right, closed=self.closed, verify_integrity=False) | ||
|
||
# --------------------------------------------------------------------- | ||
# Descriptive | ||
|
||
|
@@ -546,18 +497,20 @@ def __len__(self) -> int: | |
|
||
def __getitem__(self, key): | ||
key = check_array_indexer(self, key) | ||
left = self._left[key] | ||
right = self._right[key] | ||
|
||
if not isinstance(left, (np.ndarray, ExtensionArray)): | ||
# scalar | ||
if is_scalar(left) and isna(left): | ||
result = self._combined[key] | ||
|
||
if is_integer(key): | ||
left, right = result[0], result[1] | ||
if isna(left): | ||
return self._fill_value | ||
return Interval(left, right, self.closed) | ||
if np.ndim(left) > 1: | ||
|
||
# TODO: need to watch out for incorrectly-reducing getitem | ||
if np.ndim(result) > 2: | ||
# GH#30588 multi-dimensional indexer disallowed | ||
raise ValueError("multi-dimensional indexing not allowed") | ||
return self._shallow_copy(left, right) | ||
return type(self)._simple_new(result, closed=self.closed) | ||
|
||
def __setitem__(self, key, value): | ||
value_left, value_right = self._validate_setitem_value(value) | ||
|
@@ -651,7 +604,8 @@ def fillna(self, value=None, method=None, limit=None): | |
|
||
left = self.left.fillna(value=value_left) | ||
right = self.right.fillna(value=value_right) | ||
return self._shallow_copy(left, right) | ||
combined = _get_combined_data(left, right) | ||
return type(self)._simple_new(combined, closed=self.closed) | ||
|
||
def astype(self, dtype, copy=True): | ||
""" | ||
|
@@ -693,7 +647,9 @@ def astype(self, dtype, copy=True): | |
f"Cannot convert {self.dtype} to {dtype}; subtypes are incompatible" | ||
) | ||
raise TypeError(msg) from err | ||
return self._shallow_copy(new_left, new_right) | ||
# TODO: do astype directly on self._combined | ||
combined = _get_combined_data(new_left, new_right) | ||
return type(self)._simple_new(combined, closed=self.closed) | ||
elif is_categorical_dtype(dtype): | ||
return Categorical(np.asarray(self)) | ||
elif isinstance(dtype, StringDtype): | ||
|
@@ -734,9 +690,11 @@ def _concat_same_type(cls, to_concat): | |
raise ValueError("Intervals must all be closed on the same side.") | ||
closed = closed.pop() | ||
|
||
# TODO: will this mess up on dt64tz? | ||
left = np.concatenate([interval.left for interval in to_concat]) | ||
right = np.concatenate([interval.right for interval in to_concat]) | ||
return cls._simple_new(left, right, closed=closed, copy=False) | ||
combined = _get_combined_data(left, right) # TODO: 1-stage concat | ||
return cls._simple_new(combined, closed=closed) | ||
|
||
def copy(self): | ||
""" | ||
|
@@ -746,11 +704,8 @@ def copy(self): | |
------- | ||
IntervalArray | ||
""" | ||
left = self._left.copy() | ||
right = self._right.copy() | ||
closed = self.closed | ||
# TODO: Could skip verify_integrity here. | ||
return type(self).from_arrays(left, right, closed=closed) | ||
combined = self._combined.copy() | ||
return type(self)._simple_new(combined, closed=self.closed) | ||
|
||
def isna(self) -> np.ndarray: | ||
return isna(self._left) | ||
|
@@ -843,7 +798,8 @@ def take(self, indices, allow_fill=False, fill_value=None, axis=None, **kwargs): | |
self._right, indices, allow_fill=allow_fill, fill_value=fill_right | ||
) | ||
|
||
return self._shallow_copy(left_take, right_take) | ||
combined = _get_combined_data(left_take, right_take) | ||
return type(self)._simple_new(combined, closed=self.closed) | ||
|
||
def _validate_listlike(self, value): | ||
# list-like of intervals | ||
|
@@ -1170,10 +1126,7 @@ def set_closed(self, closed): | |
if closed not in VALID_CLOSED: | ||
msg = f"invalid option for 'closed': {closed}" | ||
raise ValueError(msg) | ||
|
||
return type(self)._simple_new( | ||
left=self._left, right=self._right, closed=closed, verify_integrity=False | ||
) | ||
return type(self)._simple_new(self._combined, closed=closed) | ||
|
||
_interval_shared_docs[ | ||
"is_non_overlapping_monotonic" | ||
|
@@ -1314,9 +1267,8 @@ def to_tuples(self, na_tuple=True): | |
@Appender(_extension_array_shared_docs["repeat"] % _shared_docs_kwargs) | ||
def repeat(self, repeats, axis=None): | ||
nv.validate_repeat(tuple(), dict(axis=axis)) | ||
left_repeat = self.left.repeat(repeats) | ||
right_repeat = self.right.repeat(repeats) | ||
return self._shallow_copy(left=left_repeat, right=right_repeat) | ||
combined = self._combined.repeat(repeats, 0) | ||
return type(self)._simple_new(combined, closed=self.closed) | ||
|
||
_interval_shared_docs["contains"] = textwrap.dedent( | ||
""" | ||
|
@@ -1399,3 +1351,92 @@ def maybe_convert_platform_interval(values): | |
values = np.asarray(values) | ||
|
||
return maybe_convert_platform(values) | ||
|
||
|
||
def _maybe_cast_inputs( | ||
left_orig: Union["Index", ArrayLike], | ||
right_orig: Union["Index", ArrayLike], | ||
copy: bool, | ||
dtype: Optional[Dtype], | ||
) -> Tuple["Index", "Index"]: | ||
left = ensure_index(left_orig, copy=copy) | ||
right = ensure_index(right_orig, copy=copy) | ||
|
||
if dtype is not None: | ||
# GH#19262: dtype must be an IntervalDtype to override inferred | ||
dtype = pandas_dtype(dtype) | ||
if not is_interval_dtype(dtype): | ||
msg = f"dtype must be an IntervalDtype, got {dtype}" | ||
raise TypeError(msg) | ||
dtype = cast(IntervalDtype, dtype) | ||
if dtype.subtype is not None: | ||
left = left.astype(dtype.subtype) | ||
right = right.astype(dtype.subtype) | ||
|
||
# coerce dtypes to match if needed | ||
if is_float_dtype(left) and is_integer_dtype(right): | ||
right = right.astype(left.dtype) | ||
elif is_float_dtype(right) and is_integer_dtype(left): | ||
left = left.astype(right.dtype) | ||
|
||
if type(left) != type(right): | ||
msg = ( | ||
f"must not have differing left [{type(left).__name__}] and " | ||
f"right [{type(right).__name__}] types" | ||
) | ||
raise ValueError(msg) | ||
elif is_categorical_dtype(left.dtype) or is_string_dtype(left.dtype): | ||
# GH#19016 | ||
msg = ( | ||
"category, object, and string subtypes are not supported " | ||
"for IntervalArray" | ||
) | ||
raise TypeError(msg) | ||
elif isinstance(left, ABCPeriodIndex): | ||
msg = "Period dtypes are not supported, use a PeriodIndex instead" | ||
raise ValueError(msg) | ||
elif isinstance(left, ABCDatetimeIndex) and not is_dtype_equal( | ||
left.dtype, right.dtype | ||
): | ||
left_arr = cast("DatetimeArray", left._data) | ||
right_arr = cast("DatetimeArray", right._data) | ||
msg = ( | ||
"left and right must have the same time zone, got " | ||
f"'{left_arr.tz}' and '{right_arr.tz}'" | ||
) | ||
raise ValueError(msg) | ||
|
||
return left, right | ||
|
||
|
||
def _get_combined_data( | ||
left: Union["Index", ArrayLike], right: Union["Index", ArrayLike] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can this method be more strict? e.g. only accept Union[np.ndarray, "DatetimeArray", "TimedeltaArray"] e.g. things that have already been casted? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think we can do quite a bit less back-and-forth casting eventually, yes |
||
) -> Union[np.ndarray, "DatetimeArray", "TimedeltaArray"]: | ||
# For dt64/td64 we want DatetimeArray/TimedeltaArray instead of ndarray | ||
from pandas.core.ops.array_ops import maybe_upcast_datetimelike_array | ||
|
||
left = maybe_upcast_datetimelike_array(left) | ||
left = extract_array(left, extract_numpy=True) | ||
right = maybe_upcast_datetimelike_array(right) | ||
right = extract_array(right, extract_numpy=True) | ||
|
||
lbase = getattr(left, "_ndarray", left).base | ||
rbase = getattr(right, "_ndarray", right).base | ||
if lbase is not None and lbase is rbase: | ||
# If these share data, then setitem could corrupt our IA | ||
right = right.copy() | ||
|
||
if isinstance(left, np.ndarray): | ||
assert isinstance(right, np.ndarray) # for mypy | ||
combined = np.concatenate( | ||
[left.reshape(-1, 1), right.reshape(-1, 1)], | ||
axis=1, | ||
) | ||
else: | ||
left = cast(Union["DatetimeArray", "TimedeltaArray"], left) | ||
right = cast(Union["DatetimeArray", "TimedeltaArray"], right) | ||
combined = type(left)._concat_same_type( | ||
[left.reshape(-1, 1), right.reshape(-1, 1)], | ||
axis=1, | ||
) | ||
return combined |
Uh oh!
There was an error while loading. Please reload this page.