Skip to content

✨: Add stubs for major protocols #27

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/array_api_typing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
"""Static typing support for the array API standard."""

__all__ = (
"Array",
"ArrayNamespace",
"DType",
"Device",
"HasArrayNamespace",
"__version__",
"__version_tuple__",
"signature_types",
)

from ._namespace import HasArrayNamespace
from . import signature_types
from ._array import Array
from ._misc_objects import Device, DType
from ._namespace import ArrayNamespace, HasArrayNamespace
from ._version import version as __version__, version_tuple as __version_tuple__
12 changes: 12 additions & 0 deletions src/array_api_typing/_array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""Static typing support for array API arrays."""

from typing import Protocol

from ._namespace import HasArrayNamespace


class Array(HasArrayNamespace, Protocol):
"""An Array API array of homogenously-typed numbers."""

# TODO(https://github.com/data-apis/array-api-typing/issues/23): Populate this
# protocol with methods defined by the Array API specification.
6 changes: 6 additions & 0 deletions src/array_api_typing/_misc_objects.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Static typing support for miscellaneous objects in the array API."""

from typing import TypeAlias

Device: TypeAlias = object # The device on which an Array API array is stored.
DType: TypeAlias = object # The type of the numbers contained in an Array API array."""
49 changes: 45 additions & 4 deletions src/array_api_typing/_namespace.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,54 @@
__all__ = ("HasArrayNamespace",)

from types import ModuleType
from typing import Protocol, final
from __future__ import annotations

from typing import TYPE_CHECKING, Protocol
from typing_extensions import TypeVar

T = TypeVar("T", bound=object, default=ModuleType) # PEP 696 default
if TYPE_CHECKING:
# This condition exists to prevent a circular import: _array imports _namespace for
# HasArrayNamespace. Therefore, _namespace cannot import _array except when
# type-checking. The type variable depends on Array, so we create a dummy type
# variable without the same bounds and default for this case. In Python 3.13, this
# is no longer be necessary.
from typing_extensions import Buffer

from ._array import Array
from ._misc_objects import Device, DType
from .signature_types import NestedSequence

A = TypeVar("A", bound=Array, default=Array) # PEP 696 default
else:
A = TypeVar("A")


class ArrayNamespace(Protocol[A]):
"""An Array API namespace."""

def asarray(
self,
obj: Array | complex | NestedSequence[complex] | Buffer,
/,
*,
dtype: DType | None = None,
device: Device | None = None,
copy: bool | None = None,
) -> A: ...

def astype(
self,
x: A,
dtype: DType,
/,
*,
copy: bool = True,
device: Device | None = None,
) -> A: ...


T = TypeVar("T", bound=ArrayNamespace, default=ArrayNamespace) # PEP 696 default


@final
class HasArrayNamespace(Protocol[T]): # type: ignore[misc] # see python/mypy#17288
"""Protocol for classes that have an `__array_namespace__` method.

Expand Down
7 changes: 7 additions & 0 deletions src/array_api_typing/signature_types/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""Types that appear in public function signatures."""

__all__ = [
"NestedSequence",
]

from ._signature_types import NestedSequence
77 changes: 77 additions & 0 deletions src/array_api_typing/signature_types/_signature_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Protocol, TypeVar, runtime_checkable

if TYPE_CHECKING:
from collections.abc import Iterator

_T_co = TypeVar("_T_co", covariant=True)


@runtime_checkable
class NestedSequence(Protocol[_T_co]):
"""A protocol for representing nested sequences.

Warning:
-------
`NestedSequence` currently does not work in combination with type variables,
*e.g.* ``def func(a: NestedSequnce[T]) -> T: ...``.

See Also:
--------
collections.abc.Sequence:
ABCs for read-only and mutable :term:`sequences`.

Examples:
--------
.. code-block:: python

>>> from typing import TYPE_CHECKING
>>> import numpy as np
>>> import array_api_typing as xpt

>>> def get_dtype(seq: xpt.NestedSequence[float]) -> np.dtype[np.float64]:
... return np.asarray(seq).dtype

>>> a = get_dtype([1.0])
>>> b = get_dtype([[1.0]])
>>> c = get_dtype([[[1.0]]])
>>> d = get_dtype([[[[1.0]]]])

>>> if TYPE_CHECKING:
... reveal_locals()
... # note: Revealed local types are:
... # note: a: numpy.dtype[numpy.floating[numpy._typing._64Bit]]
... # note: b: numpy.dtype[numpy.floating[numpy._typing._64Bit]]
... # note: c: numpy.dtype[numpy.floating[numpy._typing._64Bit]]
... # note: d: numpy.dtype[numpy.floating[numpy._typing._64Bit]]

"""

def __len__(self, /) -> int:
"""Implement ``len(self)``."""
raise NotImplementedError

def __getitem__(self, index: int, /) -> _T_co | NestedSequence[_T_co]:
"""Implement ``self[x]``."""
raise NotImplementedError

def __contains__(self, x: object, /) -> bool:
"""Implement ``x in self``."""
raise NotImplementedError

def __iter__(self, /) -> Iterator[_T_co | NestedSequence[_T_co]]:
"""Implement ``iter(self)``."""
raise NotImplementedError

def __reversed__(self, /) -> Iterator[_T_co | NestedSequence[_T_co]]:
"""Implement ``reversed(self)``."""
raise NotImplementedError

def count(self, value: object, /) -> int:
"""Return the number of occurrences of `value`."""
raise NotImplementedError

def index(self, value: object, /) -> int:
"""Return the first index of `value`."""
raise NotImplementedError
Loading