Skip to content

Commit 3821040

Browse files
jschendeljreback
authored andcommitted
BUG: Ensure Index.astype('category') returns a CategoricalIndex (#18677)
1 parent c753e1e commit 3821040

File tree

18 files changed

+201
-30
lines changed

18 files changed

+201
-30
lines changed

doc/source/whatsnew/v0.22.0.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,7 @@ Conversion
259259
- Fixed a bug where creating a Series from an array that contains both tz-naive and tz-aware values will result in a Series whose dtype is tz-aware instead of object (:issue:`16406`)
260260
- Adding a ``Period`` object to a ``datetime`` or ``Timestamp`` object will now correctly raise a ``TypeError`` (:issue:`17983`)
261261
- Fixed a bug where ``FY5253`` date offsets could incorrectly raise an ``AssertionError`` in arithmetic operatons (:issue:`14774`)
262+
- Bug in :meth:`Index.astype` with a categorical dtype where the resultant index is not converted to a :class:`CategoricalIndex` for all types of index (:issue:`18630`)
262263

263264

264265
Indexing

pandas/core/dtypes/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1934,7 +1934,7 @@ def pandas_dtype(dtype):
19341934
except TypeError:
19351935
pass
19361936

1937-
elif dtype.startswith('interval[') or dtype.startswith('Interval['):
1937+
elif dtype.startswith('interval') or dtype.startswith('Interval'):
19381938
try:
19391939
return IntervalDtype.construct_from_string(dtype)
19401940
except TypeError:

pandas/core/dtypes/dtypes.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,33 @@ def _validate_categories(categories, fastpath=False):
340340

341341
return categories
342342

343+
def _update_dtype(self, dtype):
344+
"""
345+
Returns a CategoricalDtype with categories and ordered taken from dtype
346+
if specified, otherwise falling back to self if unspecified
347+
348+
Parameters
349+
----------
350+
dtype : CategoricalDtype
351+
352+
Returns
353+
-------
354+
new_dtype : CategoricalDtype
355+
"""
356+
if isinstance(dtype, compat.string_types) and dtype == 'category':
357+
# dtype='category' should not change anything
358+
return self
359+
elif not self.is_dtype(dtype):
360+
msg = ('a CategoricalDtype must be passed to perform an update, '
361+
'got {dtype!r}').format(dtype=dtype)
362+
raise ValueError(msg)
363+
364+
# dtype is CDT: keep current categories if None (ordered can't be None)
365+
new_categories = dtype.categories
366+
if new_categories is None:
367+
new_categories = self.categories
368+
return CategoricalDtype(new_categories, dtype.ordered)
369+
343370
@property
344371
def categories(self):
345372
"""

pandas/core/indexes/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1053,6 +1053,10 @@ def _to_embed(self, keep_tz=False, dtype=None):
10531053

10541054
@Appender(_index_shared_docs['astype'])
10551055
def astype(self, dtype, copy=True):
1056+
if is_categorical_dtype(dtype):
1057+
from .category import CategoricalIndex
1058+
return CategoricalIndex(self.values, name=self.name, dtype=dtype,
1059+
copy=copy)
10561060
return Index(self.values.astype(dtype, copy=copy), name=self.name,
10571061
dtype=dtype)
10581062

pandas/core/indexes/category.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from pandas import compat
55
from pandas.compat.numpy import function as nv
66
from pandas.core.dtypes.generic import ABCCategorical, ABCSeries
7+
from pandas.core.dtypes.dtypes import CategoricalDtype
78
from pandas.core.dtypes.common import (
89
is_categorical_dtype,
910
_ensure_platform_int,
@@ -165,8 +166,6 @@ def _create_categorical(self, data, categories=None, ordered=None,
165166
data = Categorical(data, categories=categories, ordered=ordered,
166167
dtype=dtype)
167168
else:
168-
from pandas.core.dtypes.dtypes import CategoricalDtype
169-
170169
if categories is not None:
171170
data = data.set_categories(categories, ordered=ordered)
172171
elif ordered is not None and ordered != data.ordered:
@@ -344,6 +343,12 @@ def astype(self, dtype, copy=True):
344343
if is_interval_dtype(dtype):
345344
from pandas import IntervalIndex
346345
return IntervalIndex.from_intervals(np.array(self))
346+
elif is_categorical_dtype(dtype):
347+
# GH 18630
348+
dtype = self.dtype._update_dtype(dtype)
349+
if dtype == self.dtype:
350+
return self.copy() if copy else self
351+
347352
return super(CategoricalIndex, self).astype(dtype=dtype, copy=copy)
348353

349354
@cache_readonly

pandas/core/indexes/datetimes.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
is_period_dtype,
2121
is_bool_dtype,
2222
is_string_dtype,
23+
is_categorical_dtype,
2324
is_string_like,
2425
is_list_like,
2526
is_scalar,
@@ -35,6 +36,7 @@
3536
from pandas.core.algorithms import checked_add_with_arr
3637

3738
from pandas.core.indexes.base import Index, _index_shared_docs
39+
from pandas.core.indexes.category import CategoricalIndex
3840
from pandas.core.indexes.numeric import Int64Index, Float64Index
3941
import pandas.compat as compat
4042
from pandas.tseries.frequencies import (
@@ -915,6 +917,9 @@ def astype(self, dtype, copy=True):
915917
elif copy is True:
916918
return self.copy()
917919
return self
920+
elif is_categorical_dtype(dtype):
921+
return CategoricalIndex(self.values, name=self.name, dtype=dtype,
922+
copy=copy)
918923
elif is_string_dtype(dtype):
919924
return Index(self.format(), name=self.name, dtype=object)
920925
elif is_period_dtype(dtype):

pandas/core/indexes/interval.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
Interval, IntervalMixin, IntervalTree,
3030
intervals_to_interval_bounds)
3131

32+
from pandas.core.indexes.category import CategoricalIndex
3233
from pandas.core.indexes.datetimes import date_range
3334
from pandas.core.indexes.timedeltas import timedelta_range
3435
from pandas.core.indexes.multi import MultiIndex
@@ -632,8 +633,8 @@ def astype(self, dtype, copy=True):
632633
elif is_object_dtype(dtype):
633634
return Index(self.values, dtype=object)
634635
elif is_categorical_dtype(dtype):
635-
from pandas import Categorical
636-
return Categorical(self, ordered=True)
636+
return CategoricalIndex(self.values, name=self.name, dtype=dtype,
637+
copy=copy)
637638
raise ValueError('Cannot cast IntervalIndex to dtype {dtype}'
638639
.format(dtype=dtype))
639640

pandas/core/indexes/multi.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@
1414
from pandas.core.dtypes.common import (
1515
_ensure_int64,
1616
_ensure_platform_int,
17+
is_categorical_dtype,
1718
is_object_dtype,
1819
is_iterator,
1920
is_list_like,
21+
pandas_dtype,
2022
is_scalar)
2123
from pandas.core.dtypes.missing import isna, array_equivalent
2224
from pandas.errors import PerformanceWarning, UnsortedIndexError
@@ -2715,9 +2717,14 @@ def difference(self, other):
27152717

27162718
@Appender(_index_shared_docs['astype'])
27172719
def astype(self, dtype, copy=True):
2718-
if not is_object_dtype(np.dtype(dtype)):
2719-
raise TypeError('Setting %s dtype to anything other than object '
2720-
'is not supported' % self.__class__)
2720+
dtype = pandas_dtype(dtype)
2721+
if is_categorical_dtype(dtype):
2722+
msg = '> 1 ndim Categorical are not supported at this time'
2723+
raise NotImplementedError(msg)
2724+
elif not is_object_dtype(dtype):
2725+
msg = ('Setting {cls} dtype to anything other than object '
2726+
'is not supported').format(cls=self.__class__)
2727+
raise TypeError(msg)
27212728
elif copy is True:
27222729
return self._shallow_copy()
27232730
return self

pandas/core/indexes/numeric.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
is_float_dtype,
88
is_object_dtype,
99
is_integer_dtype,
10+
is_categorical_dtype,
1011
is_bool,
1112
is_bool_dtype,
1213
is_scalar)
@@ -16,6 +17,7 @@
1617
from pandas.core import algorithms
1718
from pandas.core.indexes.base import (
1819
Index, InvalidIndexError, _index_shared_docs)
20+
from pandas.core.indexes.category import CategoricalIndex
1921
from pandas.util._decorators import Appender, cache_readonly
2022
import pandas.core.dtypes.concat as _concat
2123
import pandas.core.indexes.base as ibase
@@ -321,10 +323,13 @@ def astype(self, dtype, copy=True):
321323
values = self._values.astype(dtype, copy=copy)
322324
elif is_object_dtype(dtype):
323325
values = self._values.astype('object', copy=copy)
326+
elif is_categorical_dtype(dtype):
327+
return CategoricalIndex(self, name=self.name, dtype=dtype,
328+
copy=copy)
324329
else:
325-
raise TypeError('Setting %s dtype to anything other than '
326-
'float64 or object is not supported' %
327-
self.__class__)
330+
raise TypeError('Setting {cls} dtype to anything other than '
331+
'float64, object, or category is not supported'
332+
.format(cls=self.__class__))
328333
return Index(values, name=self.name, dtype=dtype)
329334

330335
@Appender(_index_shared_docs['_convert_scalar_indexer'])

pandas/core/indexes/period.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@
1616
is_timedelta64_dtype,
1717
is_period_dtype,
1818
is_bool_dtype,
19+
is_categorical_dtype,
1920
pandas_dtype,
2021
_ensure_object)
2122
from pandas.core.dtypes.dtypes import PeriodDtype
2223
from pandas.core.dtypes.generic import ABCSeries
2324

2425
import pandas.tseries.frequencies as frequencies
2526
from pandas.tseries.frequencies import get_freq_code as _gfc
27+
from pandas.core.indexes.category import CategoricalIndex
2628
from pandas.core.indexes.datetimes import DatetimeIndex, Int64Index, Index
2729
from pandas.core.indexes.timedeltas import TimedeltaIndex
2830
from pandas.core.indexes.datetimelike import DatelikeOps, DatetimeIndexOpsMixin
@@ -517,6 +519,9 @@ def astype(self, dtype, copy=True, how='start'):
517519
return self.to_timestamp(how=how).tz_localize(dtype.tz)
518520
elif is_period_dtype(dtype):
519521
return self.asfreq(freq=dtype.freq)
522+
elif is_categorical_dtype(dtype):
523+
return CategoricalIndex(self.values, name=self.name, dtype=dtype,
524+
copy=copy)
520525
raise TypeError('Cannot cast PeriodIndex to dtype %s' % dtype)
521526

522527
@Substitution(klass='PeriodIndex')

0 commit comments

Comments
 (0)