Skip to content

Commit

Permalink
REF: use sanitize_array in Index.__new__ (pandas-dev#49718)
Browse files Browse the repository at this point in the history
* REF: Index.__new__ use sanitize_array

* REF: _wrapped_sanitize

* re-use wrapped_sanitize

* cln

* REF: share

* avoid extra copy

* troubleshoot CI

* pylint fixup
  • Loading branch information
jbrockmendel authored and MarcoGorelli committed Nov 18, 2022
1 parent b5fdac0 commit a55f1cf
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 89 deletions.
7 changes: 7 additions & 0 deletions pandas/core/construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,7 @@ def sanitize_array(
copy: bool = False,
*,
allow_2d: bool = False,
strict_ints: bool = False,
) -> ArrayLike:
"""
Sanitize input data to an ndarray or ExtensionArray, copy if specified,
Expand All @@ -512,6 +513,8 @@ def sanitize_array(
copy : bool, default False
allow_2d : bool, default False
If False, raise if we have a 2D Arraylike.
strict_ints : bool, default False
If False, silently ignore failures to cast float data to int dtype.
Returns
-------
Expand Down Expand Up @@ -581,6 +584,8 @@ def sanitize_array(
# DataFrame would call np.array(data, dtype=dtype, copy=copy),
# which would cast to the integer dtype even if the cast is lossy.
# See GH#40110.
if strict_ints:
raise

# We ignore the dtype arg and return floating values,
# e.g. test_constructor_floating_data_int_dtype
Expand Down Expand Up @@ -624,6 +629,8 @@ def sanitize_array(
subarr = _try_cast(data, dtype, copy)
except ValueError:
if is_integer_dtype(dtype):
if strict_ints:
raise
casted = np.array(data, copy=False)
if casted.dtype.kind == "f":
# GH#40110 match the behavior we have if we passed
Expand Down
133 changes: 45 additions & 88 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@
find_common_type,
infer_dtype_from,
maybe_cast_pointwise_result,
maybe_infer_to_datetimelike,
np_can_hold_element,
)
from pandas.core.dtypes.common import (
Expand Down Expand Up @@ -116,7 +115,6 @@
DatetimeTZDtype,
ExtensionDtype,
IntervalDtype,
PandasDtype,
PeriodDtype,
)
from pandas.core.dtypes.generic import (
Expand Down Expand Up @@ -208,6 +206,22 @@
_dtype_obj = np.dtype("object")


def _wrapped_sanitize(cls, data, dtype: DtypeObj | None, copy: bool):
"""
Call sanitize_array with wrapping for differences between Index/Series.
"""
try:
arr = sanitize_array(data, None, dtype=dtype, copy=copy, strict_ints=True)
except ValueError as err:
if "index must be specified when data is not list-like" in str(err):
raise cls._raise_scalar_data_error(data) from err
if "Data must be 1-dimensional" in str(err):
raise ValueError("Index data must be 1-dimensional") from err
raise
arr = ensure_wrapped_if_datetimelike(arr)
return arr


def _maybe_return_indexers(meth: F) -> F:
"""
Decorator to simplify 'return_indexers' checks in Index.join.
Expand Down Expand Up @@ -422,21 +436,13 @@ def __new__(
tupleize_cols: bool = True,
) -> Index:

from pandas.core.arrays import PandasArray
from pandas.core.indexes.range import RangeIndex

name = maybe_extract_name(name, data, cls)

if dtype is not None:
dtype = pandas_dtype(dtype)

if type(data) is PandasArray:
# ensure users don't accidentally put a PandasArray in an index,
# but don't unpack StringArray
data = data.to_numpy()
if isinstance(dtype, PandasDtype):
dtype = dtype.numpy_dtype

data_dtype = getattr(data, "dtype", None)

# range
Expand All @@ -448,28 +454,10 @@ def __new__(

elif is_ea_or_datetimelike_dtype(dtype):
# non-EA dtype indexes have special casting logic, so we punt here
klass = cls._dtype_to_subclass(dtype)
if klass is not Index:
return klass(data, dtype=dtype, copy=copy, name=name)

ea_cls = dtype.construct_array_type()
data = ea_cls._from_sequence(data, dtype=dtype, copy=copy)
return Index._simple_new(data, name=name)
pass

elif is_ea_or_datetimelike_dtype(data_dtype):
data_dtype = cast(DtypeObj, data_dtype)
klass = cls._dtype_to_subclass(data_dtype)
if klass is not Index:
result = klass(data, copy=copy, name=name)
if dtype is not None:
return result.astype(dtype, copy=False)
return result
elif dtype is not None:
# GH#45206
data = data.astype(dtype, copy=False)

data = extract_array(data, extract_numpy=True)
return Index._simple_new(data, name=name)
pass

# index-like
elif (
Expand All @@ -483,42 +471,25 @@ def __new__(
if isinstance(data, ABCMultiIndex):
data = data._values

if dtype is not None:
# we need to avoid having numpy coerce
if data.dtype.kind not in ["i", "u", "f", "b", "c", "m", "M"]:
# GH#11836 we need to avoid having numpy coerce
# things that look like ints/floats to ints unless
# they are actually ints, e.g. '0' and 0.0
# should not be coerced
# GH 11836
data = sanitize_array(data, None, dtype=dtype, copy=copy)

dtype = data.dtype

if data.dtype.kind in ["i", "u", "f"]:
# maybe coerce to a sub-class
arr = data
elif data.dtype.kind in ["b", "c"]:
# No special subclass, and Index._ensure_array won't do this
# for us.
arr = np.asarray(data)
else:
arr = com.asarray_tuplesafe(data, dtype=_dtype_obj)

if dtype is None:
arr = maybe_infer_to_datetimelike(arr)
arr = ensure_wrapped_if_datetimelike(arr)
dtype = arr.dtype

klass = cls._dtype_to_subclass(arr.dtype)
arr = klass._ensure_array(arr, dtype, copy)
return klass._simple_new(arr, name)
data = com.asarray_tuplesafe(data, dtype=_dtype_obj)

elif is_scalar(data):
raise cls._raise_scalar_data_error(data)
elif hasattr(data, "__array__"):
return Index(np.asarray(data), dtype=dtype, copy=copy, name=name)
elif not is_list_like(data) and not isinstance(data, memoryview):
# 2022-11-16 the memoryview check is only necessary on some CI
# builds, not clear why
raise cls._raise_scalar_data_error(data)

else:

if tupleize_cols and is_list_like(data):
if tupleize_cols:
# GH21470: convert iterable to list before determining if empty
if is_iterator(data):
data = list(data)
Expand All @@ -531,12 +502,24 @@ def __new__(
return MultiIndex.from_tuples(data, names=name)
# other iterable of some kind

subarr = com.asarray_tuplesafe(data, dtype=_dtype_obj)
if dtype is None:
# with e.g. a list [1, 2, 3] casting to numeric is _not_ deprecated
subarr = _maybe_cast_data_without_dtype(subarr)
dtype = subarr.dtype
return Index(subarr, dtype=dtype, copy=copy, name=name)
if not isinstance(data, (list, tuple)):
# we allow set/frozenset, which Series/sanitize_array does not, so
# cast to list here
data = list(data)
if len(data) == 0:
# unlike Series, we default to object dtype:
data = np.array(data, dtype=object)

if len(data) and isinstance(data[0], tuple):
# Ensure we get 1-D array of tuples instead of 2D array.
data = com.asarray_tuplesafe(data, dtype=_dtype_obj)

arr = _wrapped_sanitize(cls, data, dtype, copy)
klass = cls._dtype_to_subclass(arr.dtype)

# _ensure_array _may_ be unnecessary once Int64Index etc are gone
arr = klass._ensure_array(arr, arr.dtype, copy=False)
return klass._simple_new(arr, name)

@classmethod
def _ensure_array(cls, data, dtype, copy: bool):
Expand Down Expand Up @@ -7048,32 +7031,6 @@ def maybe_extract_name(name, obj, cls) -> Hashable:
return name


def _maybe_cast_data_without_dtype(subarr: npt.NDArray[np.object_]) -> ArrayLike:
"""
If we have an arraylike input but no passed dtype, try to infer
a supported dtype.
Parameters
----------
subarr : np.ndarray[object]
Returns
-------
np.ndarray or ExtensionArray
"""

result = lib.maybe_convert_objects(
subarr,
convert_datetime=True,
convert_timedelta=True,
convert_period=True,
convert_interval=True,
dtype_if_all_nat=np.dtype("datetime64[ns]"),
)
result = ensure_wrapped_if_datetimelike(result)
return result


def get_unanimous_names(*indexes: Index) -> tuple[Hashable, ...]:
"""
Return common name if all indices agree, otherwise None (level-by-level).
Expand Down
1 change: 1 addition & 0 deletions pandas/tests/indexes/datetimes/test_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,6 +912,7 @@ def test_constructor_no_precision_raises(self):
with pytest.raises(ValueError, match=msg):
DatetimeIndex(["2000"], dtype="datetime64")

msg = "The 'datetime64' dtype has no unit. Please pass in"
with pytest.raises(ValueError, match=msg):
Index(["2000"], dtype="datetime64")

Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/indexes/interval/test_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def test_constructor_errors(self, klass):

# scalar
msg = (
r"IntervalIndex\(...\) must be called with a collection of "
r"(IntervalIndex|Index)\(...\) must be called with a collection of "
"some kind, 5 was passed"
)
with pytest.raises(TypeError, match=msg):
Expand Down
1 change: 1 addition & 0 deletions pandas/tests/indexes/timedeltas/test_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ def test_constructor_no_precision_raises(self):
with pytest.raises(ValueError, match=msg):
TimedeltaIndex(["2000"], dtype="timedelta64")

msg = "The 'timedelta64' dtype has no unit. Please pass in"
with pytest.raises(ValueError, match=msg):
pd.Index(["2000"], dtype="timedelta64")

Expand Down

0 comments on commit a55f1cf

Please sign in to comment.