From 1a1edfdf234307b6a81dc815ebd7698c0182b497 Mon Sep 17 00:00:00 2001 From: Rubtsowa Date: Thu, 25 Nov 2021 11:53:21 +0300 Subject: [PATCH] FEAT-#3303: Add __getitem__ for Resampler (#3613) Co-authored-by: Alexey Prutskov Signed-off-by: Maria Rubtsova --- modin/pandas/base.py | 36 ++++++++++++++++++++ modin/pandas/test/dataframe/test_default.py | 37 +++++++++++++++++++++ 2 files changed, 73 insertions(+) diff --git a/modin/pandas/base.py b/modin/pandas/base.py index c7f6a0abeb5..f556d1ae3f4 100644 --- a/modin/pandas/base.py +++ b/modin/pandas/base.py @@ -3127,6 +3127,42 @@ def __init__( ] self.__groups = self.__get_groups(*self.resample_args) + def __getitem__(self, key): + """ + Get ``Resampler`` based on `key` columns of original dataframe. + + Parameters + ---------- + key : str or list + String or list of selections. + + Returns + ------- + modin.pandas.BasePandasDataset + New ``Resampler`` based on `key` columns subset + of the original dataframe. + """ + + def _get_new_resampler(key): + subset = self._dataframe[key] + resampler = type(self)(subset, *self.resample_args) + return resampler + + from .series import Series + + if isinstance( + key, (list, tuple, Series, pandas.Series, pandas.Index, np.ndarray) + ): + if len(self._dataframe.columns.intersection(key)) != len(key): + missed_keys = list(set(key).difference(self._dataframe.columns)) + raise KeyError(f"Columns not found: {str(sorted(missed_keys))[1:-1]}") + return _get_new_resampler(list(key)) + + if key not in self._dataframe: + raise KeyError(f"Column not found: {key}") + + return _get_new_resampler(key) + def __get_groups( self, rule, diff --git a/modin/pandas/test/dataframe/test_default.py b/modin/pandas/test/dataframe/test_default.py index 08478611f4a..d4d3c5afddd 100644 --- a/modin/pandas/test/dataframe/test_default.py +++ b/modin/pandas/test/dataframe/test_default.py @@ -762,6 +762,43 @@ def test_resample_specific(rule, closed, label, on, level): ) +@pytest.mark.parametrize( + "columns", + [ + "volume", + "date", + ["volume"], + ["price", "date"], + ("volume",), + pandas.Series(["volume"]), + pandas.Index(["volume"]), + ["volume", "volume", "volume"], + ["volume", "price", "date"], + ], + ids=[ + "column", + "missed_column", + "list", + "missed_column", + "tuple", + "series", + "index", + "duplicate_column", + "missed_columns", + ], +) +def test_resample_getitem(columns): + index = pandas.date_range("1/1/2013", periods=9, freq="T") + data = { + "price": range(9), + "volume": range(10, 19), + } + eval_general( + *create_test_dfs(data, index=index), + lambda df: df.resample("3T")[columns].mean(), + ) + + @pytest.mark.parametrize("data", test_data_values, ids=test_data_keys) @pytest.mark.parametrize("index", ["default", "ndarray", "has_duplicates"]) @pytest.mark.parametrize("axis", [0, 1])