Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add getitem to array protocol #8406

Merged
merged 12 commits into from
Dec 12, 2023
38 changes: 35 additions & 3 deletions xarray/namedarray/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
_T = TypeVar("_T")
_T_co = TypeVar("_T_co", covariant=True)


_dtype = np.dtype
_DType = TypeVar("_DType", bound=np.dtype[Any])
_DType_co = TypeVar("_DType_co", covariant=True, bound=np.dtype[Any])
# A subset of `npt.DTypeLike` that can be parametrized w.r.t. `np.generic`
Expand Down Expand Up @@ -55,9 +55,15 @@ def dtype(self) -> _DType_co:
_Dims = tuple[_Dim, ...]

_DimsLike = Union[str, Iterable[_Dim]]
_AttrsLike = Union[Mapping[Any, Any], None]

_dtype = np.dtype
# https://data-apis.org/array-api/latest/API_specification/indexing.html
# TODO: np.array_api doesn't allow None for some reason, maybe they're
# recommending to use expand_dims?
_IndexKey = Union[int, slice, "ellipsis"]
_IndexKeys = tuple[Union[_IndexKey], ...]
_IndexKeyLike = Union[_IndexKey, _IndexKeys]

_AttrsLike = Union[Mapping[Any, Any], None]


class _SupportsReal(Protocol[_T_co]):
Expand Down Expand Up @@ -99,6 +105,25 @@ class _arrayfunction(
Corresponds to np.ndarray.
"""

@overload
def __getitem__(
self, key: _arrayfunction[Any, Any] | tuple[_arrayfunction[Any, Any], ...], /
) -> _arrayfunction[Any, _DType_co]:
...

@overload
def __getitem__(self, key: _IndexKeyLike, /) -> Any:
...

def __getitem__(
self,
key: _IndexKeyLike
| _arrayfunction[Any, Any]
| tuple[_arrayfunction[Any, Any], ...],
/,
) -> _arrayfunction[Any, _DType_co] | Any:
...

@overload
def __array__(self, dtype: None = ..., /) -> np.ndarray[Any, _DType_co]:
...
Expand Down Expand Up @@ -151,6 +176,13 @@ class _arrayapi(_array[_ShapeType_co, _DType_co], Protocol[_ShapeType_co, _DType
Corresponds to np.ndarray.
"""

def __getitem__(
self,
key: _IndexKeyLike | Any, # TODO: Any should be _arrayapi
Copy link
Contributor Author

@Illviljan Illviljan Nov 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure about this one. it should be the same class or subclass but with a integer dtype.
typing.Self is close, but then I guess it will inherit the dtype from the original class which is not correct.

/,
) -> _arrayapi[Any, Any]:
...

def __array_namespace__(self) -> ModuleType:
...

Expand Down
14 changes: 14 additions & 0 deletions xarray/tests/test_namedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
_AttrsLike,
_DimsLike,
_DType,
_IndexKeyLike,
_Shape,
duckarray,
)
Expand Down Expand Up @@ -53,6 +54,19 @@ class CustomArrayIndexable(
ExplicitlyIndexed,
Generic[_ShapeType_co, _DType_co],
):
def __getitem__(
self, key: _IndexKeyLike | CustomArrayIndexable[Any, Any], /
) -> CustomArrayIndexable[Any, _DType_co]:
if isinstance(key, CustomArrayIndexable):
if isinstance(key.array, type(self.array)):
# TODO: key.array is duckarray here, can it be narrowed down further?
# an _arrayapi cannot be used on a _arrayfunction for example.
return type(self)(array=self.array[key.array]) # type: ignore[index]
else:
raise TypeError("key must have the same array type as self")
else:
return type(self)(array=self.array[key])

def __array_namespace__(self) -> ModuleType:
return np

Expand Down
Loading