Skip to content

Commit

Permalink
Let _get_dtype accept Categoricals and CategoricalIndex (#16887)
Browse files Browse the repository at this point in the history
  • Loading branch information
topper-123 authored and jreback committed Jul 13, 2017
1 parent 63536f4 commit 25384ba
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
1 change: 0 additions & 1 deletion doc/source/whatsnew/v0.21.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@ Conversion
^^^^^^^^^^



Indexing
^^^^^^^^

Expand Down
4 changes: 3 additions & 1 deletion pandas/core/dtypes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
ExtensionDtype)
from .generic import (ABCCategorical, ABCPeriodIndex,
ABCDatetimeIndex, ABCSeries,
ABCSparseArray, ABCSparseSeries)
ABCSparseArray, ABCSparseSeries, ABCCategoricalIndex)
from .inference import is_string_like
from .inference import * # noqa

Expand Down Expand Up @@ -1713,6 +1713,8 @@ def _get_dtype(arr_or_dtype):
return PeriodDtype.construct_from_string(arr_or_dtype)
elif is_interval_dtype(arr_or_dtype):
return IntervalDtype.construct_from_string(arr_or_dtype)
elif isinstance(arr_or_dtype, (ABCCategorical, ABCCategoricalIndex)):
return arr_or_dtype.dtype

if hasattr(arr_or_dtype, 'dtype'):
arr_or_dtype = arr_or_dtype.dtype
Expand Down
6 changes: 3 additions & 3 deletions pandas/tests/dtypes/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,16 +532,16 @@ def test_is_complex_dtype():
(float, np.dtype(float)),
('float64', np.dtype('float64')),
(np.dtype('float64'), np.dtype('float64')),
pytest.mark.xfail((str, np.dtype('<U')), ),
(str, np.dtype(str)),
(pd.Series([1, 2], dtype=np.dtype('int16')), np.dtype('int16')),
(pd.Series(['a', 'b']), np.dtype(object)),
(pd.Index([1, 2]), np.dtype('int64')),
(pd.Index(['a', 'b']), np.dtype(object)),
('category', 'category'),
(pd.Categorical(['a', 'b']).dtype, CategoricalDtype()),
pytest.mark.xfail((pd.Categorical(['a', 'b']), CategoricalDtype()),),
(pd.Categorical(['a', 'b']), CategoricalDtype()),
(pd.CategoricalIndex(['a', 'b']).dtype, CategoricalDtype()),
pytest.mark.xfail((pd.CategoricalIndex(['a', 'b']), CategoricalDtype()),),
(pd.CategoricalIndex(['a', 'b']), CategoricalDtype()),
(pd.DatetimeIndex([1, 2]), np.dtype('<M8[ns]')),
(pd.DatetimeIndex([1, 2]).dtype, np.dtype('<M8[ns]')),
('<M8[ns]', np.dtype('<M8[ns]')),
Expand Down

0 comments on commit 25384ba

Please sign in to comment.