Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] generic select function #1187

Merged
merged 13 commits into from
Nov 8, 2022
Prev Previous commit
Next Next commit
add DropLabel for dropping columns
  • Loading branch information
samukweku committed Nov 2, 2022
commit 0325ada509c1185bf218d9ec6641cf5863e61667
2 changes: 1 addition & 1 deletion janitor/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,4 @@
from .transform_columns import transform_column, transform_columns
from .truncate_datetime import truncate_datetime_dataframe
from .update_where import update_where
from .utils import patterns, unionize_dataframe_categories
from .utils import patterns, unionize_dataframe_categories, DropLabel
45 changes: 45 additions & 0 deletions janitor/functions/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@
Pattern,
Union,
Callable,
Any,
)
from pandas.core.dtypes.generic import ABCPandasArray, ABCExtensionArray
from pandas.core.common import is_bool_indexer
from dataclasses import dataclass

import pandas as pd
from janitor.utils import check, _expand_grid
Expand Down Expand Up @@ -268,6 +270,19 @@ def _select_callable(arg, func: Callable, axis=None):
return bools


@dataclass
class DropLabel:
"""
Helper class for removing labels within the `select` syntax.
`label` can be any of the types supported in `_select_index`.
An array of integers not matching the labels is returned.
:param label: Label(s) to be dropped from the index.
:returns: A dataclass.
"""

label: Any


@singledispatch
def _select_index(arg, df, axis):
"""
Expand All @@ -283,6 +298,27 @@ def _select_index(arg, df, axis):
raise KeyError(f"No match was returned for {arg}") from exc


@_select_index.register(DropLabel) # noqa: F811
def _column_sel_dispatch(cols, df, axis): # noqa: F811
"""
Base function for selection on a Pandas Index object.
Returns the inverse of the passed label(s).

Returns an array of integers.
"""
arr = _select_index(cols.label, df, axis)
index = np.arange(getattr(df, axis).size)
if isinstance(arr, int):
arr = [arr]
elif isinstance(arr, slice):
arr = index[arr]
elif is_list_like(arr):
arr = np.asanyarray(arr)
if is_bool_dtype(arr):
return index[~arr]
return np.setdiff1d(index, arr)


@_select_index.register(str) # noqa: F811
def _index_dispatch(arg, df, axis): # noqa: F811
"""
Expand Down Expand Up @@ -485,6 +521,15 @@ def _index_dispatch(arg, df, axis): # noqa: F811

return arg

# treat multiple DropLabel instances as a single unit
checks = (isinstance(entry, DropLabel) for entry in arg)
if sum(checks) > 1:
drop_labels = (entry for entry in arg if isinstance(entry, DropLabel))
drop_labels = [entry.label for entry in drop_labels]
drop_labels = DropLabel(drop_labels)
arg = [entry for entry in arg if not isinstance(entry, DropLabel)]
arg.append(drop_labels)

indices = [_select_index(entry, df, axis) for entry in arg]

# single entry does not need to be combined
Expand Down
28 changes: 27 additions & 1 deletion tests/functions/test_select_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pandas.testing import assert_frame_equal
from itertools import product

from janitor.functions.utils import patterns
from janitor.functions.utils import patterns, DropLabel


@pytest.mark.functions
Expand All @@ -25,6 +25,32 @@ def test_select_column_names(dataframe, invert, expected):
assert_frame_equal(df, dataframe[expected])


@pytest.mark.functions
@pytest.mark.parametrize(
"invert,expected",
[
(True, ["a", "Bell__Chart", "cities"]),
(False, ["decorated-elephant", "animals@#$%^"]),
],
)
def test_select_column_names_droplabel(dataframe, invert, expected):
"Base DataFrame"
columns = ["a", "Bell__Chart", "cities"]
df = dataframe.select_columns(DropLabel(columns), invert=invert)

assert_frame_equal(df, dataframe[expected])


@pytest.mark.functions
def test_select_column_names_droplabel_multiple(dataframe):
"Base DataFrame"
columns = ["a", "Bell__Chart", "cities"]
cols = [DropLabel(ent) for ent in columns]
df = dataframe.select_columns(*cols)

assert_frame_equal(df, dataframe.drop(columns=columns))


@pytest.mark.functions
@pytest.mark.parametrize(
"invert,expected",
Expand Down