Skip to content

fix: change return type of Series.loc[scalar] #40

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Oct 4, 2023
81 changes: 53 additions & 28 deletions bigframes/core/indexers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from __future__ import annotations

import typing
from typing import Tuple
from typing import Tuple, Union

import ibis
import pandas as pd
Expand All @@ -29,20 +29,19 @@
import bigframes.series

if typing.TYPE_CHECKING:
LocSingleKey = typing.Union[bigframes.series.Series, indexes.Index, slice]
LocSingleKey = Union[
bigframes.series.Series, indexes.Index, slice, bigframes.core.scalar.Scalar
]


class LocSeriesIndexer:
def __init__(self, series: bigframes.series.Series):
self._series = series

def __getitem__(self, key) -> bigframes.series.Series:
"""
Only indexing by a boolean bigframes.series.Series or list of index entries is currently supported
"""
return typing.cast(
bigframes.series.Series, _loc_getitem_series_or_dataframe(self._series, key)
)
def __getitem__(
self, key
) -> Union[bigframes.core.scalar.Scalar, bigframes.series.Series]:
return _loc_getitem_series_or_dataframe(self._series, key)

def __setitem__(self, key, value) -> None:
# TODO(swast): support MultiIndex
Expand Down Expand Up @@ -84,7 +83,7 @@ def __init__(self, series: bigframes.series.Series):

def __getitem__(
self, key
) -> bigframes.core.scalar.Scalar | bigframes.series.Series:
) -> Union[bigframes.core.scalar.Scalar, bigframes.series.Series]:
"""
Index series using integer offsets. Currently supports index by key type:

Expand All @@ -103,13 +102,17 @@ def __init__(self, dataframe: bigframes.dataframe.DataFrame):
self._dataframe = dataframe

@typing.overload
def __getitem__(self, key: LocSingleKey) -> bigframes.dataframe.DataFrame:
def __getitem__(
self, key: LocSingleKey
) -> Union[bigframes.dataframe.DataFrame, pd.Series]:
...

# Technically this is wrong since we can have duplicate column labels, but
# this is expected to be rare.
@typing.overload
def __getitem__(self, key: Tuple[LocSingleKey, str]) -> bigframes.series.Series:
def __getitem__(
self, key: Tuple[LocSingleKey, str]
) -> Union[bigframes.series.Series, bigframes.core.scalar.Scalar]:
...

def __getitem__(self, key):
Expand Down Expand Up @@ -173,7 +176,7 @@ class ILocDataFrameIndexer:
def __init__(self, dataframe: bigframes.dataframe.DataFrame):
self._dataframe = dataframe

def __getitem__(self, key) -> bigframes.dataframe.DataFrame | pd.Series:
def __getitem__(self, key) -> Union[bigframes.dataframe.DataFrame, pd.Series]:
"""
Index dataframe using integer offsets. Currently supports index by key type:

Expand All @@ -188,21 +191,26 @@ def __getitem__(self, key) -> bigframes.dataframe.DataFrame | pd.Series:
@typing.overload
def _loc_getitem_series_or_dataframe(
series_or_dataframe: bigframes.series.Series, key
) -> bigframes.series.Series:
) -> Union[bigframes.core.scalar.Scalar, bigframes.series.Series]:
...


@typing.overload
def _loc_getitem_series_or_dataframe(
series_or_dataframe: bigframes.dataframe.DataFrame, key
) -> bigframes.dataframe.DataFrame:
) -> Union[bigframes.dataframe.DataFrame, pd.Series]:
...


def _loc_getitem_series_or_dataframe(
series_or_dataframe: bigframes.dataframe.DataFrame | bigframes.series.Series,
series_or_dataframe: Union[bigframes.dataframe.DataFrame, bigframes.series.Series],
key: LocSingleKey,
) -> bigframes.dataframe.DataFrame | bigframes.series.Series:
) -> Union[
bigframes.dataframe.DataFrame,
bigframes.series.Series,
pd.Series,
bigframes.core.scalar.Scalar,
]:
if isinstance(key, bigframes.series.Series) and key.dtype == "boolean":
return series_or_dataframe[key]
elif isinstance(key, bigframes.series.Series):
Expand All @@ -222,7 +230,7 @@ def _loc_getitem_series_or_dataframe(
# TODO(henryjsolberg): support MultiIndex
if len(key) == 0: # type: ignore
return typing.cast(
typing.Union[bigframes.dataframe.DataFrame, bigframes.series.Series],
Union[bigframes.dataframe.DataFrame, bigframes.series.Series],
series_or_dataframe.iloc[0:0],
)

Expand Down Expand Up @@ -258,11 +266,22 @@ def _loc_getitem_series_or_dataframe(
)
keys_df = keys_df.set_index(index_name, drop=True)
keys_df.index.name = None
return _perform_loc_list_join(series_or_dataframe, keys_df)
result = _perform_loc_list_join(series_or_dataframe, keys_df)
pandas_result = result.to_pandas()
# although loc[scalar_key] returns multiple results when scalar_key
# is not unique, we download the results here and return the computed
# individual result (as a scalar or pandas series) when the key is unique,
# since we expect unique index keys to be more common. loc[[scalar_key]]
# can be used to retrieve one-item DataFrames or Series.
if len(pandas_result) == 1:
return pandas_result.iloc[0]
# when the key is not unique, we return a bigframes data type
# as usual for methods that return dataframes/series
return result
else:
raise TypeError(
"Invalid argument type. loc currently only supports indexing with a "
"boolean bigframes Series, a list of index entries or a single index entry. "
"Invalid argument type. Expected bigframes.Series, bigframes.Index, "
"list, : (empty slice), or scalar. "
f"{constants.FEEDBACK_LINK}"
)

Expand All @@ -284,9 +303,9 @@ def _perform_loc_list_join(


def _perform_loc_list_join(
series_or_dataframe: bigframes.dataframe.DataFrame | bigframes.series.Series,
series_or_dataframe: Union[bigframes.dataframe.DataFrame, bigframes.series.Series],
keys_df: bigframes.dataframe.DataFrame,
) -> bigframes.series.Series | bigframes.dataframe.DataFrame:
) -> Union[bigframes.series.Series, bigframes.dataframe.DataFrame]:
# right join based on the old index so that the matching rows from the user's
# original dataframe will be duplicated and reordered appropriately
original_index_names = series_or_dataframe.index.names
Expand All @@ -309,20 +328,26 @@ def _perform_loc_list_join(
@typing.overload
def _iloc_getitem_series_or_dataframe(
series_or_dataframe: bigframes.series.Series, key
) -> bigframes.series.Series | bigframes.core.scalar.Scalar:
) -> Union[bigframes.series.Series, bigframes.core.scalar.Scalar]:
...


@typing.overload
def _iloc_getitem_series_or_dataframe(
series_or_dataframe: bigframes.dataframe.DataFrame, key
) -> bigframes.dataframe.DataFrame | pd.Series:
) -> Union[bigframes.dataframe.DataFrame, pd.Series]:
...


def _iloc_getitem_series_or_dataframe(
series_or_dataframe: bigframes.dataframe.DataFrame | bigframes.series.Series, key
) -> bigframes.dataframe.DataFrame | bigframes.series.Series | bigframes.core.scalar.Scalar | pd.Series:
series_or_dataframe: Union[bigframes.dataframe.DataFrame, bigframes.series.Series],
key,
) -> Union[
bigframes.dataframe.DataFrame,
bigframes.series.Series,
bigframes.core.scalar.Scalar,
pd.Series,
]:
if isinstance(key, int):
internal_slice_result = series_or_dataframe._slice(key, key + 1, 1)
result_pd_df = internal_slice_result.to_pandas()
Expand All @@ -334,7 +359,7 @@ def _iloc_getitem_series_or_dataframe(
elif pd.api.types.is_list_like(key):
if len(key) == 0:
return typing.cast(
typing.Union[bigframes.dataframe.DataFrame, bigframes.series.Series],
Union[bigframes.dataframe.DataFrame, bigframes.series.Series],
series_or_dataframe.iloc[0:0],
)
df = series_or_dataframe
Expand Down
8 changes: 5 additions & 3 deletions bigframes/ml/model_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
https://scikit-learn.org/stable/modules/classes.html#module-sklearn.model_selection."""


import typing
from typing import List, Union

from bigframes.ml import utils
Expand Down Expand Up @@ -79,9 +80,10 @@ def train_test_split(
train_index = split_dfs[0].index
test_index = split_dfs[1].index

split_dfs += [
df.loc[index] for df in dfs[1:] for index in (train_index, test_index)
]
split_dfs += typing.cast(
List[bpd.DataFrame],
[df.loc[index] for df in dfs[1:] for index in (train_index, test_index)],
)

# convert back to Series.
results: List[Union[bpd.DataFrame, bpd.Series]] = []
Expand Down
2 changes: 1 addition & 1 deletion tests/system/small/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2081,7 +2081,7 @@ def test_loc_single_index_no_duplicate(scalars_df_index, scalars_pandas_df_index
bf_result = scalars_df_index.loc[index]
pd_result = scalars_pandas_df_index.loc[index]
pd.testing.assert_series_equal(
bf_result.to_pandas().iloc[0, :],
bf_result,
pd_result,
)

Expand Down
6 changes: 3 additions & 3 deletions tests/system/small/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def test_series_get_with_default_index(scalars_dfs):
scalars_df, scalars_pandas_df = scalars_dfs
bf_result = scalars_df[col_name].get(key)
pd_result = scalars_pandas_df[col_name].get(key)
assert bf_result.to_pandas().iloc[0] == pd_result
assert bf_result == pd_result


@pytest.mark.parametrize(
Expand Down Expand Up @@ -157,7 +157,7 @@ def test_series___getitem___with_default_index(scalars_dfs):
scalars_df, scalars_pandas_df = scalars_dfs
bf_result = scalars_df[col_name][key]
pd_result = scalars_pandas_df[col_name][key]
assert bf_result.to_pandas().iloc[0] == pd_result
assert bf_result == pd_result


@pytest.mark.parametrize(
Expand Down Expand Up @@ -2652,7 +2652,7 @@ def test_loc_single_index_no_duplicate(scalars_df_index, scalars_pandas_df_index
index = -2345
bf_result = scalars_df_index.date_col.loc[index]
pd_result = scalars_pandas_df_index.date_col.loc[index]
assert bf_result.to_pandas().iloc[0] == pd_result
assert bf_result == pd_result


def test_series_bool_interpretation_error(scalars_df_index):
Expand Down