Skip to content

Commit

Permalink
BUG: IntervalIndex set op bugs for empty results (pandas-dev#19112)
Browse files Browse the repository at this point in the history
  • Loading branch information
jschendel authored and jreback committed Jan 12, 2018
1 parent 8912efc commit 5853b79
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 6 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v0.23.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,7 @@ Indexing
- Bug in :func:`MultiIndex.__contains__` where non-tuple keys would return ``True`` even if they had been dropped (:issue:`19027`)
- Bug in :func:`MultiIndex.set_labels` which would cause casting (and potentially clipping) of the new labels if the ``level`` argument is not 0 or a list like [0, 1, ... ] (:issue:`19057`)
- Bug in ``str.extractall`` when there were no matches empty :class:`Index` was returned instead of appropriate :class:`MultiIndex` (:issue:`19034`)
- Bug in :class:`IntervalIndex` where set operations that returned an empty ``IntervalIndex`` had the wrong dtype (:issue:`19101`)

I/O
^^^
Expand Down
21 changes: 19 additions & 2 deletions pandas/core/indexes/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pandas.core.dtypes.missing import notna, isna
from pandas.core.dtypes.generic import ABCDatetimeIndex, ABCPeriodIndex
from pandas.core.dtypes.dtypes import IntervalDtype
from pandas.core.dtypes.cast import maybe_convert_platform
from pandas.core.dtypes.cast import maybe_convert_platform, find_common_type
from pandas.core.dtypes.common import (
_ensure_platform_int,
is_list_like,
Expand All @@ -16,6 +16,7 @@
is_integer_dtype,
is_float_dtype,
is_interval_dtype,
is_object_dtype,
is_scalar,
is_float,
is_number,
Expand Down Expand Up @@ -1289,9 +1290,25 @@ def func(self, other):
msg = ('can only do set operations between two IntervalIndex '
'objects that are closed on the same side')
other = self._as_like_interval_index(other, msg)

# GH 19016: ensure set op will not return a prohibited dtype
subtypes = [self.dtype.subtype, other.dtype.subtype]
common_subtype = find_common_type(subtypes)
if is_object_dtype(common_subtype):
msg = ('can only do {op} between two IntervalIndex '
'objects that have compatible dtypes')
raise TypeError(msg.format(op=op_name))

result = getattr(self._multiindex, op_name)(other._multiindex)
result_name = self.name if self.name == other.name else None
return type(self).from_tuples(result.values, closed=self.closed,

# GH 19101: ensure empty results have correct dtype
if result.empty:
result = result.values.astype(self.dtype.subtype)
else:
result = result.values

return type(self).from_tuples(result, closed=self.closed,
name=result_name)
return func

Expand Down
60 changes: 56 additions & 4 deletions pandas/tests/indexes/interval/test_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,6 +880,16 @@ def test_union(self, closed):
tm.assert_index_equal(index.union(index), index)
tm.assert_index_equal(index.union(index[:1]), index)

# GH 19101: empty result, same dtype
index = IntervalIndex(np.array([], dtype='int64'), closed=closed)
result = index.union(index)
tm.assert_index_equal(result, index)

# GH 19101: empty result, different dtypes
other = IntervalIndex(np.array([], dtype='float64'), closed=closed)
result = index.union(other)
tm.assert_index_equal(result, index)

def test_intersection(self, closed):
index = self.create_index(closed=closed)
other = IntervalIndex.from_breaks(range(5, 13), closed=closed)
Expand All @@ -893,14 +903,48 @@ def test_intersection(self, closed):

tm.assert_index_equal(index.intersection(index), index)

# GH 19101: empty result, same dtype
other = IntervalIndex.from_breaks(range(300, 314), closed=closed)
expected = IntervalIndex(np.array([], dtype='int64'), closed=closed)
result = index.intersection(other)
tm.assert_index_equal(result, expected)

# GH 19101: empty result, different dtypes
breaks = np.arange(300, 314, dtype='float64')
other = IntervalIndex.from_breaks(breaks, closed=closed)
result = index.intersection(other)
tm.assert_index_equal(result, expected)

def test_difference(self, closed):
index = self.create_index(closed=closed)
tm.assert_index_equal(index.difference(index[:1]), index[1:])

# GH 19101: empty result, same dtype
result = index.difference(index)
expected = IntervalIndex(np.array([], dtype='int64'), closed=closed)
tm.assert_index_equal(result, expected)

# GH 19101: empty result, different dtypes
other = IntervalIndex.from_arrays(index.left.astype('float64'),
index.right, closed=closed)
result = index.difference(other)
tm.assert_index_equal(result, expected)

def test_symmetric_difference(self, closed):
idx = self.create_index(closed=closed)
result = idx[1:].symmetric_difference(idx[:-1])
expected = IntervalIndex([idx[0], idx[-1]])
index = self.create_index(closed=closed)
result = index[1:].symmetric_difference(index[:-1])
expected = IntervalIndex([index[0], index[-1]])
tm.assert_index_equal(result, expected)

# GH 19101: empty result, same dtype
result = index.symmetric_difference(index)
expected = IntervalIndex(np.array([], dtype='int64'), closed=closed)
tm.assert_index_equal(result, expected)

# GH 19101: empty result, different dtypes
other = IntervalIndex.from_arrays(index.left.astype('float64'),
index.right, closed=closed)
result = index.symmetric_difference(other)
tm.assert_index_equal(result, expected)

@pytest.mark.parametrize('op_name', [
Expand All @@ -909,17 +953,25 @@ def test_set_operation_errors(self, closed, op_name):
index = self.create_index(closed=closed)
set_op = getattr(index, op_name)

# test errors
# non-IntervalIndex
msg = ('can only do set operations between two IntervalIndex objects '
'that are closed on the same side')
with tm.assert_raises_regex(ValueError, msg):
set_op(Index([1, 2, 3]))

# mixed closed
for other_closed in {'right', 'left', 'both', 'neither'} - {closed}:
other = self.create_index(closed=other_closed)
with tm.assert_raises_regex(ValueError, msg):
set_op(other)

# GH 19016: incompatible dtypes
other = interval_range(Timestamp('20180101'), periods=9, closed=closed)
msg = ('can only do {op} between two IntervalIndex objects that have '
'compatible dtypes').format(op=op_name)
with tm.assert_raises_regex(TypeError, msg):
set_op(other)

def test_isin(self, closed):
index = self.create_index(closed=closed)

Expand Down

0 comments on commit 5853b79

Please sign in to comment.