Skip to content

ENH: Add argmax and argmin to ExtensionArray #27801

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 36 commits into from
Jul 8, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
8303674
Add argmax, max, argmin, min to EA
makbigc Aug 7, 2019
44838f7
Remove argmax, argmin, max, min from ArrowEA
makbigc Aug 9, 2019
8bc9573
Fix black error
makbigc Aug 9, 2019
bc69a4a
Add issue number to the test
makbigc Aug 10, 2019
4787c63
update
makbigc Aug 10, 2019
8bd82f5
merge for update
makbigc Dec 2, 2019
e84268b
Move the whatsnew entry to v1
makbigc Dec 2, 2019
f9c9fea
Add test for categorical
makbigc Dec 3, 2019
8b7585f
Add func doc and pre-check for zero length
makbigc Dec 4, 2019
41e8ce4
Add min and max to StringArray
makbigc Dec 4, 2019
20ca0a2
Add test for empty array
makbigc Dec 4, 2019
1aa7422
Resolve black format
makbigc Dec 5, 2019
f2b6958
merge again
makbigc Jan 18, 2020
7d81cc5
Fix test
makbigc Jan 18, 2020
d18c8bb
Fix lint error
makbigc Jan 18, 2020
3530a6a
Change the error message
makbigc Jan 19, 2020
8d8506a
Move the whatsnew entry from v1 to v1.1
makbigc Jan 19, 2020
e7e7c86
merge for update
makbigc Jan 22, 2020
10f9b27
merge for update
makbigc Jan 29, 2020
ccded0b
merge again
makbigc Feb 2, 2020
9aea8fe
merge for update
makbigc Feb 6, 2020
5636d2a
merge for update
makbigc Feb 7, 2020
4db39e0
merge again
makbigc Feb 14, 2020
2036b25
Refactor max and min
makbigc Feb 14, 2020
d81a8f8
merge again
makbigc Feb 25, 2020
58d46d0
merge for update
makbigc Feb 26, 2020
6b88790
Merge remote-tracking branch 'upstream/master' into enh-24382
jorisvandenbossche May 9, 2020
8cd7169
fixup merge
jorisvandenbossche May 9, 2020
810ac56
Remove min/max for now
jorisvandenbossche May 9, 2020
6b030aa
argmin/argmax implementation based on _values_for_argsort
jorisvandenbossche May 9, 2020
e5a6d8c
test clean-up + update docstring
jorisvandenbossche May 9, 2020
d7a49c1
Merge remote-tracking branch 'upstream/master' into enh-24382
jorisvandenbossche May 9, 2020
65e1e4c
simplify test_sparse override
jorisvandenbossche May 9, 2020
1b64e2f
Merge remote-tracking branch 'upstream/master' into enh-24382
jorisvandenbossche May 22, 2020
2cdf16b
Merge remote-tracking branch 'upstream/master' into enh-24382
jorisvandenbossche Jun 20, 2020
7c79f5c
Merge remote-tracking branch 'upstream/master' into enh-24382
jorisvandenbossche Jul 8, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
argmin/argmax implementation based on _values_for_argsort
  • Loading branch information
jorisvandenbossche committed May 9, 2020
commit 6b030aad3f67049fc251b91ecf34922dfaeb79c0
14 changes: 3 additions & 11 deletions pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from pandas.core import ops
from pandas.core.algorithms import _factorize_array, unique
from pandas.core.missing import backfill_1d, pad_1d
from pandas.core.sorting import nargsort
from pandas.core.sorting import nargminmax, nargsort

_extension_array_shared_docs: Dict[str, str] = dict()

Expand Down Expand Up @@ -520,10 +520,7 @@ def argmin(self):
--------
ExtensionArray.argmax
"""
if len(self) == 0:
raise ValueError("attempt to get argmin of an empty sequence")

return self.argsort()[0]
return nargminmax(self, "argmin")

def argmax(self):
"""
Expand All @@ -538,12 +535,7 @@ def argmax(self):
--------
ExtensionArray.argmin
"""

if len(self) == 0:
raise ValueError("attempt to get argmax of an empty sequence")

no_nan = self.isna().sum()
return self.argsort()[-1 - no_nan]
return nargminmax(self, "argmax")

def fillna(self, value=None, method=None, limit=None):
"""
Expand Down
27 changes: 27 additions & 0 deletions pandas/core/sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,33 @@ def nargsort(
return indexer


def nargminmax(values, method: str):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don’t think these
should be in base.py

rather in array_ops

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I put them here is because I think it makes sense to keep it close to nargsort, since the code is very similar (using the same approach with the idx/non_nan_idx etc)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don’t think these should be in base.py

BTW, this is not base.py, but core/sorting.py, which groups a whole bunch of functionality related to sortable values.
(it might make sense to move sorting.py into the array_ops submodule, but I would do that as a separate move)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

k that makes sense (and let’s move sorting.py) as a followon

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

annotate values as EA?

"""
Implementation of np.argmin/argmax but for ExtensionArray and which
handles missing values.

Parameters
----------
values : ExtensionArray
method : {"argmax", "argmin"}

Returns
-------
int
"""
assert method in {"argmax", "argmin"}
func = np.argmax if method == "argmax" else np.argmin

mask = np.asarray(isna(values))
values = values._values_for_argsort()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

im not wild about relying on non-public attrs here. could we have the EA method pass values and mask?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_values_for_argsort is a "public developer" API (it's part of the EA interface), that's the entire point of it.

I know we have had the discussion about the point of _values_for_argsort and in principle we could also do without. But at this point, we have that method, it is used for argsort as well, so I think it is most logical that I use it here. And we can continue that general discussion about _values_for_argsort elsewhere.


idx = np.arange(len(values))
non_nans = values[~mask]
non_nan_idx = idx[~mask]

return non_nan_idx[func(non_nans)]


def ensure_key_mapped_multiindex(index, key: Callable, level=None):
"""
Returns a new MultiIndex in which key has been applied
Expand Down
37 changes: 26 additions & 11 deletions pandas/tests/extension/base/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,26 +75,41 @@ def test_argsort_missing(self, data_missing_for_sorting):
expected = pd.Series(np.array([1, -1, 0], dtype=np.int64))
self.assert_series_equal(result, expected)

def test_argmax(self, data_missing_for_sorting):
def test_argmin_argmax(self, data_for_sorting, data_missing_for_sorting, na_value):
# GH 24382
result = data_missing_for_sorting.argmax()
expected = 0
assert result == expected

def test_argmin(self, data_missing_for_sorting):
# GH 24382
result = data_missing_for_sorting.argmin()
expected = 2
assert result == expected
# data_for_sorting -> [B, C, A] with A < B < C
assert data_for_sorting.argmax() == 1
assert data_for_sorting.argmin() == 2

# with repeated values -> first occurence
data = data_for_sorting.take([2, 0, 0, 1, 1, 2])
assert data.argmax() == 3
assert data.argmin() == 0

# with missing values
# data_missing_for_sorting -> [B, NA, A] with A < B and NA missing.
assert data_missing_for_sorting.argmax() == 0
assert data_missing_for_sorting.argmin() == 2

@pytest.mark.parametrize(
"method", ["argmax", "argmin"],
)
def test_extremize_empty_array(self, method, data_missing_for_sorting):
def test_argmin_argmax_empty_array(self, method, data):
# GH 24382
err_msg = "attempt to get"
with pytest.raises(ValueError, match=err_msg):
getattr(data_missing_for_sorting[:0], method)()
getattr(data[:0], method)()

@pytest.mark.parametrize(
"method", ["argmax", "argmin"],
)
def test_argmin_argmax_all_na(self, method, data, na_value):
# all missing with skipna=True is the same as emtpy
err_msg = "attempt to get"
data_na = type(data)._from_sequence([na_value, na_value], dtype=data.dtype)
with pytest.raises(ValueError, match=err_msg):
getattr(data_na, method)()

@pytest.mark.parametrize(
"na_position, expected",
Expand Down
17 changes: 17 additions & 0 deletions pandas/tests/extension/test_boolean.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,23 @@ def test_searchsorted(self, data_for_sorting, as_series):
def test_value_counts(self, all_data, dropna):
return super().test_value_counts(all_data, dropna)

def test_argmin_argmax(self, data_for_sorting, data_missing_for_sorting):
# override because there are only 2 unique values

# data_for_sorting -> [B, C, A] with A < B < C -> here True, True, False
assert data_for_sorting.argmax() == 0
assert data_for_sorting.argmin() == 2

# with repeated values -> first occurence
data = data_for_sorting.take([2, 0, 0, 1, 1, 2])
assert data.argmax() == 1
assert data.argmin() == 0

# with missing values
# data_missing_for_sorting -> [B, NA, A] with A < B and NA missing.
assert data_missing_for_sorting.argmax() == 0
assert data_missing_for_sorting.argmin() == 2


class TestCasting(base.BaseCastingTests):
pass
Expand Down
13 changes: 13 additions & 0 deletions pandas/tests/extension/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,19 @@ def test_shift_0_periods(self, data):
data._sparse_values[0] = data._sparse_values[1]
assert result._sparse_values[0] != result._sparse_values[1]

@pytest.mark.parametrize(
"method", ["argmax", "argmin"],
)
def test_argmin_argmax_all_na(self, method, data, na_value):
# overriding because Sparse[int64, 0] cannot handle na_value
if data.dtype.fill_value == 0:
pytest.skip("missing values not supported with Sparse[int64, 0]")

err_msg = "attempt to get"
data_na = type(data)._from_sequence([na_value, na_value], dtype=data.dtype)
with pytest.raises(ValueError, match=err_msg):
getattr(data_na, method)()


class TestCasting(BaseSparseTests, base.BaseCastingTests):
def test_astype_object_series(self, all_data):
Expand Down