Skip to content

List accessor #55777

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
Nov 3, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
inital implementation, no documentation
  • Loading branch information
Rohan Jain committed Oct 31, 2023
commit a7f4a848458b9867b85bfb81cb86c6b6b578e11e
7 changes: 5 additions & 2 deletions pandas/core/arrays/arrow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from pandas.core.arrays.arrow.accessors import StructAccessor
from pandas.core.arrays.arrow.accessors import (
ListAccessor,
StructAccessor,
)
from pandas.core.arrays.arrow.array import ArrowExtensionArray

__all__ = ["ArrowExtensionArray", "StructAccessor"]
__all__ = ["ArrowExtensionArray", "StructAccessor", "ListAccessor"]
111 changes: 95 additions & 16 deletions pandas/core/arrays/arrow/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

from __future__ import annotations

from abc import (
ABCMeta,
abstractmethod,
)
from typing import TYPE_CHECKING

from pandas.compat import pa_version_under10p1
Expand All @@ -19,7 +23,90 @@
)


class StructAccessor:
class ArrowAccessor(metaclass=ABCMeta):
def __init__(self, data) -> None:
self._data = data
self._validate(data)

@abstractmethod
def _is_valid_pyarrow_dtype(self, pyarrow_dtype: pa.DataType) -> bool:
pass

@property
@abstractmethod
def _validation_msg(self) -> str:
pass

def _validate(self, data):
dtype = data.dtype
if not isinstance(dtype, ArrowDtype):
# Raise AttributeError so that inspect can handle non-struct Series.
raise AttributeError(self._validation_msg.format(dtype=dtype))

if not self._is_valid_pyarrow_dtype(dtype.pyarrow_dtype):
# Raise AttributeError so that inspect can handle invalid Series.
raise AttributeError(self._validation_msg.format(dtype=dtype))

@property
def _pa_array(self) -> pa.Array:
return self._data.array._pa_array


class ListAccessor(ArrowAccessor):
_validation_msg = (
"Can only use the '.list' accessor with 'list[pyarrow]' dtype, not {dtype}."
)

def __init__(self, data=None) -> None:
super().__init__(data)

def _is_valid_pyarrow_dtype(self, pyarrow_dtype: pa.DataType) -> bool:
return (
pa.types.is_list(pyarrow_dtype)
or pa.types.is_fixed_size_list(pyarrow_dtype)
or pa.types.is_large_list(pyarrow_dtype)
)

def len(self) -> Series:
from pandas import Series

value_lengths = pc.list_value_length(self._pa_array)
return Series(value_lengths, dtype=ArrowDtype(value_lengths.type))

def __getitem__(self, key: int) -> Series:
from pandas import Series

if isinstance(key, int):
# TODO: Support negative key but pyarrow does not allow
# element index to be an array.
# if key < 0:
# key = pc.add(key, pc.list_value_length(self._pa_array))
element = pc.list_element(self._pa_array, key)
return Series(element, dtype=ArrowDtype(element.type))
elif isinstance(key, slice):
# TODO: Support negative start/stop/step, ideally this would be added
# upstream in pyarrow.
start, stop, step = key.start, key.stop, key.step
if start is None:
# TODO: When adding negative step support
# this should be setto last element of array
# when step is negative.
start = 0
if step is None:
step = 1
sliced = pc.list_slice(self._pa_array, start, stop, step)
return Series(sliced, dtype=ArrowDtype(sliced.type))
else:
raise ValueError(f"key must be an int or slice, got {type(key).__name__}")

def flatten(self) -> Series:
from pandas import Series

flattened = pc.list_flatten(self._pa_array)
return Series(flattened, dtype=ArrowDtype(flattened.type))


class StructAccessor(ArrowAccessor):
"""
Accessor object for structured data properties of the Series values.

Expand All @@ -34,18 +121,10 @@ class StructAccessor:
)

def __init__(self, data=None) -> None:
self._parent = data
self._validate(data)

def _validate(self, data):
dtype = data.dtype
if not isinstance(dtype, ArrowDtype):
# Raise AttributeError so that inspect can handle non-struct Series.
raise AttributeError(self._validation_msg.format(dtype=dtype))
super().__init__(data)

if not pa.types.is_struct(dtype.pyarrow_dtype):
# Raise AttributeError so that inspect can handle non-struct Series.
raise AttributeError(self._validation_msg.format(dtype=dtype))
def _is_valid_pyarrow_dtype(self, pyarrow_dtype: pa.DataType) -> bool:
return pa.types.is_struct(pyarrow_dtype)

@property
def dtypes(self) -> Series:
Expand Down Expand Up @@ -80,7 +159,7 @@ def dtypes(self) -> Series:
Series,
)

pa_type = self._parent.dtype.pyarrow_dtype
pa_type = self._data.dtype.pyarrow_dtype
types = [ArrowDtype(struct.type) for struct in pa_type]
names = [struct.name for struct in pa_type]
return Series(types, index=Index(names))
Expand Down Expand Up @@ -135,7 +214,7 @@ def field(self, name_or_index: str | int) -> Series:
"""
from pandas import Series

pa_arr = self._parent.array._pa_array
pa_arr = self._data.array._pa_array
if isinstance(name_or_index, int):
index = name_or_index
elif isinstance(name_or_index, str):
Expand All @@ -151,7 +230,7 @@ def field(self, name_or_index: str | int) -> Series:
return Series(
field_arr,
dtype=ArrowDtype(field_arr.type),
index=self._parent.index,
index=self._data.index,
name=pa_field.name,
)

Expand Down Expand Up @@ -190,7 +269,7 @@ def explode(self) -> DataFrame:
"""
from pandas import concat

pa_type = self._parent.dtype.pyarrow_dtype
pa_type = self._data.dtype.pyarrow_dtype
return concat(
[self.field(i) for i in range(pa_type.num_fields)], axis="columns"
)
6 changes: 5 additions & 1 deletion pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,10 @@
from pandas.core.accessor import CachedAccessor
from pandas.core.apply import SeriesApply
from pandas.core.arrays import ExtensionArray
from pandas.core.arrays.arrow import StructAccessor
from pandas.core.arrays.arrow import (
ListAccessor,
StructAccessor,
)
from pandas.core.arrays.categorical import CategoricalAccessor
from pandas.core.arrays.sparse import SparseAccessor
from pandas.core.arrays.string_ import StringDtype
Expand Down Expand Up @@ -5891,6 +5894,7 @@ def to_period(self, freq: str | None = None, copy: bool | None = None) -> Series
plot = CachedAccessor("plot", pandas.plotting.PlotAccessor)
sparse = CachedAccessor("sparse", SparseAccessor)
struct = CachedAccessor("struct", StructAccessor)
list = CachedAccessor("list", ListAccessor)

# ----------------------------------------------------------------------
# Add plotting methods to Series
Expand Down
89 changes: 89 additions & 0 deletions pandas/tests/series/accessors/test_list_accessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import re

import pytest

from pandas import (
ArrowDtype,
Series,
)
import pandas._testing as tm

pa = pytest.importorskip("pyarrow")


@pytest.mark.parametrize(
"list_dtype",
(
pa.list_(pa.int64()),
pa.list_(pa.int64(), list_size=3),
pa.large_list(pa.int64()),
),
)
def test_list_getitem(list_dtype: pa.DataType):
ser = Series(
[[1, 2, 3], [4, None, 5], None],
dtype=ArrowDtype(list_dtype),
)
actual = ser.list[1]
expected = Series([2, None, None], dtype="int64[pyarrow]")
tm.assert_series_equal(actual, expected)


def test_list_getitem_slice():
ser = Series(
[[1, 2, 3], [4, None, 5], None],
dtype=ArrowDtype(pa.list_(pa.int64())),
)
actual = ser.list[1:None:None]
expected = Series([[2, 3], [None, 5], None], dtype=ArrowDtype(pa.list_(pa.int64())))
tm.assert_series_equal(actual, expected)


def test_list_len():
ser = Series(
[[1, 2, 3], [4, None], None],
dtype=ArrowDtype(pa.list_(pa.int64())),
)
actual = ser.list.len()
expected = Series([3, 2, None], dtype=ArrowDtype(pa.int32()))
tm.assert_series_equal(actual, expected)


def test_list_flatten():
ser = Series(
[[1, 2, 3], [4, None], None],
dtype=ArrowDtype(pa.list_(pa.int64())),
)
actual = ser.list.flatten()
expected = Series([1, 2, 3, 4, None], dtype=ArrowDtype(pa.int64()))
tm.assert_series_equal(actual, expected)


def test_list_getitem_slice_invalid():
ser = Series(
[[1, 2, 3], [4, None, 5], None],
dtype=ArrowDtype(pa.list_(pa.int64())),
)
with pytest.raises(pa.lib.ArrowInvalid, match=re.escape("`step` must be >= 1")):
ser.list[1:None:0]


@pytest.mark.parametrize(
"list_dtype",
(
pa.list_(pa.int64()),
pa.list_(pa.int64(), list_size=3),
pa.large_list(pa.int64()),
),
)
def test_list_getitem_invalid_index(list_dtype: pa.DataType):
ser = Series(
[[1, 2, 3], [4, None, 5], None],
dtype=ArrowDtype(list_dtype),
)
with pytest.raises(pa.lib.ArrowInvalid, match="Index -1 is out of bounds"):
ser.list[-1]
with pytest.raises(pa.lib.ArrowInvalid, match="Index 5 is out of bounds"):
ser.list[5]
with pytest.raises(ValueError, match="key must be an int or slice, got str"):
ser.list["abc"]
24 changes: 4 additions & 20 deletions pandas/tests/series/accessors/test_struct_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,28 +14,12 @@
pa = pytest.importorskip("pyarrow")


def test_struct_accessor_dtypes():
def test_list_getitem():
ser = Series(
[],
dtype=ArrowDtype(
pa.struct(
[
("int_col", pa.int64()),
("string_col", pa.string()),
(
"struct_col",
pa.struct(
[
("int_col", pa.int64()),
("float_col", pa.float64()),
]
),
),
]
)
),
[[1, 2, 3], [4, None], None],
dtype=ArrowDtype(pa.list_(pa.int64())),
)
actual = ser.struct.dtypes
actual = ser.list[1]
expected = Series(
[
ArrowDtype(pa.int64()),
Expand Down