Skip to content

Commit 781307e

Browse files
feat: add idxmin, idxmax to series, dataframe (#74)
* feat: add idxmin, idxmax to series, dataframe * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md --------- Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
1 parent 17afac9 commit 781307e

File tree

8 files changed

+182
-16
lines changed

8 files changed

+182
-16
lines changed

bigframes/core/block_transforms.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import pandas as pd
1919

20+
import bigframes.constants as constants
2021
import bigframes.core as core
2122
import bigframes.core.blocks as blocks
2223
import bigframes.core.ordering as ordering
@@ -576,3 +577,53 @@ def align_columns(
576577
left_final = left_block.select_columns(left_column_ids)
577578
right_final = right_block.select_columns(right_column_ids)
578579
return left_final, right_final
580+
581+
582+
def idxmin(block: blocks.Block) -> blocks.Block:
583+
return _idx_extrema(block, "min")
584+
585+
586+
def idxmax(block: blocks.Block) -> blocks.Block:
587+
return _idx_extrema(block, "max")
588+
589+
590+
def _idx_extrema(
591+
block: blocks.Block, min_or_max: typing.Literal["min", "max"]
592+
) -> blocks.Block:
593+
if len(block.index_columns) != 1:
594+
# TODO: Need support for tuple dtype
595+
raise NotImplementedError(
596+
f"idxmin not support for multi-index. {constants.FEEDBACK_LINK}"
597+
)
598+
599+
original_block = block
600+
result_cols = []
601+
for value_col in original_block.value_columns:
602+
direction = (
603+
ordering.OrderingDirection.ASC
604+
if min_or_max == "min"
605+
else ordering.OrderingDirection.DESC
606+
)
607+
# Have to find the min for each
608+
order_refs = [
609+
ordering.OrderingColumnReference(value_col, direction),
610+
*[
611+
ordering.OrderingColumnReference(idx_col)
612+
for idx_col in original_block.index_columns
613+
],
614+
]
615+
window_spec = core.WindowSpec(ordering=order_refs)
616+
idx_col = original_block.index_columns[0]
617+
block, result_col = block.apply_window_op(
618+
idx_col, agg_ops.first_op, window_spec
619+
)
620+
result_cols.append(result_col)
621+
622+
block = block.select_columns(result_cols).with_column_labels(
623+
original_block.column_labels
624+
)
625+
# Stack the entire column axis to produce single-column result
626+
# Assumption: uniform dtype for stackability
627+
return block.aggregate_all_and_stack(
628+
agg_ops.AnyValueOp(), dtype=block.dtypes[0]
629+
).with_column_labels([original_block.index.name])

bigframes/dataframe.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1642,6 +1642,12 @@ def agg(
16421642

16431643
aggregate = agg
16441644

1645+
def idxmin(self) -> bigframes.series.Series:
1646+
return bigframes.series.Series(block_ops.idxmin(self._block))
1647+
1648+
def idxmax(self) -> bigframes.series.Series:
1649+
return bigframes.series.Series(block_ops.idxmax(self._block))
1650+
16451651
def describe(self) -> DataFrame:
16461652
df_numeric = self._drop_non_numeric(keep_bool=False)
16471653
if len(df_numeric.columns) == 0:

bigframes/series.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -887,6 +887,34 @@ def argmin(self) -> int:
887887
scalars.Scalar, Series(block.select_column(row_nums)).iloc[0]
888888
)
889889

890+
def idxmax(self) -> blocks.Label:
891+
block = self._block.order_by(
892+
[
893+
OrderingColumnReference(
894+
self._value_column, direction=OrderingDirection.DESC
895+
),
896+
*[
897+
OrderingColumnReference(idx_col)
898+
for idx_col in self._block.index_columns
899+
],
900+
]
901+
)
902+
block = block.slice(0, 1)
903+
return indexes.Index._from_block(block).to_pandas()[0]
904+
905+
def idxmin(self) -> blocks.Label:
906+
block = self._block.order_by(
907+
[
908+
OrderingColumnReference(self._value_column),
909+
*[
910+
OrderingColumnReference(idx_col)
911+
for idx_col in self._block.index_columns
912+
],
913+
]
914+
)
915+
block = block.slice(0, 1)
916+
return indexes.Index._from_block(block).to_pandas()[0]
917+
890918
@property
891919
def is_monotonic_increasing(self) -> bool:
892920
return typing.cast(

tests/system/small/test_dataframe.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1292,6 +1292,34 @@ def test_df_update(overwrite, filter_func):
12921292
pd.testing.assert_frame_equal(bf_df1.to_pandas(), pd_df1)
12931293

12941294

1295+
def test_df_idxmin():
1296+
pd_df = pd.DataFrame(
1297+
{"a": [1, 2, 3], "b": [7, None, 3], "c": [4, 4, 4]}, index=["x", "y", "z"]
1298+
)
1299+
bf_df = dataframe.DataFrame(pd_df)
1300+
1301+
bf_result = bf_df.idxmin().to_pandas()
1302+
pd_result = pd_df.idxmin()
1303+
1304+
pd.testing.assert_series_equal(
1305+
bf_result, pd_result, check_index_type=False, check_dtype=False
1306+
)
1307+
1308+
1309+
def test_df_idxmax():
1310+
pd_df = pd.DataFrame(
1311+
{"a": [1, 2, 3], "b": [7, None, 3], "c": [4, 4, 4]}, index=["x", "y", "z"]
1312+
)
1313+
bf_df = dataframe.DataFrame(pd_df)
1314+
1315+
bf_result = bf_df.idxmax().to_pandas()
1316+
pd_result = pd_df.idxmax()
1317+
1318+
pd.testing.assert_series_equal(
1319+
bf_result, pd_result, check_index_type=False, check_dtype=False
1320+
)
1321+
1322+
12951323
@pytest.mark.parametrize(
12961324
("join", "axis"),
12971325
[

tests/system/small/test_multiindex.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,17 @@ def test_reset_multi_index(scalars_df_index, scalars_pandas_df_index):
4141
pandas.testing.assert_frame_equal(bf_result, pd_result)
4242

4343

44+
def test_series_multi_index_idxmin(scalars_df_index, scalars_pandas_df_index):
45+
bf_result = scalars_df_index.set_index(["bool_col", "int64_too"])[
46+
"float64_col"
47+
].idxmin()
48+
pd_result = scalars_pandas_df_index.set_index(["bool_col", "int64_too"])[
49+
"float64_col"
50+
].idxmin()
51+
52+
assert bf_result == pd_result
53+
54+
4455
def test_binop_series_series_matching_multi_indices(
4556
scalars_df_index, scalars_pandas_df_index
4657
):

tests/system/small/test_series.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2468,6 +2468,18 @@ def test_argmax(scalars_df_index, scalars_pandas_df_index):
24682468
assert bf_result == pd_result
24692469

24702470

2471+
def test_series_idxmin(scalars_df_index, scalars_pandas_df_index):
2472+
bf_result = scalars_df_index.string_col.idxmin()
2473+
pd_result = scalars_pandas_df_index.string_col.idxmin()
2474+
assert bf_result == pd_result
2475+
2476+
2477+
def test_series_idxmax(scalars_df_index, scalars_pandas_df_index):
2478+
bf_result = scalars_df_index.int64_too.idxmax()
2479+
pd_result = scalars_pandas_df_index.int64_too.idxmax()
2480+
assert bf_result == pd_result
2481+
2482+
24712483
def test_getattr_attribute_error_when_pandas_has(scalars_df_index):
24722484
# asof is implemented in pandas but not in bigframes
24732485
with pytest.raises(AttributeError):

third_party/bigframes_vendored/pandas/core/frame.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1805,6 +1805,28 @@ def nsmallest(self, n: int, columns, keep: str = "first"):
18051805
"""
18061806
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)
18071807

1808+
def idxmin(self):
1809+
"""
1810+
Return index of first occurrence of minimum over requested axis.
1811+
1812+
NA/null values are excluded.
1813+
1814+
Returns:
1815+
Series: Indexes of minima along the specified axis.
1816+
"""
1817+
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)
1818+
1819+
def idxmax(self):
1820+
"""
1821+
Return index of first occurrence of maximum over requested axis.
1822+
1823+
NA/null values are excluded.
1824+
1825+
Returns:
1826+
Series: Indexes of maxima along the specified axis.
1827+
"""
1828+
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)
1829+
18081830
def nunique(self):
18091831
"""
18101832
Count number of distinct elements in specified axis.

third_party/bigframes_vendored/pandas/core/series.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import numpy as np
99
from pandas._libs import lib
1010
from pandas._typing import Axis, FilePath, NaPosition, WriteBuffer
11-
import pandas.io.formats.format as fmt
1211

1312
from bigframes import constants
1413
from third_party.bigframes_vendored.pandas.core.generic import NDFrame
@@ -151,21 +150,6 @@ def to_string(
151150
str or None: String representation of Series if ``buf=None``,
152151
otherwise None.
153152
"""
154-
formatter = fmt.SeriesFormatter(
155-
self,
156-
name=name,
157-
length=length,
158-
header=header,
159-
index=index,
160-
dtype=dtype,
161-
na_rep=na_rep,
162-
float_format=float_format,
163-
min_rows=min_rows,
164-
max_rows=max_rows,
165-
)
166-
result = formatter.to_string()
167-
168-
# catch contract violations
169153
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)
170154

171155
def to_markdown(
@@ -475,6 +459,30 @@ def duplicated(self, keep="first") -> Series:
475459
"""
476460
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)
477461

462+
def idxmin(self) -> Hashable:
463+
"""
464+
Return the row label of the minimum value.
465+
466+
If multiple values equal the minimum, the first row label with that
467+
value is returned.
468+
469+
Returns:
470+
Index: Label of the minimum value.
471+
"""
472+
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)
473+
474+
def idxmax(self) -> Hashable:
475+
"""
476+
Return the row label of the maximum value.
477+
478+
If multiple values equal the maximum, the first row label with that
479+
value is returned.
480+
481+
Returns:
482+
Index: Label of the maximum value.
483+
"""
484+
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)
485+
478486
def round(self, decimals: int = 0) -> Series:
479487
"""
480488
Round each value in a Series to the given number of decimals.

0 commit comments

Comments
 (0)