|
31 | 31 | DtypeObj, |
32 | 32 | IntervalClosedType, |
33 | 33 | npt, |
| 34 | + NumpyIndexT, |
34 | 35 | ) |
35 | 36 | from pandas.errors import InvalidIndexError |
36 | 37 | from pandas.util._decorators import ( |
|
47 | 48 | ) |
48 | 49 | from pandas.core.dtypes.common import ( |
49 | 50 | ensure_platform_int, |
| 51 | + is_array_like, |
50 | 52 | is_datetime64tz_dtype, |
51 | 53 | is_datetime_or_timedelta_dtype, |
52 | 54 | is_dtype_equal, |
@@ -146,6 +148,24 @@ def _new_IntervalIndex(cls, d): |
146 | 148 | return cls.from_arrays(**d) |
147 | 149 |
|
148 | 150 |
|
| 151 | +def maybe_convert_numeric_to_64bit(arr: NumpyIndexT) -> NumpyIndexT: |
| 152 | + # IntervalTree only supports 64 bit numpy array |
| 153 | + |
| 154 | + if not is_array_like(arr): |
| 155 | + return arr |
| 156 | + dtype = arr.dtype |
| 157 | + if not np.issubclass_(dtype.type, np.number): |
| 158 | + return arr |
| 159 | + elif is_signed_integer_dtype(dtype) and dtype != np.int64: |
| 160 | + return arr.astype(np.int64) |
| 161 | + elif is_unsigned_integer_dtype(dtype) and dtype != np.uint64: |
| 162 | + return arr.astype(np.uint64) |
| 163 | + elif is_float_dtype(dtype) and dtype != np.float64: |
| 164 | + return arr.astype(np.float64) |
| 165 | + else: |
| 166 | + return arr |
| 167 | + |
| 168 | + |
149 | 169 | @Appender( |
150 | 170 | _interval_shared_docs["class"] |
151 | 171 | % { |
@@ -343,7 +363,9 @@ def from_tuples( |
343 | 363 | @cache_readonly |
344 | 364 | def _engine(self) -> IntervalTree: # type: ignore[override] |
345 | 365 | left = self._maybe_convert_i8(self.left) |
| 366 | + left = maybe_convert_numeric_to_64bit(left) |
346 | 367 | right = self._maybe_convert_i8(self.right) |
| 368 | + right = maybe_convert_numeric_to_64bit(right) |
347 | 369 | return IntervalTree(left, right, closed=self.closed) |
348 | 370 |
|
349 | 371 | def __contains__(self, key: Any) -> bool: |
@@ -520,13 +542,12 @@ def _maybe_convert_i8(self, key): |
520 | 542 | The original key if no conversion occurred, int if converted scalar, |
521 | 543 | Int64Index if converted list-like. |
522 | 544 | """ |
523 | | - original = key |
524 | 545 | if is_list_like(key): |
525 | 546 | key = ensure_index(key) |
526 | | - key = self._maybe_convert_numeric_to_64bit(key) |
| 547 | + key = maybe_convert_numeric_to_64bit(key) |
527 | 548 |
|
528 | 549 | if not self._needs_i8_conversion(key): |
529 | | - return original |
| 550 | + return key |
530 | 551 |
|
531 | 552 | scalar = is_scalar(key) |
532 | 553 | if is_interval_dtype(key) or isinstance(key, Interval): |
@@ -569,20 +590,6 @@ def _maybe_convert_i8(self, key): |
569 | 590 |
|
570 | 591 | return key_i8 |
571 | 592 |
|
572 | | - def _maybe_convert_numeric_to_64bit(self, idx: Index) -> Index: |
573 | | - # IntervalTree only supports 64 bit numpy array |
574 | | - dtype = idx.dtype |
575 | | - if np.issubclass_(dtype.type, np.number): |
576 | | - return idx |
577 | | - elif is_signed_integer_dtype(dtype) and dtype != np.int64: |
578 | | - return idx.astype(np.int64) |
579 | | - elif is_unsigned_integer_dtype(dtype) and dtype != np.uint64: |
580 | | - return idx.astype(np.uint64) |
581 | | - elif is_float_dtype(dtype) and dtype != np.float64: |
582 | | - return idx.astype(np.float64) |
583 | | - else: |
584 | | - return idx |
585 | | - |
586 | 593 | def _searchsorted_monotonic(self, label, side: Literal["left", "right"] = "left"): |
587 | 594 | if not self.is_non_overlapping_monotonic: |
588 | 595 | raise KeyError( |
|
0 commit comments