Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test array api protocol #7902

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions xarray/_array_api/_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Vendored from https://github.com/data-apis/array-api/pull/589
"""
Types for type annotations used in the array API standard.

The type variables should be replaced with the actual types for a given
library, e.g., for NumPy TypeVar('array') would be replaced with ndarray.
"""
from __future__ import annotations

from collections.abc import Sequence
from dataclasses import dataclass
from enum import Enum
from typing import (
TYPE_CHECKING,
Any,
List,
Literal,
Optional,
Protocol,
Tuple,
TypeVar,
Union,
)

if TYPE_CHECKING:
from xarray._array_api.array_object import Array
from xarray._array_api.data_types import DType

array = TypeVar("array", bound="Array")
device = TypeVar("device")
dtype = TypeVar("dtype", bound="DType")
SupportsDLPack = TypeVar("SupportsDLPack")
SupportsBufferProtocol = TypeVar("SupportsBufferProtocol")
PyCapsule = TypeVar("PyCapsule")
# ellipsis cannot actually be imported from anywhere, so include a dummy here
# to keep pyflakes happy. https://github.com/python/typeshed/issues/3556
ellipsis = TypeVar("ellipsis")


@dataclass
class finfo_object:
"""Dataclass returned by `finfo`."""

bits: int
eps: float
max: float
min: float
smallest_normal: float


@dataclass
class iinfo_object:
"""Dataclass returned by `iinfo`."""

bits: int
max: int
min: int


_T_co = TypeVar("_T_co", covariant=True)


class NestedSequence(Protocol[_T_co]):
def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]:
...

def __len__(self, /) -> int:
...


__all__ = [
"Any",
"List",
"Literal",
"NestedSequence",
"Optional",
"PyCapsule",
"SupportsBufferProtocol",
"SupportsDLPack",
"Tuple",
"Union",
"Sequence",
"array",
"device",
"dtype",
"ellipsis",
"finfo_object",
"iinfo_object",
"Enum",
]
Loading